diff --git a/.clang-format b/.clang-format index c7370bb66a..018938c588 100644 --- a/.clang-format +++ b/.clang-format @@ -1,156 +1,255 @@ ---- -Language: Cpp -# BasedOnStyle: Google -AccessModifierOffset: -1 -AlignAfterOpenBracket: Align -AlignConsecutiveMacros: false -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -AlignEscapedNewlines: Left -AlignOperands: true -AlignTrailingComments: true -AllowAllArgumentsOnNextLine: true -AllowAllConstructorInitializersOnNextLine: true -AllowAllParametersOfDeclarationOnNextLine: true -AllowShortBlocksOnASingleLine: false -AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: All -AllowShortLambdasOnASingleLine: All -AllowShortIfStatementsOnASingleLine: WithoutElse -AllowShortLoopsOnASingleLine: true -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: Yes -BinPackArguments: true -BinPackParameters: true -BraceWrapping: - AfterCaseLabel: false - AfterClass: false - AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false +# reference from https://clang.llvm.org/docs/ClangFormatStyleOptions.html + +# 关闭格式化 +DisableFormat: false + +# 基础格式化方案 +BasedOnStyle: LLVM + +# 语言: None, Cpp, Java, JavaScript, ObjC, Proto, TableGen, TextProto +Language: Cpp + +# 标准: Cpp03, Cpp11, Auto +Standard: Cpp11 + +# tab宽度 +TabWidth: 4 + +# 使用tab字符: Never, ForIndentation, ForContinuationAndIndentation, Always +UseTab: Never + +# 访问说明符(public、private等)的偏移 +AccessModifierOffset: -2 + +# 缩进宽度 +IndentWidth: 4 + +# 构造函数的初始化列表的缩进宽度 +ConstructorInitializerIndentWidth: 4 + +# 延续的行的最小缩进宽度 +ContinuationIndentWidth: 4 + +# 缩进case标签 +IndentCaseLabels: true + +# 函数返回类型换行时,缩进函数声明或函数定义的函数名 +IndentWrappedFunctionNames: true + +# 命名空间的缩进: None, Inner(缩进嵌套的命名空间中的内容), All +NamespaceIndentation: All + +# 预处理缩进, None, AfterHash, BeforeHash +IndentPPDirectives: BeforeHash + +# 开括号(开圆括号、开尖括号、开方括号)后的对齐: Align, DontAlign, AlwaysBreak(总是在开括号后换行) +AlignAfterOpenBracket: Align + +# 连续赋值时,对齐所有等号 +#AlignConsecutiveAssignments: AcrossEmptyLinesAndComments +AlignConsecutiveAssignments: AcrossComments + +# 连续声明时,对齐所有声明的变量名 +AlignConsecutiveDeclarations: AcrossEmptyLinesAndComments +#AlignConsecutiveDeclarations: AcrossComments + +#AlignEscapedNewlines: Right + +# 左对齐逃脱换行(使用反斜杠换行)的反斜杠 +#AlignEscapedNewlinesLeft: true + +# 水平对齐二元和三元表达式的操作数 +AlignOperands: true + +# 对齐连续的尾随的注释 +AlignTrailingComments: true + +# 指针和引用的对齐: Left, Right, Middle +PointerAlignment: Left + +# 继承最常用的指针和引用的对齐方式 +DerivePointerAlignment: false + +# 允许函数声明的所有参数在放在下一行 +AllowAllParametersOfDeclarationOnNextLine: false + +# false表示函数实参要么都在同一行,要么都各自一行 +BinPackArguments: false + +# false表示所有形参要么都在同一行,要么都各自一行 +BinPackParameters: false + +# 允许函数调用的所有参数在放在下一行,即使BinPackParameters为false +AllowAllArgumentsOnNextLine: false + +# 允许短的块放在同一行 +AllowShortBlocksOnASingleLine: true + +# 允许短的case标签放在同一行 +AllowShortCaseLabelsOnASingleLine: true + +# 允许短的函数放在同一行: None, InlineOnly(定义在类中), Empty(空函数), Inline(定义在类中,空函数), All +AllowShortFunctionsOnASingleLine: Empty + +# 允许短的if语句保持在同一行 +AllowShortIfStatementsOnASingleLine: true + +# 允许短的循环保持在同一行 +AllowShortLoopsOnASingleLine: true + +# 总是在定义返回类型后换行(deprecated) +AlwaysBreakAfterDefinitionReturnType: None + +# 总是在返回类型后换行: None, All, TopLevel(顶级函数,不包括在类中的函数), +# AllDefinitions(所有的定义,不包括声明), TopLevelDefinitions(所有的顶级函数的定义) +AlwaysBreakAfterReturnType: None + +# 总是在多行string字面量前换行 +AlwaysBreakBeforeMultilineStrings: false + +# 总是在template声明后换行 +AlwaysBreakTemplateDeclarations: true + +# 构造函数的初始化列表要么都在同一行,要么都各自一行 +ConstructorInitializerAllOnOneLineOrOnePerLine: false + +# 构造函数的初始化列表的逗号和分号在前,对齐参数 +BreakConstructorInitializers: BeforeComma + +# 自动检测函数的调用和定义是否被格式为每行一个参数(Experimental) +ExperimentalAutoDetectBinPacking: true + +# 去除C++11的列表初始化的大括号{后和}前的空格 +Cpp11BracedListStyle: true + +# 大括号换行,只有当BreakBeforeBraces设置为Custom时才有效 +BraceWrapping: + # class定义后面 + AfterClass: true + # 控制语句后面 + AfterControlStatement: true + # enum定义后面 + AfterEnum: true + # 函数定义后面 + AfterFunction: true + # 命名空间定义后面 + AfterNamespace: true + # ObjC定义后面 + AfterObjCDeclaration: true + # struct定义后面 + AfterStruct: true + # union定义后面 + AfterUnion: true + AfterExternBlock: true + # catch之前 + BeforeCatch: true + # else之前 + BeforeElse: true + # 缩进大括号 + IndentBraces: false SplitEmptyFunction: true SplitEmptyRecord: true SplitEmptyNamespace: true -BreakBeforeBinaryOperators: None -BreakBeforeBraces: Attach -BreakBeforeInheritanceComma: false -BreakInheritanceList: BeforeColon -BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BreakConstructorInitializers: BeforeColon -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: true -ColumnLimit: 100 -CommentPragmas: '^ IWYU pragma:' + +# 在二元运算符前换行: None(在操作符后换行), NonAssignment(在非赋值的操作符前换行), All(在操作符前换行) +BreakBeforeBinaryOperators: None + +# 在大括号前换行: Attach(始终将大括号附加到周围的上下文), Linux(除函数、命名空间和类定义,与Attach类似), +# Mozilla(除枚举、函数、记录定义,与Attach类似), Stroustrup(除函数定义、catch、else,与Attach类似), +# Allman(总是在大括号前换行), GNU(总是在大括号前换行,并对于控制语句的大括号增加额外的缩进), WebKit(在函数前换行), Custom +# 注:这里认为语句块也属于函数 +BreakBeforeBraces: Allman + +# 在三元运算符前换行 +BreakBeforeTernaryOperators: false + +# 字符串字面值换行 +BreakStringLiterals: false + +# 每行字符的限制,0表示没有限制 +ColumnLimit: 0 + +# 赋值对齐换行的penalty +PenaltyBreakAssignment: 100 + +# 在call(后对函数调用换行的penalty +PenaltyBreakBeforeFirstCallParameter: 100 + +# 在一个注释中引入换行的penalty +PenaltyBreakComment: 100 + +# 第一次在<<前换行的penalty +PenaltyBreakFirstLessLess: 100 + +# 在一个字符串字面量中引入换行的penalty +PenaltyBreakString: 100 + +# 对于每个在行字符数限制之外的字符的penalty +PenaltyExcessCharacter: 100 + +# 将函数的返回类型放到它自己的行的penalty +PenaltyReturnTypeOnItsOwnLine: 100 + +# 在C风格类型转换后添加空格 +SpaceAfterCStyleCast: false + +# 在模板 template 关键字后面添加空格 +SpaceAfterTemplateKeyword: false + +# 在赋值运算符之前添加空格 +SpaceBeforeAssignmentOperators: true + +# 开圆括号之前添加一个空格: Never, ControlStatements, Always +SpaceBeforeParens: ControlStatements + +# 在尾随的评论前添加的空格数(只适用于//) +SpacesBeforeTrailingComments: 2 + +# 在尖括号的<后和>前添加空格 +SpacesInAngles: false + +# 在容器(ObjC和JavaScript的数组和字典等)字面量中添加空格 +SpacesInContainerLiterals: false + +# 在C风格类型转换的括号中添加空格 +SpacesInCStyleCastParentheses: false + +# 在圆括号的(后和)前添加空格 +SpacesInParentheses: false + +# 在空的圆括号中添加空格 +SpaceInEmptyParentheses: false + +# 在方括号的[后和]前添加空格,lamda表达式和未指明大小的数组的声明不受影响 +SpacesInSquareBrackets: false + +# 单行最多允许的连续空格? +PenaltyIndentedWhitespace: 10 + +# 描述具有特殊意义的注释的正则表达式,它不应该被分割为多行或以其它方式改变 +CommentPragmas: '^ IWYU pragma:' + +# 连续 namespace CompactNamespaces: false -ConstructorInitializerAllOnOneLineOrOnePerLine: true -ConstructorInitializerIndentWidth: 4 -ContinuationIndentWidth: 4 -Cpp11BracedListStyle: true -DerivePointerAlignment: true -DisableFormat: false -ExperimentalAutoDetectBinPacking: false -FixNamespaceComments: true -ForEachMacros: - - foreach - - Q_FOREACH - - BOOST_FOREACH -IncludeBlocks: Regroup -IncludeCategories: - - Regex: '^' - Priority: 2 - - Regex: '^<.*\.h>' - Priority: 1 - - Regex: '^<.*' - Priority: 2 - - Regex: '.*' - Priority: 3 -IncludeIsMainRegex: '([-_](test|unittest))?$' -IndentCaseLabels: true -IndentPPDirectives: None -IndentWidth: 2 -IndentWrappedFunctionNames: false -JavaScriptQuotes: Leave -JavaScriptWrapImports: true -KeepEmptyLinesAtTheStartOfBlocks: false -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBinPackProtocolList: Never -ObjCBlockIndentWidth: 2 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: true -PenaltyBreakAssignment: 2 -PenaltyBreakBeforeFirstCallParameter: 1 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyBreakTemplateDeclaration: 10 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left -RawStringFormats: - - Language: Cpp - Delimiters: - - cc - - CC - - cpp - - Cpp - - CPP - - 'c++' - - 'C++' - CanonicalDelimiter: '' - BasedOnStyle: google - - Language: TextProto - Delimiters: - - pb - - PB - - proto - - PROTO - EnclosingFunctions: - - EqualsProto - - EquivToProto - - PARSE_PARTIAL_TEXT_PROTO - - PARSE_TEST_PROTO - - PARSE_TEXT_PROTO - - ParseTextOrDie - - ParseTextProtoOrDie - CanonicalDelimiter: '' - BasedOnStyle: google -ReflowComments: true -SortIncludes: true -SortUsingDeclarations: true -SpaceAfterCStyleCast: false -SpaceAfterLogicalNot: false -SpaceAfterTemplateKeyword: true -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 2 -SpacesInAngles: false -SpacesInContainerLiterals: true -SpacesInCStyleCastParentheses: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -Standard: Auto -StatementMacros: - - Q_UNUSED - - QT_REQUIRE_VERSION -TabWidth: 8 -UseTab: Never -... + +# 保留在块开始处的空行 +KeepEmptyLinesAtTheStartOfBlocks: false + +# 连续空行的最大数量 +MaxEmptyLinesToKeep: 2 + +# 允许重新排版注释 +ReflowComments: true + +# 允许排序#include +SortIncludes: false + +# 对#include进行排序,匹配了某正则表达式的#include拥有对应的优先级,匹配不到的则默认优先级为INT_MAX(优先级越小排序越靠前), +# 可以定义负数优先级从而保证某些#include永远在最前面 +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 2d78033dcf..452e541fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -113,7 +113,7 @@ venv* data/ data -.vscode +# .vscode .idea .DS_Store @@ -172,3 +172,5 @@ demo/csharp/*/Properties # doxygen docs/cppapi/docs + +*debug* diff --git a/CMakeLists.txt b/CMakeLists.txt index a241d90674..a4252ef0d0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,12 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. -if (NOT DEFINED CMAKE_INSTALL_PREFIX) - set(CMAKE_INSTALL_PREFIX "${CMAKE_BINARY_DIR}/install" CACHE PATH "installation directory") -endif () +if(NOT DEFINED CMAKE_INSTALL_PREFIX) + set(CMAKE_INSTALL_PREFIX + "${CMAKE_BINARY_DIR}/install" + CACHE PATH "installation directory") +endif() message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") -if (NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE Release CACHE STRING "choose 'Release' as default build type" FORCE) -endif () +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE + Release + CACHE STRING "choose 'Release' as default build type" FORCE) +endif() cmake_minimum_required(VERSION 3.14) project(MMDeploy VERSION 1.3.1) @@ -18,11 +22,11 @@ set(MMDEPLOY_VERSION_MINOR ${PROJECT_VERSION_MINOR}) set(MMDEPLOY_VERSION_PATCH ${PROJECT_VERSION_PATCH}) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) -if (MSVC) - set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) -else () - set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) -endif () +if(MSVC) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +else() + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +endif() set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) # options @@ -41,141 +45,146 @@ option(MMDEPLOY_COVERAGE "build SDK for coverage" OFF) option(MMDEPLOY_USE_MSCV_STATIC "statically linked CRT" OFF) option(MMDEPLOY_ELENA_FUSION "use elena to fuse preprocess" OFF) -set(MMDEPLOY_TARGET_DEVICES "cpu" CACHE STRING "target devices to support") -set(MMDEPLOY_TARGET_BACKENDS "" CACHE STRING "target inference engines to support") -set(MMDEPLOY_CODEBASES "all" CACHE STRING "select OpenMMLab codebases") - -if ((NOT MMDEPLOY_BUILD_SDK_MONOLITHIC) AND MMDEPLOY_DYNAMIC_BACKEND) - set(MMDEPLOY_DYNAMIC_BACKEND OFF) -endif () +set(MMDEPLOY_TARGET_DEVICES + "cpu" + CACHE STRING "target devices to support") +set(MMDEPLOY_TARGET_BACKENDS + "" + CACHE STRING "target inference engines to support") +set(MMDEPLOY_CODEBASES + "all" + CACHE STRING "select OpenMMLab codebases") + +if((NOT MMDEPLOY_BUILD_SDK_MONOLITHIC) AND MMDEPLOY_DYNAMIC_BACKEND) + set(MMDEPLOY_DYNAMIC_BACKEND OFF) +endif() -if (MMDEPLOY_SHARED_LIBS) - set(MMDEPLOY_LIB_TYPE SHARED) -else () - set(MMDEPLOY_LIB_TYPE STATIC) -endif () +if(MMDEPLOY_SHARED_LIBS) + set(MMDEPLOY_LIB_TYPE SHARED) +else() + set(MMDEPLOY_LIB_TYPE STATIC) +endif() -set(MMDEPLOY_TASKS "" CACHE INTERNAL "") +set(MMDEPLOY_TASKS + "" + CACHE INTERNAL "") -if (MMDEPLOY_COVERAGE) - add_compile_options(-coverage -fprofile-arcs -ftest-coverage) - add_link_options(-coverage -lgcov) -endif () +if(MMDEPLOY_COVERAGE) + add_compile_options(-coverage -fprofile-arcs -ftest-coverage) + add_link_options(-coverage -lgcov) +endif() -# when CUDA devices are enabled, the environment variable ASAN_OPTIONS=protect_shadow_gap=0 -# must be set at runtime -if (MMDEPLOY_ASAN_ENABLE) - add_compile_options($<$:-fsanitize=address>) - add_link_options(-fsanitize=address) -endif () +# when CUDA devices are enabled, the environment variable +# ASAN_OPTIONS=protect_shadow_gap=0 must be set at runtime +if(MMDEPLOY_ASAN_ENABLE) + add_compile_options($<$:-fsanitize=address>) + add_link_options(-fsanitize=address) +endif() # notice that ubsan has linker issues for ubuntu < 18.04, see # https://stackoverflow.com/questions/50024731/ld-unrecognized-option-push-state-no-as-needed -if (MMDEPLOY_UBSAN_ENABLE) - add_compile_options($<$:-fsanitize=undefined>) - add_link_options(-fsanitize=undefined) -endif () - -if (MSVC) - add_compile_options($<$:/diagnostics:classic>) - add_compile_options($<$:/wd4251>) - if (MMDEPLOY_USE_MSCV_STATIC) - foreach(lang C CXX) - string(REPLACE /MD /MT CMAKE_${lang}_FLAGS_DEBUG "${CMAKE_${lang}_FLAGS_DEBUG}") - string(REPLACE /MD /MT CMAKE_${lang}_FLAGS_RELEASE "${CMAKE_${lang}_FLAGS_RELEASE}") - endforeach() - endif () -endif () +if(MMDEPLOY_UBSAN_ENABLE) + add_compile_options($<$:-fsanitize=undefined>) + add_link_options(-fsanitize=undefined) +endif() + +if(MSVC) + add_compile_options($<$:/diagnostics:classic>) + add_compile_options($<$:/wd4251>) + if(MMDEPLOY_USE_MSCV_STATIC) + foreach(lang C CXX) + string(REPLACE /MD /MT CMAKE_${lang}_FLAGS_DEBUG + "${CMAKE_${lang}_FLAGS_DEBUG}") + string(REPLACE /MD /MT CMAKE_${lang}_FLAGS_RELEASE + "${CMAKE_${lang}_FLAGS_RELEASE}") + endforeach() + endif() +endif() if(APPLE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fobjc-arc") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fobjc-arc") endif() add_library(MMDeployStaticModules INTERFACE) add_library(MMDeployDynamicModules INTERFACE) add_library(MMDeployLibs INTERFACE) -if ((cuda IN_LIST MMDEPLOY_TARGET_DEVICES) OR (trt IN_LIST MMDEPLOY_TARGET_BACKENDS)) - include(cmake/cuda.cmake NO_POLICY_SCOPE) -endif () +if((cuda IN_LIST MMDEPLOY_TARGET_DEVICES) OR (trt IN_LIST + MMDEPLOY_TARGET_BACKENDS)) + include(cmake/cuda.cmake NO_POLICY_SCOPE) +endif() -# this must come after including cuda.cmake because policies in function scope is captured -# at function definition +# this must come after including cuda.cmake because policies in function scope +# is captured at function definition include(cmake/MMDeploy.cmake) add_subdirectory(csrc/mmdeploy) -if (MMDEPLOY_BUILD_SDK) - if (NOT MMDEPLOY_BUILD_SDK_MONOLITHIC) - install(TARGETS MMDeployStaticModules - MMDeployDynamicModules - MMDeployLibs - EXPORT MMDeployTargets) - endif () - - if (MMDEPLOY_BUILD_TEST) - add_subdirectory(tests/test_csrc) - endif () - - if (MMDEPLOY_BUILD_EXAMPLES) - include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake) - add_subdirectory(demo/csrc) - endif () - - # export MMDeploy package - install(EXPORT MMDeployTargets - FILE MMDeployTargets.cmake +if(MMDEPLOY_BUILD_SDK) + if(NOT MMDEPLOY_BUILD_SDK_MONOLITHIC) + install(TARGETS MMDeployStaticModules MMDeployDynamicModules MMDeployLibs + EXPORT MMDeployTargets) + endif() + + if(MMDEPLOY_BUILD_TEST) + add_subdirectory(tests/test_csrc) + endif() + + if(MMDEPLOY_BUILD_EXAMPLES) + include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake) + add_subdirectory(demo/csrc) + endif() + + # export MMDeploy package + install( + EXPORT MMDeployTargets + FILE MMDeployTargets.cmake + DESTINATION lib/cmake/MMDeploy) + + if(MMDEPLOY_SPDLOG_EXTERNAL) + set(SPDLOG_DEPENDENCY "find_package(spdlog QUIET)") + endif() + # append backend deps + mmdeploy_add_deps(trt BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS TENSORRT + CUDNN) + mmdeploy_add_deps(ort BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS ONNXRUNTIME) + mmdeploy_add_deps(ncnn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS ncnn) + mmdeploy_add_deps(openvino BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS + InferenceEngine) + if(NOT MMDEPLOY_SHARED_LIBS) + mmdeploy_add_deps(pplnn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS pplnn) + endif() + mmdeploy_add_deps(snpe BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS snpe) + mmdeploy_add_deps(rknn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS rknn) + + include(CMakePackageConfigHelpers) + # generate the config file that is includes the exports + configure_package_config_file( + ${CMAKE_SOURCE_DIR}/cmake/MMDeployConfig.cmake.in + "${CMAKE_CURRENT_BINARY_DIR}/MMDeployConfig.cmake" + INSTALL_DESTINATION "lib/cmake" + NO_SET_AND_CHECK_MACRO NO_CHECK_REQUIRED_COMPONENTS_MACRO) + + write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/MMDeployConfigVersion.cmake" + VERSION "${MMDeploy_VERSION_MAJOR}.${MMDeploy_VERSION_MINOR}" + COMPATIBILITY AnyNewerVersion) + + install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/MMDeployConfig.cmake + ${CMAKE_CURRENT_BINARY_DIR}/MMDeployConfigVersion.cmake + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/MMDeploy.cmake + DESTINATION lib/cmake/MMDeploy) + + if(MSVC) + install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/cmake/loader.cpp.in DESTINATION lib/cmake/MMDeploy) + endif() + + install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules + DESTINATION lib/cmake/MMDeploy) - if (MMDEPLOY_SPDLOG_EXTERNAL) - set(SPDLOG_DEPENDENCY "find_package(spdlog QUIET)") - endif () - # append backend deps - mmdeploy_add_deps(trt BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS TENSORRT CUDNN) - mmdeploy_add_deps(ort BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS ONNXRUNTIME) - mmdeploy_add_deps(ncnn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS ncnn) - mmdeploy_add_deps(openvino BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS InferenceEngine) - if (NOT MMDEPLOY_SHARED_LIBS) - mmdeploy_add_deps(pplnn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS pplnn) - endif () - mmdeploy_add_deps(snpe BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS snpe) - mmdeploy_add_deps(rknn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS rknn) - - include(CMakePackageConfigHelpers) - # generate the config file that is includes the exports - configure_package_config_file(${CMAKE_SOURCE_DIR}/cmake/MMDeployConfig.cmake.in - "${CMAKE_CURRENT_BINARY_DIR}/MMDeployConfig.cmake" - INSTALL_DESTINATION "lib/cmake" - NO_SET_AND_CHECK_MACRO - NO_CHECK_REQUIRED_COMPONENTS_MACRO - ) - - write_basic_package_version_file( - "${CMAKE_CURRENT_BINARY_DIR}/MMDeployConfigVersion.cmake" - VERSION "${MMDeploy_VERSION_MAJOR}.${MMDeploy_VERSION_MINOR}" - COMPATIBILITY AnyNewerVersion - ) - - install(FILES - ${CMAKE_CURRENT_BINARY_DIR}/MMDeployConfig.cmake - ${CMAKE_CURRENT_BINARY_DIR}/MMDeployConfigVersion.cmake - ${CMAKE_CURRENT_SOURCE_DIR}/cmake/MMDeploy.cmake - DESTINATION lib/cmake/MMDeploy - ) - - if (MSVC) - install(FILES - ${CMAKE_CURRENT_SOURCE_DIR}/cmake/loader.cpp.in - DESTINATION lib/cmake/MMDeploy - ) - endif () - - install(DIRECTORY - ${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules - DESTINATION lib/cmake/MMDeploy - ) - - if (${CMAKE_VERSION} VERSION_LESS "3.17.0") - install(SCRIPT cmake/post-install.cmake) - endif () -endif () + if(${CMAKE_VERSION} VERSION_LESS "3.17.0") + install(SCRIPT cmake/post-install.cmake) + endif() +endif() diff --git a/cmake/MMDeploy.cmake b/cmake/MMDeploy.cmake index 304c7b1bc1..30e15c4c7c 100644 --- a/cmake/MMDeploy.cmake +++ b/cmake/MMDeploy.cmake @@ -1,220 +1,228 @@ # Copyright (c) OpenMMLab. All rights reserved. -function (mmdeploy_export_impl NAME) - set(_LIB_DIR lib) - if (MSVC) - set(_LIB_DIR bin) - endif () - install(TARGETS ${NAME} - EXPORT MMDeployTargets - ARCHIVE DESTINATION lib - LIBRARY DESTINATION ${_LIB_DIR} - RUNTIME DESTINATION bin) -endfunction () - -function (mmdeploy_add_rpath NAME) - if (MSVC) - return() - elseif(APPLE) - set_target_properties(${NAME} PROPERTIES - INSTALL_RPATH "@loader_path" - BUILD_RPATH "@loader_path") - else () - set_target_properties(${NAME} PROPERTIES - INSTALL_RPATH "\$ORIGIN" - BUILD_RPATH "\$ORIGIN") - target_link_libraries(${NAME} PRIVATE -Wl,--disable-new-dtags) - endif () -endfunction () +function(mmdeploy_export_impl NAME) + set(_LIB_DIR lib) + if(MSVC) + set(_LIB_DIR bin) + endif() + install( + TARGETS ${NAME} + EXPORT MMDeployTargets + ARCHIVE DESTINATION lib + LIBRARY DESTINATION ${_LIB_DIR} + RUNTIME DESTINATION bin) +endfunction() + +function(mmdeploy_add_rpath NAME) + if(MSVC) + return() + elseif(APPLE) + set_target_properties(${NAME} PROPERTIES INSTALL_RPATH "@loader_path" + BUILD_RPATH "@loader_path") + else() + set_target_properties(${NAME} PROPERTIES INSTALL_RPATH "\$ORIGIN" + BUILD_RPATH "\$ORIGIN") + target_link_libraries(${NAME} PRIVATE -Wl,--disable-new-dtags) + endif() +endfunction() macro(mmdeploy_add_net NAME) - if (MMDEPLOY_DYNAMIC_BACKEND) - mmdeploy_add_library(${NAME} SHARED ${ARGN}) - mmdeploy_add_rpath(${NAME}) - # DYNAMIC_BACKEND implies BUILD_SDK_MONOLITHIC - mmdeploy_export_impl(${NAME}) - target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy) - set(BACKEND_LIB_NAMES ${BACKEND_LIB_NAMES} ${PROJECT_NAME} PARENT_SCOPE) - else () - mmdeploy_add_module(${NAME} ${ARGN}) - endif () + if(MMDEPLOY_DYNAMIC_BACKEND) + mmdeploy_add_library(${NAME} SHARED ${ARGN}) + mmdeploy_add_rpath(${NAME}) + # DYNAMIC_BACKEND implies BUILD_SDK_MONOLITHIC + mmdeploy_export_impl(${NAME}) + target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy) + set(BACKEND_LIB_NAMES + ${BACKEND_LIB_NAMES} ${PROJECT_NAME} + PARENT_SCOPE) + else() + mmdeploy_add_module(${NAME} ${ARGN}) + endif() endmacro() -function (mmdeploy_export NAME) - if (NOT MMDEPLOY_BUILD_SDK_MONOLITHIC) - mmdeploy_export_impl(${NAME}) - endif () -endfunction () - - -function (mmdeploy_add_library NAME) - # EXCLUDE: exclude from registering & exporting - cmake_parse_arguments(_MMDEPLOY "EXCLUDE" "" "" ${ARGN}) - # search for add_library keywords - cmake_parse_arguments(_TYPE "STATIC;SHARED;MODULE" "" "" ${_MMDEPLOY_UNPARSED_ARGUMENTS}) - set(_MAYBE_TYPE) - if (NOT (_TYPE_STATIC OR _TYPE_SHARED OR _TYPE_MODULE)) - set(_MAYBE_TYPE ${MMDEPLOY_LIB_TYPE}) - endif () - add_library(${NAME} ${_MAYBE_TYPE} ${_MMDEPLOY_UNPARSED_ARGUMENTS}) - if (NOT MSVC) - target_compile_options(${NAME} PRIVATE $<$:-fvisibility=hidden>) - endif () +function(mmdeploy_export NAME) + if(NOT MMDEPLOY_BUILD_SDK_MONOLITHIC) + mmdeploy_export_impl(${NAME}) + endif() +endfunction() + +function(mmdeploy_add_library NAME) + # EXCLUDE: exclude from registering & exporting + cmake_parse_arguments(_MMDEPLOY "EXCLUDE" "" "" ${ARGN}) + # search for add_library keywords + cmake_parse_arguments(_TYPE "STATIC;SHARED;MODULE" "" "" + ${_MMDEPLOY_UNPARSED_ARGUMENTS}) + set(_MAYBE_TYPE) + if(NOT + (_TYPE_STATIC + OR _TYPE_SHARED + OR _TYPE_MODULE)) + set(_MAYBE_TYPE ${MMDEPLOY_LIB_TYPE}) + endif() + add_library(${NAME} ${_MAYBE_TYPE} ${_MMDEPLOY_UNPARSED_ARGUMENTS}) + if(NOT MSVC) + target_compile_options( + ${NAME} PRIVATE $<$:-fvisibility=hidden>) + endif() + target_compile_definitions(${NAME} PRIVATE -DMMDEPLOY_API_EXPORTS=1) + get_target_property(_TYPE ${NAME} TYPE) + if(_TYPE STREQUAL STATIC_LIBRARY) + set_target_properties(${NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1) + elseif(_TYPE STREQUAL SHARED_LIBRARY) + + else() + message(FATAL_ERROR "unsupported type: ${_TYPE}") + endif() + if(NOT _MMDEPLOY_EXCLUDE) + target_link_libraries(MMDeployLibs INTERFACE ${NAME}) + mmdeploy_export(${NAME}) + endif() +endfunction() + +function(mmdeploy_add_module NAME) + # EXCLUDE: exclude from registering & exporting as SDK module LIBRARY: the + # module is also a library (add_libray with SHARED instead of MODULE) + cmake_parse_arguments(_MMDEPLOY "EXCLUDE;LIBRARY" "" "" ${ARGN}) + # search for add_library keywords + cmake_parse_arguments(_TYPE "STATIC;SHARED;MODULE" "" "" + ${_MMDEPLOY_UNPARSED_ARGUMENTS}) + + set(_MAYBE_TYPE) + # no library type specified + if(NOT + (_TYPE_STATIC + OR _TYPE_SHARED + OR _TYPE_MODULE)) + # shared but not marked as a library, build module library so that no .lib + # dependency will be generated for MSVC + if(MSVC + AND MMDEPLOY_SHARED_LIBS + AND NOT _MMDEPLOY_LIBRARY) + set(_MAYBE_TYPE MODULE) + else() + set(_MAYBE_TYPE ${MMDEPLOY_LIB_TYPE}) + endif() + endif() + + add_library(${NAME} ${_MAYBE_TYPE} ${_MMDEPLOY_UNPARSED_ARGUMENTS}) + + if(NOT MSVC) + target_compile_options( + ${NAME} PRIVATE $<$:-fvisibility=hidden>) + endif() + + # automatically link mmdeploy::core if exists + if(TARGET mmdeploy::core) + target_link_libraries(${NAME} PRIVATE mmdeploy::core) + endif() + + # export public symbols when marked as a library + if(_MMDEPLOY_LIBRARY) target_compile_definitions(${NAME} PRIVATE -DMMDEPLOY_API_EXPORTS=1) - get_target_property(_TYPE ${NAME} TYPE) - if (_TYPE STREQUAL STATIC_LIBRARY) - set_target_properties(${NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1) - elseif (_TYPE STREQUAL SHARED_LIBRARY) - else () - message(FATAL_ERROR "unsupported type: ${_TYPE}") - endif () - if (NOT _MMDEPLOY_EXCLUDE) - target_link_libraries(MMDeployLibs INTERFACE ${NAME}) - mmdeploy_export(${NAME}) - endif () -endfunction () - - -function (mmdeploy_add_module NAME) - # EXCLUDE: exclude from registering & exporting as SDK module - # LIBRARY: the module is also a library (add_libray with SHARED instead of MODULE) - cmake_parse_arguments(_MMDEPLOY "EXCLUDE;LIBRARY" "" "" ${ARGN}) - # search for add_library keywords - cmake_parse_arguments(_TYPE "STATIC;SHARED;MODULE" "" "" ${_MMDEPLOY_UNPARSED_ARGUMENTS}) - - set(_MAYBE_TYPE) - # no library type specified - if (NOT (_TYPE_STATIC OR _TYPE_SHARED OR _TYPE_MODULE)) - # shared but not marked as a library, build module library so that no .lib dependency - # will be generated for MSVC - if (MSVC AND MMDEPLOY_SHARED_LIBS AND NOT _MMDEPLOY_LIBRARY) - set(_MAYBE_TYPE MODULE) - else () - set(_MAYBE_TYPE ${MMDEPLOY_LIB_TYPE}) - endif () - endif () - - add_library(${NAME} ${_MAYBE_TYPE} ${_MMDEPLOY_UNPARSED_ARGUMENTS}) - - if (NOT MSVC) - target_compile_options(${NAME} PRIVATE $<$:-fvisibility=hidden>) - endif () - - # automatically link mmdeploy::core if exists - if (TARGET mmdeploy::core) - target_link_libraries(${NAME} PRIVATE mmdeploy::core) - endif () - - # export public symbols when marked as a library - if (_MMDEPLOY_LIBRARY) - target_compile_definitions(${NAME} PRIVATE -DMMDEPLOY_API_EXPORTS=1) - endif () - - get_target_property(_TYPE ${NAME} TYPE) - if (_TYPE STREQUAL STATIC_LIBRARY) - set_target_properties(${NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1) - if (MSVC) - target_link_options(${NAME} INTERFACE "/WHOLEARCHIVE:${NAME}") - endif () - # register static modules - if (NOT _MMDEPLOY_EXCLUDE) - target_link_libraries(MMDeployStaticModules INTERFACE ${NAME}) - endif () - elseif (_TYPE STREQUAL SHARED_LIBRARY OR _TYPE STREQUAL MODULE_LIBRARY) - # register dynamic modules - if (NOT _MMDEPLOY_EXCLUDE) - target_link_libraries(MMDeployDynamicModules INTERFACE ${NAME}) - endif () - else () - message(FATAL_ERROR "unsupported type: ${_TYPE}") - endif () - if (NOT _MMDEPLOY_EXCLUDE) - mmdeploy_export(${NAME}) - endif () -endfunction () - - -function (_mmdeploy_flatten_modules RETVAL) - set(_RETVAL) - foreach (ARG IN LISTS ARGN) - get_target_property(TYPE ${ARG} TYPE) - if (TYPE STREQUAL "INTERFACE_LIBRARY") - get_target_property(LIBS ${ARG} INTERFACE_LINK_LIBRARIES) - if (LIBS) - # pattern for 3.17+ - list(FILTER LIBS EXCLUDE REGEX "^::@") - # pattern for 3.13-3.16 - list(TRANSFORM LIBS REPLACE "(.+)::@.*" "\\1") - list(APPEND _RETVAL ${LIBS}) - endif () - else () - list(APPEND _RETVAL ${ARG}) - endif () - endforeach () - set(${RETVAL} ${_RETVAL} PARENT_SCOPE) -endfunction () - - -function (mmdeploy_load_static NAME) - if (MSVC) - target_link_libraries(${NAME} PRIVATE ${ARGN}) - else () - _mmdeploy_flatten_modules(_MODULE_LIST ${ARGN}) - if (APPLE) - foreach (module IN LISTS _MODULE_LIST) - target_link_libraries(${NAME} PRIVATE -force_load ${module}) - endforeach () - else () - target_link_libraries(${NAME} PRIVATE - -Wl,--whole-archive - ${_MODULE_LIST} - -Wl,--no-whole-archive) - endif () - endif () -endfunction () - -function (mmdeploy_load_dynamic NAME) + endif() + + get_target_property(_TYPE ${NAME} TYPE) + if(_TYPE STREQUAL STATIC_LIBRARY) + set_target_properties(${NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1) + if(MSVC) + target_link_options(${NAME} INTERFACE "/WHOLEARCHIVE:${NAME}") + endif() + # register static modules + if(NOT _MMDEPLOY_EXCLUDE) + target_link_libraries(MMDeployStaticModules INTERFACE ${NAME}) + endif() + elseif(_TYPE STREQUAL SHARED_LIBRARY OR _TYPE STREQUAL MODULE_LIBRARY) + # register dynamic modules + if(NOT _MMDEPLOY_EXCLUDE) + target_link_libraries(MMDeployDynamicModules INTERFACE ${NAME}) + endif() + else() + message(FATAL_ERROR "unsupported type: ${_TYPE}") + endif() + if(NOT _MMDEPLOY_EXCLUDE) + mmdeploy_export(${NAME}) + endif() +endfunction() + +function(_mmdeploy_flatten_modules RETVAL) + set(_RETVAL) + foreach(ARG IN LISTS ARGN) + get_target_property(TYPE ${ARG} TYPE) + if(TYPE STREQUAL "INTERFACE_LIBRARY") + get_target_property(LIBS ${ARG} INTERFACE_LINK_LIBRARIES) + if(LIBS) + # pattern for 3.17+ + list(FILTER LIBS EXCLUDE REGEX "^::@") + # pattern for 3.13-3.16 + list(TRANSFORM LIBS REPLACE "(.+)::@.*" "\\1") + list(APPEND _RETVAL ${LIBS}) + endif() + else() + list(APPEND _RETVAL ${ARG}) + endif() + endforeach() + set(${RETVAL} + ${_RETVAL} + PARENT_SCOPE) +endfunction() + +function(mmdeploy_load_static NAME) + if(MSVC) + target_link_libraries(${NAME} PRIVATE ${ARGN}) + else() _mmdeploy_flatten_modules(_MODULE_LIST ${ARGN}) - if (MSVC) - if (NOT _MODULE_LIST) - return () - endif () - # MSVC has nothing like "-Wl,--no-as-needed ... -Wl,--as-needed", as a - # workaround we build a static module which loads the dynamic modules - set(_MODULE_STR ${_MODULE_LIST}) - list(TRANSFORM _MODULE_STR REPLACE "(.+)" "\"\\1\"") - string(JOIN ",\n " _MODULE_STR ${_MODULE_STR}) - set(_MMDEPLOY_DYNAMIC_MODULES ${_MODULE_STR}) - - set(_LOADER_NAME ${NAME}_loader) - - add_dependencies(${NAME} ${_MODULE_LIST}) - - set(_LOADER_PATH ${CMAKE_BINARY_DIR}/${_LOADER_NAME}.cpp) - # ! CMAKE_CURRENT_FUNCTION_LIST_DIR requires cmake 3.17+ - configure_file( - ${CMAKE_CURRENT_FUNCTION_LIST_DIR}/loader.cpp.in - ${_LOADER_PATH}) - - mmdeploy_add_module(${_LOADER_NAME} STATIC EXCLUDE ${_LOADER_PATH}) - mmdeploy_load_static(${NAME} ${_LOADER_NAME}) - elseif (APPLE) - target_link_libraries(${NAME} PRIVATE ${_MODULE_LIST}) - else () - target_link_libraries(${NAME} PRIVATE - -Wl,--no-as-needed - ${_MODULE_LIST} - -Wl,--as-needed) - endif () -endfunction () + if(APPLE) + foreach(module IN LISTS _MODULE_LIST) + target_link_libraries(${NAME} PRIVATE -force_load ${module}) + endforeach() + else() + target_link_libraries(${NAME} PRIVATE -Wl,--whole-archive ${_MODULE_LIST} + -Wl,--no-whole-archive) + endif() + endif() +endfunction() + +function(mmdeploy_load_dynamic NAME) + _mmdeploy_flatten_modules(_MODULE_LIST ${ARGN}) + if(MSVC) + if(NOT _MODULE_LIST) + return() + endif() + # MSVC has nothing like "-Wl,--no-as-needed ... -Wl,--as-needed", as a + # workaround we build a static module which loads the dynamic modules + set(_MODULE_STR ${_MODULE_LIST}) + list(TRANSFORM _MODULE_STR REPLACE "(.+)" "\"\\1\"") + string(JOIN ",\n " _MODULE_STR ${_MODULE_STR}) + set(_MMDEPLOY_DYNAMIC_MODULES ${_MODULE_STR}) + + set(_LOADER_NAME ${NAME}_loader) + + add_dependencies(${NAME} ${_MODULE_LIST}) + + set(_LOADER_PATH ${CMAKE_BINARY_DIR}/${_LOADER_NAME}.cpp) + # ! CMAKE_CURRENT_FUNCTION_LIST_DIR requires cmake 3.17+ + configure_file(${CMAKE_CURRENT_FUNCTION_LIST_DIR}/loader.cpp.in + ${_LOADER_PATH}) + + mmdeploy_add_module(${_LOADER_NAME} STATIC EXCLUDE ${_LOADER_PATH}) + mmdeploy_load_static(${NAME} ${_LOADER_NAME}) + elseif(APPLE) + target_link_libraries(${NAME} PRIVATE ${_MODULE_LIST}) + else() + target_link_libraries(${NAME} PRIVATE -Wl,--no-as-needed ${_MODULE_LIST} + -Wl,--as-needed) + endif() +endfunction() macro(mmdeploy_add_deps backend) - set(multiValueArgs BACKENDS DEPS) - cmake_parse_arguments(INFO "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - set(has_backend OFF) - if (${backend} IN_LIST INFO_BACKENDS) - foreach(pkg IN LISTS INFO_DEPS) - set(${pkg}_DEPENDENCY "find_package(${pkg} REQUIRED)") - endforeach() - endif() + set(multiValueArgs BACKENDS DEPS) + cmake_parse_arguments(INFO "${options}" "${oneValueArgs}" "${multiValueArgs}" + ${ARGN}) + set(has_backend OFF) + if(${backend} IN_LIST INFO_BACKENDS) + foreach(pkg IN LISTS INFO_DEPS) + set(${pkg}_DEPENDENCY "find_package(${pkg} REQUIRED)") + endforeach() + endif() endmacro() diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 578fdc7e74..7b2e1c7d83 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -1,110 +1,114 @@ # Copyright (c) OpenMMLab. All rights reserved. -if (${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.18.0") - # suppress 'CMAKE_CUDA_ARCHITECTURES' warning - cmake_policy(SET CMP0104 OLD) -endif () +if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.18.0") + # suppress 'CMAKE_CUDA_ARCHITECTURES' warning + cmake_policy(SET CMP0104 OLD) +endif() -if (MSVC OR (NOT DEFINED CMAKE_CUDA_RUNTIME_LIBRARY)) - # use shared, on windows, python api can't build with static lib. - set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) - set(CUDA_USE_STATIC_CUDA_RUNTIME OFF) -endif () +if(MSVC OR (NOT DEFINED CMAKE_CUDA_RUNTIME_LIBRARY)) + # use shared, on windows, python api can't build with static lib. + set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) + set(CUDA_USE_STATIC_CUDA_RUNTIME OFF) +endif() -if (MSVC) - # no plugin in BuildCustomizations and no specify cuda toolset - if (NOT CMAKE_VS_PLATFORM_TOOLSET_CUDA) - message(FATAL_ERROR "Please install CUDA MSBuildExtensions") - endif () +if(MSVC) + # no plugin in BuildCustomizations and no specify cuda toolset + if(NOT CMAKE_VS_PLATFORM_TOOLSET_CUDA) + message(FATAL_ERROR "Please install CUDA MSBuildExtensions") + endif() - if (CMAKE_VS_PLATFORM_TOOLSET_CUDA_CUSTOM_DIR) - # find_package(CUDA) required ENV{CUDA_PATH} - set(ENV{CUDA_PATH} ${CMAKE_VS_PLATFORM_TOOLSET_CUDA_CUSTOM_DIR}) - else () - # we use CUDA_PATH and ignore nvcc.exe - # cmake will import highest cuda props version, which may not equal to CUDA_PATH - if (NOT (DEFINED ENV{CUDA_PATH})) - message(FATAL_ERROR "Please set CUDA_PATH environment variable") - endif () + if(CMAKE_VS_PLATFORM_TOOLSET_CUDA_CUSTOM_DIR) + # find_package(CUDA) required ENV{CUDA_PATH} + set(ENV{CUDA_PATH} ${CMAKE_VS_PLATFORM_TOOLSET_CUDA_CUSTOM_DIR}) + else() + # we use CUDA_PATH and ignore nvcc.exe cmake will import highest cuda props + # version, which may not equal to CUDA_PATH + if(NOT (DEFINED ENV{CUDA_PATH})) + message(FATAL_ERROR "Please set CUDA_PATH environment variable") + endif() - string(REGEX REPLACE ".*v([0-9]+)\\..*" "\\1" _MAJOR $ENV{CUDA_PATH}) - string(REGEX REPLACE ".*v[0-9]+\\.([0-9]+).*" "\\1" _MINOR $ENV{CUDA_PATH}) - if (NOT (${CMAKE_VS_PLATFORM_TOOLSET_CUDA} STREQUAL "${_MAJOR}.${_MINOR}")) - message(FATAL_ERROR "Auto detected cuda version ${CMAKE_VS_PLATFORM_TOOLSET_CUDA}" - " is mismatch with ENV{CUDA_PATH} $ENV{CUDA_PATH}. Please modify CUDA_PATH" - " to match ${CMAKE_VS_PLATFORM_TOOLSET_CUDA} or specify cuda toolset by" - " cmake -T cuda=/path/to/cuda ..") - endif () + string(REGEX REPLACE ".*v([0-9]+)\\..*" "\\1" _MAJOR $ENV{CUDA_PATH}) + string(REGEX REPLACE ".*v[0-9]+\\.([0-9]+).*" "\\1" _MINOR $ENV{CUDA_PATH}) + if(NOT (${CMAKE_VS_PLATFORM_TOOLSET_CUDA} STREQUAL "${_MAJOR}.${_MINOR}")) + message( + FATAL_ERROR + "Auto detected cuda version ${CMAKE_VS_PLATFORM_TOOLSET_CUDA}" + " is mismatch with ENV{CUDA_PATH} $ENV{CUDA_PATH}. Please modify CUDA_PATH" + " to match ${CMAKE_VS_PLATFORM_TOOLSET_CUDA} or specify cuda toolset by" + " cmake -T cuda=/path/to/cuda ..") + endif() - if (NOT (DEFINED ENV{CUDA_PATH_V${_MAJOR}_${_MINOR}})) - message(FATAL_ERROR "Please set CUDA_PATH_V${_MAJOR}_${_MINOR} environment variable") - endif () - endif () -endif () + if(NOT (DEFINED ENV{CUDA_PATH_V${_MAJOR}_${_MINOR}})) + message( + FATAL_ERROR + "Please set CUDA_PATH_V${_MAJOR}_${_MINOR} environment variable") + endif() + endif() +endif() # nvcc compiler settings find_package(CUDA REQUIRED) -if (MSVC) - set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc.exe) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/wd4819,/wd4828") - if (HAVE_CXX_FLAG_UTF_8) - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/utf-8") - endif () -else () - set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc) - # Explicitly set the cuda host compiler. Because the default host compiler # - # selected by cmake maybe wrong. - set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) - set(CMAKE_CUDA_FLAGS - "${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC,-Wall,-fvisibility=hidden") - if (CMAKE_CXX_COMPILER_ID MATCHES "GNU") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-gnu-unique") - endif () -endif () +if(MSVC) + set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc.exe) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/wd4819,/wd4828") + if(HAVE_CXX_FLAG_UTF_8) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/utf-8") + endif() +else() + set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc) + # Explicitly set the cuda host compiler. Because the default host compiler # + # selected by cmake maybe wrong. + set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) + set(CMAKE_CUDA_FLAGS + "${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC,-Wall,-fvisibility=hidden") + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-gnu-unique") + endif() +endif() enable_language(CUDA) # set virtual compute architecture and real ones set(_NVCC_FLAGS) -if (NOT CMAKE_CUDA_ARCHITECTURES) - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52") - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_53,code=sm_53") - if (CUDA_VERSION_MAJOR VERSION_GREATER_EQUAL "8") - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_60,code=sm_60") - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61") - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_62,code=sm_62") - endif () - if (CUDA_VERSION_MAJOR VERSION_GREATER_EQUAL "9") - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_70,code=sm_70") - endif () - if (CUDA_VERSION_MAJOR VERSION_GREATER_EQUAL "10") - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_72,code=sm_72") - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_75,code=sm_75") - endif () - if (CUDA_VERSION_MAJOR VERSION_GREATER_EQUAL "11") - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_80,code=sm_80") - if (CUDA_VERSION_MINOR VERSION_GREATER_EQUAL "1") - # cuda doesn't support `sm_86` until version 11.1 - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_86,code=sm_86") - endif () - if (CUDA_VERSION_MINOR VERSION_GREATER_EQUAL "4") - set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_87,code=sm_87") - endif () - endif () -endif () +if(NOT CMAKE_CUDA_ARCHITECTURES) + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52") + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_53,code=sm_53") + if(CUDA_VERSION_MAJOR VERSION_GREATER_EQUAL "8") + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_60,code=sm_60") + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61") + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_62,code=sm_62") + endif() + if(CUDA_VERSION_MAJOR VERSION_GREATER_EQUAL "9") + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_70,code=sm_70") + endif() + if(CUDA_VERSION_MAJOR VERSION_GREATER_EQUAL "10") + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_72,code=sm_72") + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_75,code=sm_75") + endif() + if(CUDA_VERSION_MAJOR VERSION_GREATER_EQUAL "11") + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_80,code=sm_80") + if(CUDA_VERSION_MINOR VERSION_GREATER_EQUAL "1") + # cuda doesn't support `sm_86` until version 11.1 + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_86,code=sm_86") + endif() + if(CUDA_VERSION_MINOR VERSION_GREATER_EQUAL "4") + set(_NVCC_FLAGS "${_NVCC_FLAGS} -gencode arch=compute_87,code=sm_87") + endif() + endif() +endif() set(CMAKE_CUDA_FLAGS_DEBUG "-g -O0") set(CMAKE_CUDA_FLAGS_RELEASE "-O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMMDEPLOY_USE_CUDA=1") -if (NOT MSVC) - set(CMAKE_CUDA_STANDARD 14) -endif () +if(NOT MSVC) + set(CMAKE_CUDA_STANDARD 14) +endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${_NVCC_FLAGS}") -if (MSVC AND MMDEPLOY_USE_MSCV_STATIC) - string(REPLACE -MD -MT CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG}") - string(REPLACE -MD -MT CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE}") -endif () +if(MSVC AND MMDEPLOY_USE_MSCV_STATIC) + string(REPLACE -MD -MT CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG}") + string(REPLACE -MD -MT CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE}") +endif() diff --git a/cmake/filesystem.cmake b/cmake/filesystem.cmake index 787923f2cc..14f1aaaadf 100644 --- a/cmake/filesystem.cmake +++ b/cmake/filesystem.cmake @@ -1,43 +1,48 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# Modified from https://github.com/pybind/pybind11/blob/master/tests/CMakeLists.txt +# Copyright (c) OpenMMLab. All rights reserved. Modified from +# https://github.com/pybind/pybind11/blob/master/tests/CMakeLists.txt -if (MSVC) - set(STD_FS_NO_LIB_NEEDED TRUE) -else () - file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/main.cpp - "#include \nint main(int,char**argv){return std::filesystem::path(argv[0]).string().length();}") - try_compile(HAS_INC_FS ${CMAKE_CURRENT_BINARY_DIR} - SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp - COMPILE_DEFINITIONS -std=c++17 -c) +if(MSVC) + set(STD_FS_NO_LIB_NEEDED TRUE) +else() + file( + WRITE ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + "#include \nint main(int,char**argv){return std::filesystem::path(argv[0]).string().length();}" + ) + try_compile( + HAS_INC_FS ${CMAKE_CURRENT_BINARY_DIR} + SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + COMPILE_DEFINITIONS -std=c++17 -c) - if (NOT HAS_INC_FS) - file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/main.cpp - "#include \nint main(int,char**argv){return std::experimental::filesystem::path(argv[0]).string().length();}") - endif () + if(NOT HAS_INC_FS) + file( + WRITE ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + "#include \nint main(int,char**argv){return std::experimental::filesystem::path(argv[0]).string().length();}" + ) + endif() - try_compile( - STD_FS_NO_LIB_NEEDED ${CMAKE_CURRENT_BINARY_DIR} - SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp - COMPILE_DEFINITIONS -std=c++17) - try_compile( - STD_FS_NEEDS_STDCXXFS ${CMAKE_CURRENT_BINARY_DIR} - SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp - COMPILE_DEFINITIONS -std=c++17 - LINK_LIBRARIES stdc++fs) - try_compile( - STD_FS_NEEDS_CXXFS ${CMAKE_CURRENT_BINARY_DIR} - SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp - COMPILE_DEFINITIONS -std=c++17 - LINK_LIBRARIES c++fs) -endif () + try_compile( + STD_FS_NO_LIB_NEEDED ${CMAKE_CURRENT_BINARY_DIR} + SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + COMPILE_DEFINITIONS -std=c++17) + try_compile( + STD_FS_NEEDS_STDCXXFS ${CMAKE_CURRENT_BINARY_DIR} + SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + COMPILE_DEFINITIONS -std=c++17 + LINK_LIBRARIES stdc++fs) + try_compile( + STD_FS_NEEDS_CXXFS ${CMAKE_CURRENT_BINARY_DIR} + SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + COMPILE_DEFINITIONS -std=c++17 + LINK_LIBRARIES c++fs) +endif() -if (${STD_FS_NO_LIB_NEEDED}) - set(STD_FS_LIB "") -elseif (${STD_FS_NEEDS_STDCXXFS}) - set(STD_FS_LIB stdc++fs) -elseif (${STD_FS_NEEDS_CXXFS}) - set(STD_FS_LIB c++fs) -else () - message(WARNING "Unknown C++17 compiler - not passing -lstdc++fs") - set(STD_FS_LIB "") -endif () +if(${STD_FS_NO_LIB_NEEDED}) + set(STD_FS_LIB "") +elseif(${STD_FS_NEEDS_STDCXXFS}) + set(STD_FS_LIB stdc++fs) +elseif(${STD_FS_NEEDS_CXXFS}) + set(STD_FS_LIB c++fs) +else() + message(WARNING "Unknown C++17 compiler - not passing -lstdc++fs") + set(STD_FS_LIB "") +endif() diff --git a/cmake/modules/FindCUDNN.cmake b/cmake/modules/FindCUDNN.cmake index 3f3f9b893a..332fad48eb 100644 --- a/cmake/modules/FindCUDNN.cmake +++ b/cmake/modules/FindCUDNN.cmake @@ -1,36 +1,39 @@ # Copyright (c) OpenMMLab. All rights reserved. -if (NOT DEFINED CUDNN_DIR) - set(CUDNN_DIR $ENV{CUDNN_DIR}) -endif () +if(NOT DEFINED CUDNN_DIR) + set(CUDNN_DIR $ENV{CUDNN_DIR}) +endif() find_path( - CUDNN_INCLUDE_DIR cudnn.h - HINTS ${CUDNN_DIR} ${CUDA_TOOLKIT_ROOT_DIR} - PATH_SUFFIXES include) + CUDNN_INCLUDE_DIR cudnn.h + HINTS ${CUDNN_DIR} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES include) find_library( - CUDNN_LIBRARY_CUDNN_PATH cudnn - HINTS ${CUDNN_DIR} ${CUDA_TOOLKIT_ROOT_DIR} - PATH_SUFFIXES lib lib64 lib/x64) + CUDNN_LIBRARY_CUDNN_PATH cudnn + HINTS ${CUDNN_DIR} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 lib/x64) -if (NOT (CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY_CUDNN_PATH)) - message(FATAL_ERROR "Couldn't find cuDNN in CUDNN_DIR: ${CUDNN_DIR}, " - "or in CUDA_TOOLKIT_ROOT_DIR: ${CUDA_TOOLKIT_ROOT_DIR}, " - "please check if the path is correct.") +if(NOT (CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY_CUDNN_PATH)) + message( + FATAL_ERROR + "Couldn't find cuDNN in CUDNN_DIR: ${CUDNN_DIR}, " + "or in CUDA_TOOLKIT_ROOT_DIR: ${CUDA_TOOLKIT_ROOT_DIR}, " + "please check if the path is correct.") endif() add_library(cudnn SHARED IMPORTED) -set_property(TARGET cudnn APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) -if (MSVC) - set_target_properties(cudnn PROPERTIES - IMPORTED_IMPLIB_RELEASE ${CUDNN_LIBRARY_CUDNN_PATH} - INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR} - ) +set_property( + TARGET cudnn + APPEND + PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +if(MSVC) + set_target_properties( + cudnn PROPERTIES IMPORTED_IMPLIB_RELEASE ${CUDNN_LIBRARY_CUDNN_PATH} + INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}) else() - set_target_properties(cudnn PROPERTIES - IMPORTED_LOCATION_RELEASE ${CUDNN_LIBRARY_CUDNN_PATH} - INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR} - ) + set_target_properties( + cudnn PROPERTIES IMPORTED_LOCATION_RELEASE ${CUDNN_LIBRARY_CUDNN_PATH} + INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}) endif() diff --git a/cmake/modules/FindONNXRUNTIME.cmake b/cmake/modules/FindONNXRUNTIME.cmake index 63ea176595..d3eff87f65 100644 --- a/cmake/modules/FindONNXRUNTIME.cmake +++ b/cmake/modules/FindONNXRUNTIME.cmake @@ -1,36 +1,40 @@ # Copyright (c) OpenMMLab. All rights reserved. -if (NOT DEFINED ONNXRUNTIME_DIR) - set(ONNXRUNTIME_DIR $ENV{ONNXRUNTIME_DIR}) -endif () -if (NOT ONNXRUNTIME_DIR) - message(FATAL_ERROR "Please set ONNXRUNTIME_DIR with cmake -D option.") +if(NOT DEFINED ONNXRUNTIME_DIR) + set(ONNXRUNTIME_DIR $ENV{ONNXRUNTIME_DIR}) +endif() +if(NOT ONNXRUNTIME_DIR) + message(FATAL_ERROR "Please set ONNXRUNTIME_DIR with cmake -D option.") endif() find_path( - ONNXRUNTIME_INCLUDE_DIR onnxruntime_cxx_api.h - HINTS ${ONNXRUNTIME_DIR} - PATH_SUFFIXES include) + ONNXRUNTIME_INCLUDE_DIR onnxruntime_cxx_api.h + HINTS ${ONNXRUNTIME_DIR} + PATH_SUFFIXES include) find_library( - ONNXRUNTIME_LIBRARY_ONNXRUNTIME_PATH onnxruntime - HINTS ${ONNXRUNTIME_DIR} - PATH_SUFFIXES lib lib64 lib/x64) -if (NOT (ONNXRUNTIME_INCLUDE_DIR AND ONNXRUNTIME_LIBRARY_ONNXRUNTIME_PATH)) - message(FATAL_ERROR "Couldn't find onnxruntime in ONNXRUNTIME_DIR: " - "${ONNXRUNTIME_DIR}, please check if the path is correct.") + ONNXRUNTIME_LIBRARY_ONNXRUNTIME_PATH onnxruntime + HINTS ${ONNXRUNTIME_DIR} + PATH_SUFFIXES lib lib64 lib/x64) +if(NOT (ONNXRUNTIME_INCLUDE_DIR AND ONNXRUNTIME_LIBRARY_ONNXRUNTIME_PATH)) + message( + FATAL_ERROR "Couldn't find onnxruntime in ONNXRUNTIME_DIR: " + "${ONNXRUNTIME_DIR}, please check if the path is correct.") endif() add_library(onnxruntime SHARED IMPORTED) -set_property(TARGET onnxruntime APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) -if (MSVC) - set_target_properties(onnxruntime PROPERTIES - IMPORTED_IMPLIB_RELEASE ${ONNXRUNTIME_LIBRARY_ONNXRUNTIME_PATH} - INTERFACE_INCLUDE_DIRECTORIES ${ONNXRUNTIME_INCLUDE_DIR} - ) +set_property( + TARGET onnxruntime + APPEND + PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +if(MSVC) + set_target_properties( + onnxruntime + PROPERTIES IMPORTED_IMPLIB_RELEASE ${ONNXRUNTIME_LIBRARY_ONNXRUNTIME_PATH} + INTERFACE_INCLUDE_DIRECTORIES ${ONNXRUNTIME_INCLUDE_DIR}) else() - set_target_properties(onnxruntime PROPERTIES - IMPORTED_LOCATION_RELEASE ${ONNXRUNTIME_LIBRARY_ONNXRUNTIME_PATH} - INTERFACE_INCLUDE_DIRECTORIES ${ONNXRUNTIME_INCLUDE_DIR} - ) + set_target_properties( + onnxruntime + PROPERTIES IMPORTED_LOCATION_RELEASE ${ONNXRUNTIME_LIBRARY_ONNXRUNTIME_PATH} + INTERFACE_INCLUDE_DIRECTORIES ${ONNXRUNTIME_INCLUDE_DIR}) endif() diff --git a/cmake/modules/FindTENSORRT.cmake b/cmake/modules/FindTENSORRT.cmake index e2c328923e..25d015a52c 100644 --- a/cmake/modules/FindTENSORRT.cmake +++ b/cmake/modules/FindTENSORRT.cmake @@ -1,51 +1,56 @@ # Copyright (c) OpenMMLab. All rights reserved. -if (NOT DEFINED TENSORRT_DIR) - set(TENSORRT_DIR $ENV{TENSORRT_DIR}) -endif () -if (NOT TENSORRT_DIR) - message(FATAL_ERROR "Please set TENSORRT_DIR with cmake -D option.") +if(NOT DEFINED TENSORRT_DIR) + set(TENSORRT_DIR $ENV{TENSORRT_DIR}) +endif() +if(NOT TENSORRT_DIR) + message(FATAL_ERROR "Please set TENSORRT_DIR with cmake -D option.") endif() find_path( - TENSORRT_INCLUDE_DIR NvInfer.h - HINTS ${TENSORRT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} - PATH_SUFFIXES include) + TENSORRT_INCLUDE_DIR NvInfer.h + HINTS ${TENSORRT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES include) -if (NOT TENSORRT_INCLUDE_DIR) - message(FATAL_ERROR "Cannot find TensorRT header NvInfer.h " - "in TENSORRT_DIR: ${TENSORRT_DIR} or in CUDA_TOOLKIT_ROOT_DIR: " - "${CUDA_TOOLKIT_ROOT_DIR}, please check if the path is correct.") -endif () +if(NOT TENSORRT_INCLUDE_DIR) + message( + FATAL_ERROR + "Cannot find TensorRT header NvInfer.h " + "in TENSORRT_DIR: ${TENSORRT_DIR} or in CUDA_TOOLKIT_ROOT_DIR: " + "${CUDA_TOOLKIT_ROOT_DIR}, please check if the path is correct.") +endif() set(__TENSORRT_LIB_COMPONENTS nvinfer;nvinfer_plugin) foreach(__component ${__TENSORRT_LIB_COMPONENTS}) - find_library( - __component_path ${__component} - HINTS ${TENSORRT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} - PATH_SUFFIXES lib lib64 lib/x64) - if (NOT __component_path) - message(FATAL_ERROR "Cannot find TensorRT lib ${__component} in " - "TENSORRT_DIR: ${TENSORRT_DIR} or CUDA_TOOLKIT_ROOT_DIR: ${CUDA_TOOLKIT_ROOT_DIR}, " - "please check if the path is correct") - endif() + find_library( + __component_path ${__component} + HINTS ${TENSORRT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 lib/x64) + if(NOT __component_path) + message( + FATAL_ERROR + "Cannot find TensorRT lib ${__component} in " + "TENSORRT_DIR: ${TENSORRT_DIR} or CUDA_TOOLKIT_ROOT_DIR: ${CUDA_TOOLKIT_ROOT_DIR}, " + "please check if the path is correct") + endif() - add_library(${__component} SHARED IMPORTED) - set_property(TARGET ${__component} APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) - if (MSVC) - set_target_properties( - ${__component} PROPERTIES - IMPORTED_IMPLIB_RELEASE ${__component_path} - INTERFACE_INCLUDE_DIRECTORIES ${TENSORRT_INCLUDE_DIR} - ) - else() - set_target_properties( - ${__component} PROPERTIES - IMPORTED_LOCATION_RELEASE ${__component_path} - INTERFACE_INCLUDE_DIRECTORIES ${TENSORRT_INCLUDE_DIR} - ) - endif() - unset(__component_path CACHE) + add_library(${__component} SHARED IMPORTED) + set_property( + TARGET ${__component} + APPEND + PROPERTY IMPORTED_CONFIGURATIONS RELEASE) + if(MSVC) + set_target_properties( + ${__component} + PROPERTIES IMPORTED_IMPLIB_RELEASE ${__component_path} + INTERFACE_INCLUDE_DIRECTORIES ${TENSORRT_INCLUDE_DIR}) + else() + set_target_properties( + ${__component} + PROPERTIES IMPORTED_LOCATION_RELEASE ${__component_path} + INTERFACE_INCLUDE_DIRECTORIES ${TENSORRT_INCLUDE_DIR}) + endif() + unset(__component_path CACHE) endforeach() set(TENSORRT_LIBS ${__TENSORRT_LIB_COMPONENTS}) diff --git a/cmake/modules/FindTVM.cmake b/cmake/modules/FindTVM.cmake index f6443609e4..8ae3a48abd 100644 --- a/cmake/modules/FindTVM.cmake +++ b/cmake/modules/FindTVM.cmake @@ -1,47 +1,56 @@ # Copyright (c) OpenMMLab. All rights reserved. -if (NOT DEFINED TVM_DIR) - set(TVM_DIR $ENV{TVM_DIR}) -endif () -if (NOT TVM_DIR) - message(FATAL_ERROR "Please set TVM_DIR with cmake -D option.") +if(NOT DEFINED TVM_DIR) + set(TVM_DIR $ENV{TVM_DIR}) +endif() +if(NOT TVM_DIR) + message(FATAL_ERROR "Please set TVM_DIR with cmake -D option.") endif() find_path( - TVM_INCLUDE_DIR tvm/runtime/c_runtime_api.h - HINTS ${TVM_DIR} - PATH_SUFFIXES include) + TVM_INCLUDE_DIR tvm/runtime/c_runtime_api.h + HINTS ${TVM_DIR} + PATH_SUFFIXES include) find_path( - DMLC_CORE_INCLUDE_DIR dmlc/io.h - HINTS ${TVM_DIR}/3rdparty/dmlc-core - PATH_SUFFIXES include) + DMLC_CORE_INCLUDE_DIR dmlc/io.h + HINTS ${TVM_DIR}/3rdparty/dmlc-core + PATH_SUFFIXES include) find_path( - DLPACK_INCLUDE_DIR dlpack/dlpack.h - HINTS ${TVM_DIR}/3rdparty/dlpack - PATH_SUFFIXES include) + DLPACK_INCLUDE_DIR dlpack/dlpack.h + HINTS ${TVM_DIR}/3rdparty/dlpack + PATH_SUFFIXES include) find_library( - TVM_LIBRARY_PATH tvm_runtime - HINTS ${TVM_DIR} - PATH_SUFFIXES build lib build/${CMAKE_BUILD_TYPE}) -if (NOT (TVM_INCLUDE_DIR AND DMLC_CORE_INCLUDE_DIR AND DLPACK_INCLUDE_DIR AND TVM_LIBRARY_PATH)) - message(FATAL_ERROR "Couldn't find tvm in TVM_DIR: " - "${TVM_DIR}, please check if the path is correct.") + TVM_LIBRARY_PATH tvm_runtime + HINTS ${TVM_DIR} + PATH_SUFFIXES build lib build/${CMAKE_BUILD_TYPE}) +if(NOT + (TVM_INCLUDE_DIR + AND DMLC_CORE_INCLUDE_DIR + AND DLPACK_INCLUDE_DIR + AND TVM_LIBRARY_PATH)) + message(FATAL_ERROR "Couldn't find tvm in TVM_DIR: " + "${TVM_DIR}, please check if the path is correct.") endif() add_library(tvm_runtime SHARED IMPORTED) -set_property(TARGET tvm_runtime APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) -if (MSVC) - set_target_properties(tvm_runtime PROPERTIES - IMPORTED_IMPLIB_RELEASE ${TVM_LIBRARY_PATH} - INTERFACE_INCLUDE_DIRECTORIES ${TVM_INCLUDE_DIR} ${DMLC_CORE_INCLUDE_DIR} ${DLPACK_INCLUDE_DIR} - ) +set_property( + TARGET tvm_runtime + APPEND + PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +if(MSVC) + set_target_properties( + tvm_runtime + PROPERTIES IMPORTED_IMPLIB_RELEASE ${TVM_LIBRARY_PATH} + INTERFACE_INCLUDE_DIRECTORIES ${TVM_INCLUDE_DIR} + ${DMLC_CORE_INCLUDE_DIR} ${DLPACK_INCLUDE_DIR}) else() - set_target_properties(tvm_runtime PROPERTIES - IMPORTED_LOCATION_RELEASE ${TVM_LIBRARY_PATH} - INTERFACE_INCLUDE_DIRECTORIES ${TVM_INCLUDE_DIR} ${DMLC_CORE_INCLUDE_DIR} ${DLPACK_INCLUDE_DIR} - ) + set_target_properties( + tvm_runtime + PROPERTIES IMPORTED_LOCATION_RELEASE ${TVM_LIBRARY_PATH} + INTERFACE_INCLUDE_DIRECTORIES ${TVM_INCLUDE_DIR} + ${DMLC_CORE_INCLUDE_DIR} ${DLPACK_INCLUDE_DIR}) endif() diff --git a/cmake/post-install.cmake b/cmake/post-install.cmake index d289e53996..c9ae0d6dd9 100644 --- a/cmake/post-install.cmake +++ b/cmake/post-install.cmake @@ -1,10 +1,10 @@ - -set(_TARGETS_PATH ${CMAKE_INSTALL_PREFIX}/lib/cmake/MMDeploy/MMDeployTargets.cmake) +set(_TARGETS_PATH + ${CMAKE_INSTALL_PREFIX}/lib/cmake/MMDeploy/MMDeployTargets.cmake) file(READ ${_TARGETS_PATH} _MMDEPLOY_TARGETS) -string(REGEX REPLACE "::@<0x[a-z0-9]+>" "" - _MMDEPLOY_TARGETS_FIXED "${_MMDEPLOY_TARGETS}") +string(REGEX REPLACE "::@<0x[a-z0-9]+>" "" _MMDEPLOY_TARGETS_FIXED + "${_MMDEPLOY_TARGETS}") file(WRITE ${_TARGETS_PATH} "${_MMDEPLOY_TARGETS_FIXED}") diff --git a/cmake/stacktrace.cmake b/cmake/stacktrace.cmake index bd0761a217..4ef719aaa2 100644 --- a/cmake/stacktrace.cmake +++ b/cmake/stacktrace.cmake @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. find_package(Boost 1.65 COMPONENTS stacktrace_backtrace) -if (Boost_FOUND) - target_link_libraries(mmdeploy_core PUBLIC Boost::stacktrace_backtrace) - target_compile_definitions(mmdeploy_core PUBLIC -DMMDEPLOY_STATUS_USE_STACKTRACE=1) +if(Boost_FOUND) + target_link_libraries(mmdeploy_core PUBLIC Boost::stacktrace_backtrace) + target_compile_definitions(mmdeploy_core + PUBLIC -DMMDEPLOY_STATUS_USE_STACKTRACE=1) endif() diff --git a/cmake/tensorrt.cmake b/cmake/tensorrt.cmake index 546a85070d..1302f4bbcf 100644 --- a/cmake/tensorrt.cmake +++ b/cmake/tensorrt.cmake @@ -3,35 +3,33 @@ include(${CMAKE_SOURCE_DIR}/cmake/modules/FindTENSORRT.cmake) include(${CMAKE_SOURCE_DIR}/cmake/modules/FindCUDNN.cmake) find_path( - TENSORRT_INCLUDE_DIR NvInfer.h - HINTS ${TENSORRT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} - PATH_SUFFIXES include) -if (TENSORRT_INCLUDE_DIR) - message(STATUS "Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}") -else () - message(ERROR "Cannot find TensorRT headers") -endif () + TENSORRT_INCLUDE_DIR NvInfer.h + HINTS ${TENSORRT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES include) +if(TENSORRT_INCLUDE_DIR) + message(STATUS "Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}") +else() + message(ERROR "Cannot find TensorRT headers") +endif() find_library( - TENSORRT_LIBRARY_INFER nvinfer - HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} - PATH_SUFFIXES lib lib64 lib/x64) + TENSORRT_LIBRARY_INFER nvinfer + HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 lib/x64) find_library( - TENSORRT_LIBRARY_INFER_PLUGIN nvinfer_plugin - HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} - PATH_SUFFIXES lib lib64 lib/x64) -set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} - ${TENSORRT_LIBRARY_INFER_PLUGIN}) -if (TENSORRT_LIBRARY_INFER - AND TENSORRT_LIBRARY_INFER_PLUGIN) - message(STATUS "Found TensorRT libs at ${TENSORRT_LIBRARY}") -else () - message(FATAL_ERROR "Cannot find TensorRT libs") -endif () + TENSORRT_LIBRARY_INFER_PLUGIN nvinfer_plugin + HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 lib/x64) +set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN}) +if(TENSORRT_LIBRARY_INFER AND TENSORRT_LIBRARY_INFER_PLUGIN) + message(STATUS "Found TensorRT libs at ${TENSORRT_LIBRARY}") +else() + message(FATAL_ERROR "Cannot find TensorRT libs") +endif() include(FindPackageHandleStandardArgs) find_package_handle_standard_args(TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR - TENSORRT_LIBRARY) -if (NOT TENSORRT_FOUND) - message(ERROR "Cannot find TensorRT library.") -endif () + TENSORRT_LIBRARY) +if(NOT TENSORRT_FOUND) + message(ERROR "Cannot find TensorRT library.") +endif() diff --git a/cmake/toolchains/aarch64-linux-gnu.cmake b/cmake/toolchains/aarch64-linux-gnu.cmake index f95911efd1..dfb3fced21 100644 --- a/cmake/toolchains/aarch64-linux-gnu.cmake +++ b/cmake/toolchains/aarch64-linux-gnu.cmake @@ -13,5 +13,9 @@ set(CMAKE_C_FLAGS "-march=armv8-a") set(CMAKE_CXX_FLAGS "-march=armv8-a") # cache flags -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags") +set(CMAKE_C_FLAGS + "${CMAKE_C_FLAGS}" + CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS}" + CACHE STRING "c++ flags") diff --git a/cmake/toolchains/arm-linux-gnueabihf.cmake b/cmake/toolchains/arm-linux-gnueabihf.cmake index 74ed5bf935..d4cfe513a0 100644 --- a/cmake/toolchains/arm-linux-gnueabihf.cmake +++ b/cmake/toolchains/arm-linux-gnueabihf.cmake @@ -12,5 +12,9 @@ set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon") set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon") # cache flags -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags") +set(CMAKE_C_FLAGS + "${CMAKE_C_FLAGS}" + CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS}" + CACHE STRING "c++ flags") diff --git a/cmake/toolchains/riscv64-linux-gnu.cmake b/cmake/toolchains/riscv64-linux-gnu.cmake index e3b3b2adbc..a6515dbd7f 100644 --- a/cmake/toolchains/riscv64-linux-gnu.cmake +++ b/cmake/toolchains/riscv64-linux-gnu.cmake @@ -13,5 +13,9 @@ set(CMAKE_C_FLAGS "-march=rv64gc") set(CMAKE_CXX_FLAGS "-march=rv64gc") # cache flags -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags") +set(CMAKE_C_FLAGS + "${CMAKE_C_FLAGS}" + CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS}" + CACHE STRING "c++ flags") diff --git a/cmake/toolchains/riscv64-unknown-linux-gnu.cmake b/cmake/toolchains/riscv64-unknown-linux-gnu.cmake index c24661f6e6..93ddc583fe 100644 --- a/cmake/toolchains/riscv64-unknown-linux-gnu.cmake +++ b/cmake/toolchains/riscv64-unknown-linux-gnu.cmake @@ -2,15 +2,17 @@ set(CMAKE_SYSTEM_NAME Linux) set(CMAKE_SYSTEM_PROCESSOR riscv) if(DEFINED ENV{RISCV_ROOT_PATH}) - file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) + file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) else() - message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") + message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") endif() set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc) set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++) -set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot" CACHE PATH "riscv sysroot") +set(CMAKE_SYSROOT + "${RISCV_ROOT_PATH}/sysroot" + CACHE PATH "riscv sysroot") set(CMAKE_FIND_ROOT_PATH ${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu) set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) @@ -22,5 +24,9 @@ set(CMAKE_C_FLAGS "-march=rv64gc") set(CMAKE_CXX_FLAGS "-march=rv64gc") # cache flags -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags") +set(CMAKE_C_FLAGS + "${CMAKE_C_FLAGS}" + CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS}" + CACHE STRING "c++ flags") diff --git a/cmake/toolchains/rknpu2-linux-gnu.cmake b/cmake/toolchains/rknpu2-linux-gnu.cmake index 2bb6835430..4a94f8b238 100644 --- a/cmake/toolchains/rknpu2-linux-gnu.cmake +++ b/cmake/toolchains/rknpu2-linux-gnu.cmake @@ -2,9 +2,9 @@ set(CMAKE_SYSTEM_NAME Linux) set(CMAKE_SYSTEM_PROCESSOR rockchip) if(DEFINED ENV{RKNN_TOOL_CHAIN}) - file(TO_CMAKE_PATH $ENV{RKNN_TOOL_CHAIN} RKNN_TOOL_CHAIN) + file(TO_CMAKE_PATH $ENV{RKNN_TOOL_CHAIN} RKNN_TOOL_CHAIN) else() - message(FATAL_ERROR "RKNN_TOOL_CHAIN env must be defined") + message(FATAL_ERROR "RKNN_TOOL_CHAIN env must be defined") endif() set(CMAKE_C_COMPILER ${RKNN_TOOL_CHAIN}/bin/aarch64-rockchip-linux-gnu-gcc) @@ -19,5 +19,9 @@ set(CMAKE_C_FLAGS "-Wl,--allow-shlib-undefined") set(CMAKE_CXX_FLAGS "-Wl,--allow-shlib-undefined") # cache flags -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags") +set(CMAKE_C_FLAGS + "${CMAKE_C_FLAGS}" + CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS}" + CACHE STRING "c++ flags") diff --git a/csrc/mmdeploy/CMakeLists.txt b/csrc/mmdeploy/CMakeLists.txt index 6bfbd3a95a..26dce4f586 100644 --- a/csrc/mmdeploy/CMakeLists.txt +++ b/csrc/mmdeploy/CMakeLists.txt @@ -2,20 +2,20 @@ add_subdirectory(backend_ops) -if (MMDEPLOY_BUILD_SDK) - # include OpenCV for SDK modules since many of them depends on it - include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake) +if(MMDEPLOY_BUILD_SDK) + # include OpenCV for SDK modules since many of them depends on it + include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake) - add_subdirectory(core) - add_subdirectory(execution) - add_subdirectory(utils) - add_subdirectory(archive) - add_subdirectory(device) - add_subdirectory(graph) - add_subdirectory(model) - add_subdirectory(operation) - add_subdirectory(preprocess) - add_subdirectory(net) - add_subdirectory(codebase) - add_subdirectory(apis) -endif () + add_subdirectory(core) + add_subdirectory(execution) + add_subdirectory(utils) + add_subdirectory(archive) + add_subdirectory(device) + add_subdirectory(graph) + add_subdirectory(model) + add_subdirectory(operation) + add_subdirectory(preprocess) + add_subdirectory(net) + add_subdirectory(codebase) + add_subdirectory(apis) +endif() diff --git a/csrc/mmdeploy/apis/CMakeLists.txt b/csrc/mmdeploy/apis/CMakeLists.txt index 1ab877be90..e137bce311 100644 --- a/csrc/mmdeploy/apis/CMakeLists.txt +++ b/csrc/mmdeploy/apis/CMakeLists.txt @@ -4,8 +4,8 @@ add_subdirectory(c) add_subdirectory(cxx) add_subdirectory(java) -# add python subdir conditionally since it's designed to work as -# a standalone project also -if (MMDEPLOY_BUILD_SDK_PYTHON_API) - add_subdirectory(python) -endif () +# add python subdir conditionally since it's designed to work as a standalone +# project also +if(MMDEPLOY_BUILD_SDK_PYTHON_API) + add_subdirectory(python) +endif() diff --git a/csrc/mmdeploy/apis/c/CMakeLists.txt b/csrc/mmdeploy/apis/c/CMakeLists.txt index f08fa8cf86..4c1755b168 100644 --- a/csrc/mmdeploy/apis/c/CMakeLists.txt +++ b/csrc/mmdeploy/apis/c/CMakeLists.txt @@ -6,81 +6,76 @@ include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake) set(CAPI_OBJS) macro(add_object name) - add_library(${name} OBJECT ${ARGN}) - set_target_properties(${name} PROPERTIES POSITION_INDEPENDENT_CODE 1) - target_compile_definitions(${name} PRIVATE -DMMDEPLOY_API_EXPORTS=1) - if (NOT MSVC) - target_compile_options(${name} PRIVATE $<$:-fvisibility=hidden>) - endif () - target_link_libraries(${name} PRIVATE mmdeploy::core) - target_include_directories(${name} PUBLIC - $ - $) - set(CAPI_OBJS ${CAPI_OBJS} ${name}) - mmdeploy_export(${name}) + add_library(${name} OBJECT ${ARGN}) + set_target_properties(${name} PROPERTIES POSITION_INDEPENDENT_CODE 1) + target_compile_definitions(${name} PRIVATE -DMMDEPLOY_API_EXPORTS=1) + if(NOT MSVC) + target_compile_options( + ${name} PRIVATE $<$:-fvisibility=hidden>) + endif() + target_link_libraries(${name} PRIVATE mmdeploy::core) + target_include_directories( + ${name} PUBLIC $ + $) + set(CAPI_OBJS ${CAPI_OBJS} ${name}) + mmdeploy_export(${name}) endmacro() -set(COMMON_LIST - common - model - executor - pipeline) +set(COMMON_LIST common model executor pipeline) set(TASK_LIST ${MMDEPLOY_TASKS}) -foreach (TASK ${COMMON_LIST}) - set(TARGET_NAME mmdeploy_${TASK}) - set(OBJECT_NAME mmdeploy_${TASK}_obj) - add_object(${OBJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${TASK}.cpp) - mmdeploy_add_library(${TARGET_NAME}) - target_link_libraries(${TARGET_NAME} PRIVATE ${OBJECT_NAME}) - target_include_directories(${TARGET_NAME} PUBLIC - $ - $) - install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${TASK}.h - DESTINATION include/mmdeploy) -endforeach () +foreach(TASK ${COMMON_LIST}) + set(TARGET_NAME mmdeploy_${TASK}) + set(OBJECT_NAME mmdeploy_${TASK}_obj) + add_object(${OBJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${TASK}.cpp) + mmdeploy_add_library(${TARGET_NAME}) + target_link_libraries(${TARGET_NAME} PRIVATE ${OBJECT_NAME}) + target_include_directories( + ${TARGET_NAME} PUBLIC $ + $) + install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${TASK}.h + DESTINATION include/mmdeploy) +endforeach() -target_link_libraries(mmdeploy_executor PUBLIC - mmdeploy_execution mmdeploy_common) -target_link_libraries(mmdeploy_pipeline PUBLIC - mmdeploy_executor mmdeploy_model mmdeploy_common) +target_link_libraries(mmdeploy_executor PUBLIC mmdeploy_execution + mmdeploy_common) +target_link_libraries(mmdeploy_pipeline PUBLIC mmdeploy_executor mmdeploy_model + mmdeploy_common) -foreach (TASK ${TASK_LIST}) - set(TARGET_NAME mmdeploy_${TASK}) - set(OBJECT_NAME mmdeploy_${TASK}_obj) - add_object(${OBJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${TASK}.cpp) - mmdeploy_add_library(${TARGET_NAME}) - target_link_libraries(${TARGET_NAME} PRIVATE ${OBJECT_NAME} - mmdeploy_pipeline) - target_include_directories(${TARGET_NAME} PUBLIC - $ - $) - install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${TASK}.h - DESTINATION include/mmdeploy) -endforeach () +foreach(TASK ${TASK_LIST}) + set(TARGET_NAME mmdeploy_${TASK}) + set(OBJECT_NAME mmdeploy_${TASK}_obj) + add_object(${OBJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${TASK}.cpp) + mmdeploy_add_library(${TARGET_NAME}) + target_link_libraries(${TARGET_NAME} PRIVATE ${OBJECT_NAME} mmdeploy_pipeline) + target_include_directories( + ${TARGET_NAME} PUBLIC $ + $) + install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${TASK}.h + DESTINATION include/mmdeploy) +endforeach() -install(DIRECTORY ${CMAKE_SOURCE_DIR}/demo/csrc/ DESTINATION example/cpp - FILES_MATCHING - PATTERN "*.cpp" - PATTERN "CMakeLists.txt" - ) +install( + DIRECTORY ${CMAKE_SOURCE_DIR}/demo/csrc/ + DESTINATION example/cpp + FILES_MATCHING + PATTERN "*.cpp" + PATTERN "CMakeLists.txt") -if (MMDEPLOY_BUILD_SDK_CSHARP_API OR MMDEPLOY_BUILD_SDK_MONOLITHIC) - add_library(mmdeploy SHARED) - mmdeploy_load_static(mmdeploy MMDeployStaticModules) - mmdeploy_load_dynamic(mmdeploy MMDeployDynamicModules) - target_link_libraries(mmdeploy PRIVATE ${CAPI_OBJS} mmdeploy_execution) - target_include_directories(mmdeploy PUBLIC - $ - $) - set(MMDEPLOY_VERSION ${MMDEPLOY_VERSION_MAJOR} - .${MMDEPLOY_VERSION_MINOR} - .${MMDEPLOY_VERSION_PATCH}) - string(REPLACE ";" "" MMDEPLOY_VERSION ${MMDEPLOY_VERSION}) - set_target_properties(mmdeploy PROPERTIES - VERSION ${MMDEPLOY_VERSION} - SOVERSION ${MMDEPLOY_VERSION_MAJOR}) - mmdeploy_add_rpath(mmdeploy) - mmdeploy_export_impl(mmdeploy) -endif () +if(MMDEPLOY_BUILD_SDK_CSHARP_API OR MMDEPLOY_BUILD_SDK_MONOLITHIC) + add_library(mmdeploy SHARED) + mmdeploy_load_static(mmdeploy MMDeployStaticModules) + mmdeploy_load_dynamic(mmdeploy MMDeployDynamicModules) + target_link_libraries(mmdeploy PRIVATE ${CAPI_OBJS} mmdeploy_execution) + target_include_directories( + mmdeploy PUBLIC $ + $) + set(MMDEPLOY_VERSION ${MMDEPLOY_VERSION_MAJOR} .${MMDEPLOY_VERSION_MINOR} + .${MMDEPLOY_VERSION_PATCH}) + string(REPLACE ";" "" MMDEPLOY_VERSION ${MMDEPLOY_VERSION}) + set_target_properties(mmdeploy PROPERTIES VERSION ${MMDEPLOY_VERSION} + SOVERSION ${MMDEPLOY_VERSION_MAJOR}) + mmdeploy_add_rpath(mmdeploy) + mmdeploy_export_impl(mmdeploy) +endif() diff --git a/csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp b/csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp index 3eec4ef90b..9faf47f349 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp @@ -16,118 +16,132 @@ using namespace mmdeploy; using namespace std; -int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_classifier_t* classifier) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_classifier_t* classifier) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_classifier_create_v2(model, context, classifier); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_classifier_create_v2(model, context, classifier); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_classifier_t* classifier) { - mmdeploy_model_t model{}; +int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_classifier_t* classifier) +{ + mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_classifier_create(model, device_name, device_id, classifier); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_classifier_create(model, device_name, device_id, classifier); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_classifier_t* classifier) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)classifier); +int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_classifier_t* classifier) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)classifier); } -int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value) { - return mmdeploy_common_create_input(mats, mat_count, value); +int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value) +{ + return mmdeploy_common_create_input(mats, mat_count, value); } -int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_classification_t** results, - int** result_count) { - wrapped input; - if (auto ec = mmdeploy_classifier_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_classifier_apply_v2(classifier, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_classifier_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_classification_t** results, int** result_count) +{ + wrapped input; + if (auto ec = mmdeploy_classifier_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_classifier_apply_v2(classifier, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_classifier_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)classifier, input, output); +int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)classifier, input, output); } -int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)classifier, input, output); +int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)classifier, input, output); } -int mmdeploy_classifier_get_result(mmdeploy_value_t output, mmdeploy_classification_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value& value = Cast(output)->front(); - - auto classify_outputs = from_value>(value); - - vector _result_count; - _result_count.reserve(classify_outputs.size()); - - for (const auto& cls_output : classify_outputs) { - _result_count.push_back((int)cls_output.size()); +int mmdeploy_classifier_get_result(mmdeploy_value_t output, mmdeploy_classification_t** results, int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } - - auto total = std::accumulate(begin(_result_count), end(_result_count), 0); - - std::unique_ptr result_count_data(new int[_result_count.size()]{}); - std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); - - std::unique_ptr result_data( - new mmdeploy_classification_t[total]{}); - auto result_ptr = result_data.get(); - for (const auto& cls_output : classify_outputs) { - for (const auto& label : cls_output) { - result_ptr->label_id = label.label_id; - result_ptr->score = label.score; - ++result_ptr; - } + try + { + Value& value = Cast(output)->front(); + + auto classify_outputs = from_value>(value); + + vector _result_count; + _result_count.reserve(classify_outputs.size()); + + for (const auto& cls_output : classify_outputs) + { + _result_count.push_back((int)cls_output.size()); + } + + auto total = std::accumulate(begin(_result_count), end(_result_count), 0); + + std::unique_ptr result_count_data(new int[_result_count.size()]{}); + std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); + + std::unique_ptr result_data( + new mmdeploy_classification_t[total]{}); + auto result_ptr = result_data.get(); + for (const auto& cls_output : classify_outputs) + { + for (const auto& label : cls_output) + { + result_ptr->label_id = label.label_id; + result_ptr->score = label.score; + ++result_ptr; + } + } + + *result_count = result_count_data.release(); + *results = result_data.release(); + + return MMDEPLOY_SUCCESS; } - - *result_count = result_count_data.release(); - *results = result_data.release(); - - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, const int* result_count, - int count) { - delete[] results; - delete[] result_count; +void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, const int* result_count, int count) +{ + delete[] results; + delete[] result_count; } -void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)classifier); +void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)classifier); } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/classifier.h b/csrc/mmdeploy/apis/c/mmdeploy/classifier.h index 54e9d0215b..1681cf7fae 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/classifier.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/classifier.h @@ -13,124 +13,125 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_classification_t { - int label_id; - float score; -} mmdeploy_classification_t; - -typedef struct mmdeploy_classifier* mmdeploy_classifier_t; - -/** - * @brief Create classifier's handle - * @param[in] model an instance of mmclassification sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] classifier instance of a classifier, which must be destroyed - * by \ref mmdeploy_classifier_destroy - * @return status of creating classifier's handle - */ -MMDEPLOY_API int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_classifier_t* classifier); - -/** - * @brief Create classifier's handle - * @param[in] model_path path of mmclassification sdk model exported by mmdeploy model converter - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] classifier instance of a classifier, which must be destroyed - * by \ref mmdeploy_classifier_destroy - * @return status of creating classifier's handle - */ -MMDEPLOY_API int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, - int device_id, - mmdeploy_classifier_t* classifier); - -/** - * @brief Use classifier created by \ref mmdeploy_classifier_create_by_path to get label - * information of each image in a batch - * @param[in] classifier classifier's handle created by \ref mmdeploy_classifier_create_by_path - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer to save classification results of each - * image, which must be freed by \ref mmdeploy_classifier_release_result - * @param[out] result_count a linear buffer with length being \p mat_count to save the number of - * classification results of each image. It must be released by \ref - * mmdeploy_classifier_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_classification_t** results, int** result_count); - -/** - * @brief Release the inference result buffer created \ref mmdeploy_classifier_apply - * @param[in] results classification results buffer - * @param[in] result_count \p results size buffer - * @param[in] count length of \p result_count - */ -MMDEPLOY_API void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, - const int* result_count, int count); - -/** - * @brief Destroy classifier's handle - * @param[in] classifier classifier's handle created by \ref mmdeploy_classifier_create_by_path - */ -MMDEPLOY_API void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_classifier_create, but allows to control execution context of tasks - * via context - */ -MMDEPLOY_API int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_classifier_t* classifier); - -/** - * @brief Pack classifier inputs into mmdeploy_value_t - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] value the packed value - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value); - -/** - * @brief Same as \ref mmdeploy_classifier_apply, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, - mmdeploy_value_t input, mmdeploy_value_t* output); - -/** - * @brief Apply classifier asynchronously - * @param[in] classifier handle of the classifier - * @param[in] input input sender that will be consumed by the operation - * @param[out] output output sender - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -/** - * - * @param[in] output output obtained by applying a classifier - * @param[out] results a linear buffer containing classification results of each image, released by - * \ref mmdeploy_classifier_release_result - * @param[out] result_count a linear buffer containing the number of results for each input image, - * released by \ref mmdeploy_classifier_release_result - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_classifier_get_result(mmdeploy_value_t output, - mmdeploy_classification_t** results, - int** result_count); + typedef struct mmdeploy_classification_t + { + int label_id; + float score; + } mmdeploy_classification_t; + + typedef struct mmdeploy_classifier* mmdeploy_classifier_t; + + /** + * @brief Create classifier's handle + * @param[in] model an instance of mmclassification sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] classifier instance of a classifier, which must be destroyed + * by \ref mmdeploy_classifier_destroy + * @return status of creating classifier's handle + */ + MMDEPLOY_API int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_classifier_t* classifier); + + /** + * @brief Create classifier's handle + * @param[in] model_path path of mmclassification sdk model exported by mmdeploy model converter + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] classifier instance of a classifier, which must be destroyed + * by \ref mmdeploy_classifier_destroy + * @return status of creating classifier's handle + */ + MMDEPLOY_API int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_classifier_t* classifier); + + /** + * @brief Use classifier created by \ref mmdeploy_classifier_create_by_path to get label + * information of each image in a batch + * @param[in] classifier classifier's handle created by \ref mmdeploy_classifier_create_by_path + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer to save classification results of each + * image, which must be freed by \ref mmdeploy_classifier_release_result + * @param[out] result_count a linear buffer with length being \p mat_count to save the number of + * classification results of each image. It must be released by \ref + * mmdeploy_classifier_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_classification_t** results, + int** result_count); + + /** + * @brief Release the inference result buffer created \ref mmdeploy_classifier_apply + * @param[in] results classification results buffer + * @param[in] result_count \p results size buffer + * @param[in] count length of \p result_count + */ + MMDEPLOY_API void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, + const int* result_count, + int count); + + /** + * @brief Destroy classifier's handle + * @param[in] classifier classifier's handle created by \ref mmdeploy_classifier_create_by_path + */ + MMDEPLOY_API void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_classifier_create, but allows to control execution context of tasks + * via context + */ + MMDEPLOY_API int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_classifier_t* classifier); + + /** + * @brief Pack classifier inputs into mmdeploy_value_t + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] value the packed value + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value); + + /** + * @brief Same as \ref mmdeploy_classifier_apply, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Apply classifier asynchronously + * @param[in] classifier handle of the classifier + * @param[in] input input sender that will be consumed by the operation + * @param[out] output output sender + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + /** + * + * @param[in] output output obtained by applying a classifier + * @param[out] results a linear buffer containing classification results of each image, released by + * \ref mmdeploy_classifier_release_result + * @param[out] result_count a linear buffer containing the number of results for each input image, + * released by \ref mmdeploy_classifier_release_result + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_classifier_get_result(mmdeploy_value_t output, + mmdeploy_classification_t** results, + int** result_count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/common.cpp b/csrc/mmdeploy/apis/c/mmdeploy/common.cpp index e00cc3f1cf..81e43ffce3 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/common.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/common.cpp @@ -5,111 +5,144 @@ #include "mmdeploy/core/profiler.h" #include "mmdeploy/executor_internal.h" -mmdeploy_value_t mmdeploy_value_copy(mmdeploy_value_t value) { - if (!value) { - return nullptr; - } - return Guard([&] { return Take(Value(*Cast(value))); }); +mmdeploy_value_t mmdeploy_value_copy(mmdeploy_value_t value) +{ + if (!value) + { + return nullptr; + } + return Guard([&] + { return Take(Value(*Cast(value))); }); } -void mmdeploy_value_destroy(mmdeploy_value_t value) { delete Cast(value); } +void mmdeploy_value_destroy(mmdeploy_value_t value) +{ + delete Cast(value); +} -int mmdeploy_context_create(mmdeploy_context_t* context) { - *context = (mmdeploy_context_t) new Value; - return 0; +int mmdeploy_context_create(mmdeploy_context_t* context) +{ + *context = (mmdeploy_context_t) new Value; + return 0; } -int mmdeploy_context_create_by_device(const char* device_name, int device_id, - mmdeploy_context_t* context) { - mmdeploy_device_t device{}; - int ec = MMDEPLOY_SUCCESS; - mmdeploy_context_t _context{}; - ec = mmdeploy_context_create(&_context); - if (ec != MMDEPLOY_SUCCESS) { - return ec; - } - ec = mmdeploy_device_create(device_name, device_id, &device); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_context_create_by_device(const char* device_name, int device_id, mmdeploy_context_t* context) +{ + mmdeploy_device_t device{}; + int ec = MMDEPLOY_SUCCESS; + mmdeploy_context_t _context{}; + ec = mmdeploy_context_create(&_context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_device_create(device_name, device_id, &device); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_context_add(_context, MMDEPLOY_TYPE_DEVICE, nullptr, device); + mmdeploy_device_destroy(device); + if (ec == MMDEPLOY_SUCCESS) + { + *context = _context; + } return ec; - } - ec = mmdeploy_context_add(_context, MMDEPLOY_TYPE_DEVICE, nullptr, device); - mmdeploy_device_destroy(device); - if (ec == MMDEPLOY_SUCCESS) { - *context = _context; - } - return ec; } -void mmdeploy_context_destroy(mmdeploy_context_t context) { delete Cast(context); } +void mmdeploy_context_destroy(mmdeploy_context_t context) +{ + delete Cast(context); +} -int mmdeploy_common_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value) { - if (mat_count && mats == nullptr) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - auto input = std::make_unique(Value{Value::kArray}); - for (int i = 0; i < mat_count; ++i) { - input->front().push_back({{"ori_img", Cast(mats[i])}}); +int mmdeploy_common_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value) +{ + if (mat_count && mats == nullptr) + { + return MMDEPLOY_E_INVALID_ARG; } - *value = Cast(input.release()); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_SUCCESS; -} -int mmdeploy_device_create(const char* device_name, int device_id, mmdeploy_device_t* device) { - Device tmp(device_name, device_id); - if (tmp.platform_id() == -1) { - MMDEPLOY_ERROR("Device \"{}\" not found", device_name); - return MMDEPLOY_E_INVALID_ARG; - } - *device = (mmdeploy_device_t) new Device(tmp); - return MMDEPLOY_SUCCESS; + try + { + auto input = std::make_unique(Value{Value::kArray}); + for (int i = 0; i < mat_count; ++i) + { + input->front().push_back({{"ori_img", Cast(mats[i])}}); + } + *value = Cast(input.release()); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + + return MMDEPLOY_SUCCESS; } -void mmdeploy_device_destroy(mmdeploy_device_t device) { delete (Device*)device; } +int mmdeploy_device_create(const char* device_name, int device_id, mmdeploy_device_t* device) +{ + Device tmp(device_name, device_id); + if (tmp.platform_id() == -1) + { + MMDEPLOY_ERROR("Device \"{}\" not found", device_name); + return MMDEPLOY_E_INVALID_ARG; + } + *device = (mmdeploy_device_t) new Device(tmp); + return MMDEPLOY_SUCCESS; +} -int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler) { - *profiler = (mmdeploy_profiler_t) new profiler::Profiler(path); - return MMDEPLOY_SUCCESS; +void mmdeploy_device_destroy(mmdeploy_device_t device) +{ + delete (Device*)device; } -void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler) { - if (profiler) { - auto p = (profiler::Profiler*)profiler; - p->Release(); - delete p; - } +int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler) +{ + *profiler = (mmdeploy_profiler_t) new profiler::Profiler(path); + return MMDEPLOY_SUCCESS; } -int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t type, const char* name, - const void* object) { - auto& ctx = *Cast(context); - switch (type) { - case MMDEPLOY_TYPE_DEVICE: { - const auto& device = *(Device*)object; - ctx["device"] = device; - ctx["stream"] = Stream(device); - break; +void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler) +{ + if (profiler) + { + auto p = (profiler::Profiler*)profiler; + p->Release(); + delete p; } - case MMDEPLOY_TYPE_SCHEDULER: - ctx["scheduler"][name] = *Cast((const mmdeploy_scheduler_t)object); - break; - case MMDEPLOY_TYPE_MODEL: - ctx["model"][name] = *Cast((const mmdeploy_model_t)object); - break; - case MMDEPLOY_TYPE_PROFILER: { - const auto& profiler = *(profiler::Profiler*)object; - profiler::Scope* root(profiler.scope()); - ctx["scope"] = root; - break; +} + +int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t type, const char* name, const void* object) +{ + auto& ctx = *Cast(context); + switch (type) + { + case MMDEPLOY_TYPE_DEVICE: + { + const auto& device = *(Device*)object; + ctx["device"] = device; + ctx["stream"] = Stream(device); + break; + } + case MMDEPLOY_TYPE_SCHEDULER: + ctx["scheduler"][name] = *Cast((const mmdeploy_scheduler_t)object); + break; + case MMDEPLOY_TYPE_MODEL: + ctx["model"][name] = *Cast((const mmdeploy_model_t)object); + break; + case MMDEPLOY_TYPE_PROFILER: + { + const auto& profiler = *(profiler::Profiler*)object; + profiler::Scope* root(profiler.scope()); + ctx["scope"] = root; + break; + } + default: + return MMDEPLOY_E_NOT_SUPPORTED; } - default: - return MMDEPLOY_E_NOT_SUPPORTED; - } - return 0; + return 0; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/common.h b/csrc/mmdeploy/apis/c/mmdeploy/common.h index c665134cbf..26b92973ca 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/common.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/common.h @@ -6,19 +6,19 @@ #include // NOLINT #ifndef MMDEPLOY_EXPORT -#ifdef _MSC_VER -#define MMDEPLOY_EXPORT __declspec(dllexport) -#else -#define MMDEPLOY_EXPORT __attribute__((visibility("default"))) -#endif + #ifdef _MSC_VER + #define MMDEPLOY_EXPORT __declspec(dllexport) + #else + #define MMDEPLOY_EXPORT __attribute__((visibility("default"))) + #endif #endif #ifndef MMDEPLOY_API -#ifdef MMDEPLOY_API_EXPORTS -#define MMDEPLOY_API MMDEPLOY_EXPORT -#else -#define MMDEPLOY_API -#endif + #ifdef MMDEPLOY_API_EXPORTS + #define MMDEPLOY_API MMDEPLOY_EXPORT + #else + #define MMDEPLOY_API + #endif #endif // clang-format off @@ -54,136 +54,137 @@ typedef enum mmdeploy_status_t { // clang-format on -typedef struct mmdeploy_device* mmdeploy_device_t; +typedef struct mmdeploy_device* mmdeploy_device_t; typedef struct mmdeploy_profiler* mmdeploy_profiler_t; -typedef struct mmdeploy_mat_t { - uint8_t* data; - int height; - int width; - int channel; - mmdeploy_pixel_format_t format; - mmdeploy_data_type_t type; - mmdeploy_device_t device; +typedef struct mmdeploy_mat_t +{ + uint8_t* data; + int height; + int width; + int channel; + mmdeploy_pixel_format_t format; + mmdeploy_data_type_t type; + mmdeploy_device_t device; } mmdeploy_mat_t; -typedef struct mmdeploy_rect_t { - float left; - float top; - float right; - float bottom; +typedef struct mmdeploy_rect_t +{ + float left; + float top; + float right; + float bottom; } mmdeploy_rect_t; -typedef struct mmdeploy_point_t { - float x; - float y; +typedef struct mmdeploy_point_t +{ + float x; + float y; } mmdeploy_point_t; -typedef struct mmdeploy_value* mmdeploy_value_t; +typedef struct mmdeploy_value* mmdeploy_value_t; typedef struct mmdeploy_context* mmdeploy_context_t; -typedef enum mmdeploy_context_type_t { - MMDEPLOY_TYPE_DEVICE = 0, - MMDEPLOY_TYPE_STREAM = 1, - MMDEPLOY_TYPE_MODEL = 2, - MMDEPLOY_TYPE_SCHEDULER = 3, - MMDEPLOY_TYPE_MAT = 4, - MMDEPLOY_TYPE_PROFILER = 5, +typedef enum mmdeploy_context_type_t +{ + MMDEPLOY_TYPE_DEVICE = 0, + MMDEPLOY_TYPE_STREAM = 1, + MMDEPLOY_TYPE_MODEL = 2, + MMDEPLOY_TYPE_SCHEDULER = 3, + MMDEPLOY_TYPE_MAT = 4, + MMDEPLOY_TYPE_PROFILER = 5, } mmdeploy_context_type_t; #if __cplusplus -extern "C" { +extern "C" +{ #endif -/** - * Copy value - * @param value - * @return - */ -MMDEPLOY_API mmdeploy_value_t mmdeploy_value_copy(mmdeploy_value_t value); - -/** - * Destroy value - * @param value - */ -MMDEPLOY_API void mmdeploy_value_destroy(mmdeploy_value_t value); - -/** - * Create device handle - * @param device_name - * @param device_id - * @param device - * @return - */ -MMDEPLOY_API int mmdeploy_device_create(const char* device_name, int device_id, - mmdeploy_device_t* device); - -/** - * Destroy device handle - * @param device - */ -MMDEPLOY_API void mmdeploy_device_destroy(mmdeploy_device_t device); - -/** - * Create profiler - * @param path path to save the profile data - * @param profiler handle for profiler, should be added to context and deleted by - * mmdeploy_profiler_destroy - * @return status of create - */ -MMDEPLOY_API int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler); - -/** - * Destroy profiler handle - * @param profiler handle for profiler, profile data will be written to disk after this call - */ -MMDEPLOY_API void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler); - -/** - * Create context - * @param context - * @return - */ -MMDEPLOY_API int mmdeploy_context_create(mmdeploy_context_t* context); - -/** - * Create context - * @param device_name - * @param device_id - * @param context - * @return - */ -MMDEPLOY_API int mmdeploy_context_create_by_device(const char* device_name, int device_id, - mmdeploy_context_t* context); - -/** - * Destroy context - * @param context - */ -MMDEPLOY_API void mmdeploy_context_destroy(mmdeploy_context_t context); - -/** - * Add context object - * @param context - * @param type - * @param name - * @param object - * @return - */ -MMDEPLOY_API int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t type, - const char* name, const void* object); - -/** - * Create input value from array of mats - * @param mats - * @param mat_count - * @param value - * @return - */ -MMDEPLOY_API int mmdeploy_common_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value); + /** + * Copy value + * @param value + * @return + */ + MMDEPLOY_API mmdeploy_value_t mmdeploy_value_copy(mmdeploy_value_t value); + + /** + * Destroy value + * @param value + */ + MMDEPLOY_API void mmdeploy_value_destroy(mmdeploy_value_t value); + + /** + * Create device handle + * @param device_name + * @param device_id + * @param device + * @return + */ + MMDEPLOY_API int mmdeploy_device_create(const char* device_name, int device_id, mmdeploy_device_t* device); + + /** + * Destroy device handle + * @param device + */ + MMDEPLOY_API void mmdeploy_device_destroy(mmdeploy_device_t device); + + /** + * Create profiler + * @param path path to save the profile data + * @param profiler handle for profiler, should be added to context and deleted by + * mmdeploy_profiler_destroy + * @return status of create + */ + MMDEPLOY_API int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler); + + /** + * Destroy profiler handle + * @param profiler handle for profiler, profile data will be written to disk after this call + */ + MMDEPLOY_API void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler); + + /** + * Create context + * @param context + * @return + */ + MMDEPLOY_API int mmdeploy_context_create(mmdeploy_context_t* context); + + /** + * Create context + * @param device_name + * @param device_id + * @param context + * @return + */ + MMDEPLOY_API int mmdeploy_context_create_by_device(const char* device_name, int device_id, mmdeploy_context_t* context); + + /** + * Destroy context + * @param context + */ + MMDEPLOY_API void mmdeploy_context_destroy(mmdeploy_context_t context); + + /** + * Add context object + * @param context + * @param type + * @param name + * @param object + * @return + */ + MMDEPLOY_API int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t type, const char* name, const void* object); + + /** + * Create input value from array of mats + * @param mats + * @param mat_count + * @param value + * @return + */ + MMDEPLOY_API int mmdeploy_common_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value); #if __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/common_internal.h b/csrc/mmdeploy/apis/c/mmdeploy/common_internal.h index a1ddecb54d..24a776d8be 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/common_internal.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/common_internal.h @@ -12,93 +12,160 @@ using namespace mmdeploy; -namespace { - -inline mmdeploy_value_t Cast(Value* s) { return reinterpret_cast(s); } - -inline Value* Cast(mmdeploy_value_t s) { return reinterpret_cast(s); } - -inline Value Take(mmdeploy_value_t v) { - auto value = std::move(*Cast(v)); - mmdeploy_value_destroy(v); - return value; -} - -inline Value* Cast(mmdeploy_context_t c) { return reinterpret_cast(c); } - -inline mmdeploy_value_t Take(Value v) { - return Cast(new Value(std::move(v))); // NOLINT -} - -inline mmdeploy_pipeline_t Cast(AsyncHandle* pipeline) { - return reinterpret_cast(pipeline); -} - -inline AsyncHandle* Cast(mmdeploy_pipeline_t pipeline) { - return reinterpret_cast(pipeline); -} - -inline mmdeploy_model_t Cast(Model* model) { return reinterpret_cast(model); } - -inline Model* Cast(mmdeploy_model_t model) { return reinterpret_cast(model); } - -inline Mat Cast(const mmdeploy_mat_t& mat) { - return Mat{mat.height, mat.width, PixelFormat(mat.format), - DataType(mat.type), mat.data, mat.device ? *(const Device*)mat.device : Device{0}}; -} - -template -std::invoke_result_t Guard(F f) { - try { - return f(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return nullptr; -} - -template -class wrapped {}; - -template -class wrapped> { - public: - wrapped() noexcept : v_(nullptr) {} - explicit wrapped(T v) noexcept : v_(v) {} - - void reset() { - if (v_) { - delete Cast(v_); - v_ = nullptr; +namespace +{ + + inline mmdeploy_value_t Cast(Value* s) + { + return reinterpret_cast(s); + } + + inline Value* Cast(mmdeploy_value_t s) + { + return reinterpret_cast(s); + } + + inline Value Take(mmdeploy_value_t v) + { + auto value = std::move(*Cast(v)); + mmdeploy_value_destroy(v); + return value; + } + + inline Value* Cast(mmdeploy_context_t c) + { + return reinterpret_cast(c); } - } - ~wrapped() { reset(); } + inline mmdeploy_value_t Take(Value v) + { + return Cast(new Value(std::move(v))); // NOLINT + } - wrapped(const wrapped&) = delete; - wrapped& operator=(const wrapped&) = delete; + inline mmdeploy_pipeline_t Cast(AsyncHandle* pipeline) + { + return reinterpret_cast(pipeline); + } - wrapped(wrapped&& other) noexcept : v_(other.release()) {} - wrapped& operator=(wrapped&& other) noexcept { - reset(); - v_ = other.release(); - return *this; - } + inline AsyncHandle* Cast(mmdeploy_pipeline_t pipeline) + { + return reinterpret_cast(pipeline); + } - T release() noexcept { return std::exchange(v_, nullptr); } + inline mmdeploy_model_t Cast(Model* model) + { + return reinterpret_cast(model); + } - auto operator*() { return Cast(v_); } - auto operator-> () { return Cast(v_); } + inline Model* Cast(mmdeploy_model_t model) + { + return reinterpret_cast(model); + } - T* ptr() noexcept { return &v_; } + inline Mat Cast(const mmdeploy_mat_t& mat) + { + return Mat{mat.height, + mat.width, + PixelFormat(mat.format), + DataType(mat.type), + mat.data, + mat.device ? *(const Device*)mat.device : Device{0}}; + } - operator T() const noexcept { return v_; } // NOLINT + template + std::invoke_result_t Guard(F f) + { + try + { + return f(); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + + return nullptr; + } - private: - T v_; -}; + template + class wrapped + { + }; + + template + class wrapped> + { + public: + wrapped() noexcept + : v_(nullptr) + { + } + explicit wrapped(T v) noexcept + : v_(v) + { + } + + void reset() + { + if (v_) + { + delete Cast(v_); + v_ = nullptr; + } + } + + ~wrapped() + { + reset(); + } + + wrapped(const wrapped&) = delete; + wrapped& operator=(const wrapped&) = delete; + + wrapped(wrapped&& other) noexcept + : v_(other.release()) + { + } + + wrapped& operator=(wrapped&& other) noexcept + { + reset(); + v_ = other.release(); + return *this; + } + + T release() noexcept + { + return std::exchange(v_, nullptr); + } + + auto operator*() + { + return Cast(v_); + } + + auto operator->() + { + return Cast(v_); + } + + T* ptr() noexcept + { + return &v_; + } + + operator T() const noexcept + { + return v_; + } // NOLINT + + private: + T v_; + }; } // namespace diff --git a/csrc/mmdeploy/apis/c/mmdeploy/detector.cpp b/csrc/mmdeploy/apis/c/mmdeploy/detector.cpp index aadf92fb62..6ad627be50 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/detector.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/detector.cpp @@ -24,126 +24,143 @@ using ResultType = mmdeploy::Structure, // std::vector>; // -int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_detector_t* detector) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_detector_t* detector) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_detector_create_v2(model, context, detector); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_detector_create_v2(model, context, detector); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_detector_t* detector) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); +int mmdeploy_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_detector_t* detector) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); } -int mmdeploy_detector_create_by_path(const char* model_path, const char* device_name, int device_id, - mmdeploy_detector_t* detector) { - mmdeploy_model_t model{}; +int mmdeploy_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_detector_t* detector) +{ + mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_detector_create(model, device_name, device_id, detector); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_detector_create(model, device_name, device_id, detector); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input) { - return mmdeploy_common_create_input(mats, mat_count, input); +int mmdeploy_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input) +{ + return mmdeploy_common_create_input(mats, mat_count, input); } -int mmdeploy_detector_apply(mmdeploy_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_detection_t** results, int** result_count) { - wrapped input; - if (auto ec = mmdeploy_detector_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_detector_apply_v2(detector, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_detector_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_detector_apply(mmdeploy_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_detection_t** results, int** result_count) +{ + wrapped input; + if (auto ec = mmdeploy_detector_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_detector_apply_v2(detector, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_detector_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_detector_apply_v2(mmdeploy_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_detector_apply_v2(mmdeploy_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_detector_apply_async(mmdeploy_detector_t detector, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_detector_apply_async(mmdeploy_detector_t detector, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_detector_get_result(mmdeploy_value_t output, mmdeploy_detection_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value& value = Cast(output)->front(); - auto detector_outputs = from_value>(value); - - vector _result_count(detector_outputs.size()); - size_t total = 0; - for (size_t i = 0; i < detector_outputs.size(); ++i) { - _result_count[i] = static_cast(detector_outputs[i].size()); - total += detector_outputs[i].size(); +int mmdeploy_detector_get_result(mmdeploy_value_t output, mmdeploy_detection_t** results, int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } - ResultType r({total, 1, 1, 1}); - auto [result_data, result_count_vec, masks, buffers] = r.pointers(); - - auto result_ptr = result_data; - - for (const auto& det_output : detector_outputs) { - for (const auto& detection : det_output) { - result_ptr->label_id = detection.label_id; - result_ptr->score = detection.score; - const auto& bbox = detection.bbox; - result_ptr->bbox = {bbox[0], bbox[1], bbox[2], bbox[3]}; - auto mask_byte_size = detection.mask.byte_size(); - if (mask_byte_size) { - auto& mask = detection.mask; - result_ptr->mask = &masks->emplace_back(); - buffers->push_back(mask.buffer()); - result_ptr->mask->data = mask.data(); - result_ptr->mask->width = mask.width(); - result_ptr->mask->height = mask.height(); + try + { + Value& value = Cast(output)->front(); + auto detector_outputs = from_value>(value); + + vector _result_count(detector_outputs.size()); + size_t total = 0; + for (size_t i = 0; i < detector_outputs.size(); ++i) + { + _result_count[i] = static_cast(detector_outputs[i].size()); + total += detector_outputs[i].size(); } - ++result_ptr; - } - } - *result_count_vec = std::move(_result_count); - *result_count = result_count_vec->data(); - *results = result_data; - r.release(); + ResultType r({total, 1, 1, 1}); + auto [result_data, result_count_vec, masks, buffers] = r.pointers(); + + auto result_ptr = result_data; + + for (const auto& det_output : detector_outputs) + { + for (const auto& detection : det_output) + { + result_ptr->label_id = detection.label_id; + result_ptr->score = detection.score; + const auto& bbox = detection.bbox; + result_ptr->bbox = {bbox[0], bbox[1], bbox[2], bbox[3]}; + auto mask_byte_size = detection.mask.byte_size(); + if (mask_byte_size) + { + auto& mask = detection.mask; + result_ptr->mask = &masks->emplace_back(); + buffers->push_back(mask.buffer()); + result_ptr->mask->data = mask.data(); + result_ptr->mask->width = mask.width(); + result_ptr->mask->height = mask.height(); + } + ++result_ptr; + } + } - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + *result_count_vec = std::move(_result_count); + *result_count = result_count_vec->data(); + *results = result_data; + r.release(); + + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -void mmdeploy_detector_release_result(mmdeploy_detection_t* results, const int* result_count, - int count) { - auto num_dets = std::accumulate(result_count, result_count + count, 0); - ResultType deleter({static_cast(num_dets), 1, 1, 1}, results); +void mmdeploy_detector_release_result(mmdeploy_detection_t* results, const int* result_count, int count) +{ + auto num_dets = std::accumulate(result_count, result_count + count, 0); + ResultType deleter({static_cast(num_dets), 1, 1, 1}, results); } -void mmdeploy_detector_destroy(mmdeploy_detector_t detector) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); +void mmdeploy_detector_destroy(mmdeploy_detector_t detector) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/detector.h b/csrc/mmdeploy/apis/c/mmdeploy/detector.h index 5c5ba2f356..713214ca4f 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/detector.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/detector.h @@ -13,124 +13,123 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_instance_mask_t { - char* data; - int height; - int width; -} mmdeploy_instance_mask_t; - -typedef struct mmdeploy_detection_t { - int label_id; - float score; - mmdeploy_rect_t bbox; - mmdeploy_instance_mask_t* mask; -} mmdeploy_detection_t; - -typedef struct mmdeploy_detector* mmdeploy_detector_t; - -/** - * @brief Create detector's handle - * @param[in] model an instance of mmdetection sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a detector - * @return status of creating detector's handle - */ -MMDEPLOY_API int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_detector_t* detector); - -/** - * @brief Create detector's handle - * @param[in] model_path path of mmdetection sdk model exported by mmdeploy model converter - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a detector - * @return status of creating detector's handle - */ -MMDEPLOY_API int mmdeploy_detector_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_detector_t* detector); - -/** - * @brief Apply detector to batch images and get their inference results - * @param[in] detector detector's handle created by \ref mmdeploy_detector_create_by_path - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer to save detection results of each image. It must be released - * by \ref mmdeploy_detector_release_result - * @param[out] result_count a linear buffer with length being \p mat_count to save the number of - * detection results of each image. And it must be released by \ref - * mmdeploy_detector_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_detector_apply(mmdeploy_detector_t detector, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_detection_t** results, - int** result_count); - -/** @brief Release the inference result buffer created by \ref mmdeploy_detector_apply - * @param[in] results detection results buffer - * @param[in] result_count \p results size buffer - * @param[in] count length of \p result_count - */ -MMDEPLOY_API void mmdeploy_detector_release_result(mmdeploy_detection_t* results, - const int* result_count, int count); - -/** - * @brief Destroy detector's handle - * @param[in] detector detector's handle created by \ref mmdeploy_detector_create_by_path - */ -MMDEPLOY_API void mmdeploy_detector_destroy(mmdeploy_detector_t detector); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_detector_create, but allows to control execution context of tasks - * via context - */ -MMDEPLOY_API int mmdeploy_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_detector_t* detector); - -/** - * @brief Pack detector inputs into mmdeploy_value_t - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @return the created value - */ -MMDEPLOY_API int mmdeploy_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input); - -/** - * @brief Same as \ref mmdeploy_detector_apply, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_detector_apply_v2(mmdeploy_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output); - -/** - * @brief Apply detector asynchronously - * @param[in] detector handle to the detector - * @param[in] input input sender - * @return output sender - */ -MMDEPLOY_API int mmdeploy_detector_apply_async(mmdeploy_detector_t detector, - mmdeploy_sender_t input, mmdeploy_sender_t* output); - -/** - * @brief Unpack detector output from a mmdeploy_value_t - * @param[in] output output obtained by applying a detector - * @param[out] results a linear buffer to save detection results of each image. It must be released - * by \ref mmdeploy_detector_release_result - * @param[out] result_count a linear buffer with length number of input images to save the number of - * detection results of each image. Must be released by \ref - * mmdeploy_detector_release_result - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_detector_get_result(mmdeploy_value_t output, - mmdeploy_detection_t** results, int** result_count); + typedef struct mmdeploy_instance_mask_t + { + char* data; + int height; + int width; + } mmdeploy_instance_mask_t; + + typedef struct mmdeploy_detection_t + { + int label_id; + float score; + mmdeploy_rect_t bbox; + mmdeploy_instance_mask_t* mask; + } mmdeploy_detection_t; + + typedef struct mmdeploy_detector* mmdeploy_detector_t; + + /** + * @brief Create detector's handle + * @param[in] model an instance of mmdetection sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a detector + * @return status of creating detector's handle + */ + MMDEPLOY_API int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_detector_t* detector); + + /** + * @brief Create detector's handle + * @param[in] model_path path of mmdetection sdk model exported by mmdeploy model converter + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a detector + * @return status of creating detector's handle + */ + MMDEPLOY_API int mmdeploy_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_detector_t* detector); + + /** + * @brief Apply detector to batch images and get their inference results + * @param[in] detector detector's handle created by \ref mmdeploy_detector_create_by_path + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer to save detection results of each image. It must be released + * by \ref mmdeploy_detector_release_result + * @param[out] result_count a linear buffer with length being \p mat_count to save the number of + * detection results of each image. And it must be released by \ref + * mmdeploy_detector_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_detector_apply(mmdeploy_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_detection_t** results, int** result_count); + + /** @brief Release the inference result buffer created by \ref mmdeploy_detector_apply + * @param[in] results detection results buffer + * @param[in] result_count \p results size buffer + * @param[in] count length of \p result_count + */ + MMDEPLOY_API void mmdeploy_detector_release_result(mmdeploy_detection_t* results, + const int* result_count, + int count); + + /** + * @brief Destroy detector's handle + * @param[in] detector detector's handle created by \ref mmdeploy_detector_create_by_path + */ + MMDEPLOY_API void mmdeploy_detector_destroy(mmdeploy_detector_t detector); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_detector_create, but allows to control execution context of tasks + * via context + */ + MMDEPLOY_API int mmdeploy_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_detector_t* detector); + + /** + * @brief Pack detector inputs into mmdeploy_value_t + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @return the created value + */ + MMDEPLOY_API int mmdeploy_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input); + + /** + * @brief Same as \ref mmdeploy_detector_apply, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_detector_apply_v2(mmdeploy_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output); + + /** + * @brief Apply detector asynchronously + * @param[in] detector handle to the detector + * @param[in] input input sender + * @return output sender + */ + MMDEPLOY_API int mmdeploy_detector_apply_async(mmdeploy_detector_t detector, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + /** + * @brief Unpack detector output from a mmdeploy_value_t + * @param[in] output output obtained by applying a detector + * @param[out] results a linear buffer to save detection results of each image. It must be released + * by \ref mmdeploy_detector_release_result + * @param[out] result_count a linear buffer with length number of input images to save the number of + * detection results of each image. Must be released by \ref + * mmdeploy_detector_release_result + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_detector_get_result(mmdeploy_value_t output, + mmdeploy_detection_t** results, + int** result_count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/executor.cpp b/csrc/mmdeploy/apis/c/mmdeploy/executor.cpp index 2fdfb9091f..0de722b58c 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/executor.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/executor.cpp @@ -9,199 +9,283 @@ using namespace mmdeploy; -namespace { +namespace +{ -mmdeploy_scheduler_t CreateScheduler(const char* type, const Value& config = Value()) { - try { - auto creator = gRegistry().Get(type); - if (!creator) { - MMDEPLOY_ERROR("Creator for {} not found. Available schedulers: {}", type, - gRegistry().List()); - return nullptr; + mmdeploy_scheduler_t CreateScheduler(const char* type, const Value& config = Value()) + { + try + { + auto creator = gRegistry().Get(type); + if (!creator) + { + MMDEPLOY_ERROR("Creator for {} not found. Available schedulers: {}", + type, + gRegistry().List()); + return nullptr; + } + return Cast(new SchedulerType(creator->Create(config))); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("failed to create Scheduler: {} ({}), config: {}", type, e.what(), config); + return nullptr; + } } - return Cast(new SchedulerType(creator->Create(config))); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("failed to create Scheduler: {} ({}), config: {}", type, e.what(), config); - return nullptr; - } -} } // namespace -mmdeploy_sender_t mmdeploy_sender_copy(mmdeploy_sender_t input) { - if (!input) { - return nullptr; - } - return Take(SenderType(*Cast(input))); +mmdeploy_sender_t mmdeploy_sender_copy(mmdeploy_sender_t input) +{ + if (!input) + { + return nullptr; + } + return Take(SenderType(*Cast(input))); } -int mmdeploy_sender_destroy(mmdeploy_sender_t sender) { - delete Cast(sender); - return 0; +int mmdeploy_sender_destroy(mmdeploy_sender_t sender) +{ + delete Cast(sender); + return 0; } -mmdeploy_scheduler_t mmdeploy_executor_inline() { return CreateScheduler("Inline"); } +mmdeploy_scheduler_t mmdeploy_executor_inline() +{ + return CreateScheduler("Inline"); +} -mmdeploy_scheduler_t mmdeploy_executor_system_pool() { - // create a thread pool context and hold its shared handle - static auto scheduler = *Cast(CreateScheduler("ThreadPool")); - // return a copy of the handle to the thread pool - return Cast(new SchedulerType(scheduler)); +mmdeploy_scheduler_t mmdeploy_executor_system_pool() +{ + // create a thread pool context and hold its shared handle + static auto scheduler = *Cast(CreateScheduler("ThreadPool")); + // return a copy of the handle to the thread pool + return Cast(new SchedulerType(scheduler)); } -mmdeploy_scheduler_t mmdeploy_executor_create_thread_pool(int num_threads) { - return CreateScheduler("ThreadPool", {{"num_threads", num_threads}}); +mmdeploy_scheduler_t mmdeploy_executor_create_thread_pool(int num_threads) +{ + return CreateScheduler("ThreadPool", {{"num_threads", num_threads}}); } -mmdeploy_scheduler_t mmdeploy_executor_create_thread() { return CreateScheduler("SingleThread"); } +mmdeploy_scheduler_t mmdeploy_executor_create_thread() +{ + return CreateScheduler("SingleThread"); +} mmdeploy_scheduler_t mmdeploy_executor_dynamic_batch(mmdeploy_scheduler_t scheduler, - int max_batch_size, int timeout) { - if (!scheduler) { - return nullptr; - } - return CreateScheduler( - "DynamicBatch", - {{"scheduler", *Cast(scheduler)}, {"max_batch_size", max_batch_size}, {"timeout", timeout}}); + int max_batch_size, + int timeout) +{ + if (!scheduler) + { + return nullptr; + } + return CreateScheduler("DynamicBatch", + {{"scheduler", *Cast(scheduler)}, + {"max_batch_size", max_batch_size}, + {"timeout", timeout}}); } -int mmdeploy_scheduler_destroy(mmdeploy_scheduler_t scheduler) { - delete Cast(scheduler); - return 0; +int mmdeploy_scheduler_destroy(mmdeploy_scheduler_t scheduler) +{ + delete Cast(scheduler); + return 0; } -mmdeploy_sender_t mmdeploy_executor_just(mmdeploy_value_t value) { - if (value) { - return Guard([&] { return Take(Just(*Cast(value))); }); - } else { - return Take(Just(Value())); - } +mmdeploy_sender_t mmdeploy_executor_just(mmdeploy_value_t value) +{ + if (value) + { + return Guard([&] + { return Take(Just(*Cast(value))); }); + } + else + { + return Take(Just(Value())); + } } -mmdeploy_sender_t mmdeploy_executor_schedule(mmdeploy_scheduler_t scheduler) { - if (!scheduler) { - return nullptr; - } - return Guard([&] { return Take(Then(Schedule(*Cast(scheduler)), [] { return Value(); })); }); +mmdeploy_sender_t mmdeploy_executor_schedule(mmdeploy_scheduler_t scheduler) +{ + if (!scheduler) + { + return nullptr; + } + return Guard([&] + { return Take(Then(Schedule(*Cast(scheduler)), + [] + { + return Value(); + })); }); } mmdeploy_sender_t mmdeploy_executor_transfer_just(mmdeploy_scheduler_t scheduler, - mmdeploy_value_t value) { - if (!scheduler || !value) { - return nullptr; - } - return Guard([&] { return Take(TransferJust(*Cast(scheduler), *Cast(value))); }); -} - -mmdeploy_sender_t mmdeploy_executor_transfer(mmdeploy_sender_t input, - mmdeploy_scheduler_t scheduler) { - if (!input || !scheduler) { - return nullptr; - } - return Guard([&] { return Take(Transfer(Take(input), *Cast(scheduler))); }); -} - -mmdeploy_sender_t mmdeploy_executor_on(mmdeploy_scheduler_t scheduler, mmdeploy_sender_t input) { - if (!scheduler || !input) { - return nullptr; - } - return Guard([&] { return Take(On(*Cast(scheduler), Take(input))); }); -} - -mmdeploy_sender_t mmdeploy_executor_then(mmdeploy_sender_t input, mmdeploy_then_fn_t fn, - void* context) { - if (!input || !fn) { - return nullptr; - } - return Guard([&] { - return Take(Then(Take(input), [fn, context](Value args) { - auto out = Cast(fn(Take(std::move(args)), context)); - Value ret(std::move(*out)); - delete out; - return ret; - })); - }); -} - -mmdeploy_sender_t mmdeploy_executor_let_value(mmdeploy_sender_t input, mmdeploy_let_value_fn_t fn, - void* context) { - if (!input || !fn) { - return nullptr; - } - return Guard([&] { - return Take(LetValue(Take(input), [fn, context](Value& args) { - auto out = Cast(fn(Cast(&args), context)); - SenderType ret(std::move(*out)); - delete out; - return ret; - })); - }); -} - -mmdeploy_sender_t mmdeploy_executor_split(mmdeploy_sender_t input) { - if (!input) { - return nullptr; - } - return Guard([&] { return Take(Split(Take(input))); }); -} - -mmdeploy_sender_t mmdeploy_executor_when_all(mmdeploy_sender_t inputs[], int32_t n) { - if (!inputs) { - return nullptr; - } - return Guard([&] { - std::vector senders; - senders.reserve(n); - for (int i = 0; i < n; ++i) { - senders.emplace_back(Take(inputs[i])); - } - return Take( - Then(WhenAll(std::move(senders)), [](Value::Array&& v) { return Value(std::move(v)); })); - }); -} - -mmdeploy_sender_t mmdeploy_executor_ensure_started(mmdeploy_sender_t input) { - if (!input) { - return nullptr; - } - return Guard([&] { return Take(EnsureStarted(Take(input))); }); -} - -int mmdeploy_executor_start_detached(mmdeploy_sender_t input) { - if (!input) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - StartDetached(Take(input)); - return 0; - } catch (...) { - } - return MMDEPLOY_E_FAIL; + mmdeploy_value_t value) +{ + if (!scheduler || !value) + { + return nullptr; + } + return Guard([&] + { return Take(TransferJust(*Cast(scheduler), *Cast(value))); }); +} + +mmdeploy_sender_t mmdeploy_executor_transfer(mmdeploy_sender_t input, + mmdeploy_scheduler_t scheduler) +{ + if (!input || !scheduler) + { + return nullptr; + } + return Guard([&] + { return Take(Transfer(Take(input), *Cast(scheduler))); }); +} + +mmdeploy_sender_t mmdeploy_executor_on(mmdeploy_scheduler_t scheduler, mmdeploy_sender_t input) +{ + if (!scheduler || !input) + { + return nullptr; + } + return Guard([&] + { return Take(On(*Cast(scheduler), Take(input))); }); +} + +mmdeploy_sender_t mmdeploy_executor_then(mmdeploy_sender_t input, mmdeploy_then_fn_t fn, void* context) +{ + if (!input || !fn) + { + return nullptr; + } + return Guard([&] + { return Take(Then(Take(input), + [fn, context](Value args) + { + auto out = Cast(fn(Take(std::move(args)), context)); + Value ret(std::move(*out)); + delete out; + return ret; + })); }); +} + +mmdeploy_sender_t mmdeploy_executor_let_value(mmdeploy_sender_t input, mmdeploy_let_value_fn_t fn, void* context) +{ + if (!input || !fn) + { + return nullptr; + } + return Guard([&] + { return Take(LetValue(Take(input), + [fn, context](Value& args) + { + auto out = Cast(fn(Cast(&args), context)); + SenderType ret(std::move(*out)); + delete out; + return ret; + })); }); } -mmdeploy_value_t mmdeploy_executor_sync_wait(mmdeploy_sender_t input) { - if (!input) { - return nullptr; - } - return Guard([&] { return Take(std::get(SyncWait(Take(input)))); }); +mmdeploy_sender_t mmdeploy_executor_split(mmdeploy_sender_t input) +{ + if (!input) + { + return nullptr; + } + return Guard([&] + { return Take(Split(Take(input))); }); +} + +mmdeploy_sender_t mmdeploy_executor_when_all(mmdeploy_sender_t inputs[], int32_t n) +{ + if (!inputs) + { + return nullptr; + } + return Guard([&] + { + std::vector senders; + senders.reserve(n); + for (int i = 0; i < n; ++i) + { + senders.emplace_back(Take(inputs[i])); + } + return Take(Then(WhenAll(std::move(senders)), + [](Value::Array&& v) + { + return Value(std::move(v)); + })); }); +} + +mmdeploy_sender_t mmdeploy_executor_ensure_started(mmdeploy_sender_t input) +{ + if (!input) + { + return nullptr; + } + return Guard([&] + { return Take(EnsureStarted(Take(input))); }); } -int mmdeploy_executor_sync_wait_v2(mmdeploy_sender_t sender, mmdeploy_value_t* value) { - if (!sender) { - return MMDEPLOY_E_INVALID_ARG; - } - auto result = mmdeploy_executor_sync_wait(sender); - if (!result) { +int mmdeploy_executor_start_detached(mmdeploy_sender_t input) +{ + if (!input) + { + return MMDEPLOY_E_INVALID_ARG; + } + + try + { + StartDetached(Take(input)); + return 0; + } + catch (...) + { + } + return MMDEPLOY_E_FAIL; - } - if (value) { - *value = result; - } else { - mmdeploy_value_destroy(result); - } - return MMDEPLOY_SUCCESS; } -void mmdeploy_executor_execute(mmdeploy_scheduler_t scheduler, void (*fn)(void*), void* context) { - Execute(*Cast(scheduler), [fn, context] { fn(context); }); +mmdeploy_value_t mmdeploy_executor_sync_wait(mmdeploy_sender_t input) +{ + if (!input) + { + return nullptr; + } + return Guard([&] + { return Take(std::get(SyncWait(Take(input)))); }); +} + +int mmdeploy_executor_sync_wait_v2(mmdeploy_sender_t sender, mmdeploy_value_t* value) +{ + if (!sender) + { + return MMDEPLOY_E_INVALID_ARG; + } + + auto result = mmdeploy_executor_sync_wait(sender); + if (!result) + { + return MMDEPLOY_E_FAIL; + } + + if (value) + { + *value = result; + } + else + { + mmdeploy_value_destroy(result); + } + + return MMDEPLOY_SUCCESS; +} + +void mmdeploy_executor_execute(mmdeploy_scheduler_t scheduler, void (*fn)(void*), void* context) +{ + Execute(*Cast(scheduler), + [fn, context] + { + fn(context); + }); } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/executor.h b/csrc/mmdeploy/apis/c/mmdeploy/executor.h index a2c8ffa387..4b044a6b51 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/executor.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/executor.h @@ -6,133 +6,135 @@ #include "mmdeploy/common.h" #if __cplusplus -extern "C" { +extern "C" +{ #endif -/****************************************************************************** - * Experimental asynchronous APIs */ + /****************************************************************************** + * Experimental asynchronous APIs */ -typedef mmdeploy_value_t (*mmdeploy_then_fn_t)(mmdeploy_value_t, void*); + typedef mmdeploy_value_t (*mmdeploy_then_fn_t)(mmdeploy_value_t, void*); -typedef mmdeploy_value_t (*mmdeploy_then_fn_v2_t)(mmdeploy_value_t*, void*); - -typedef int (*mmdeploy_then_fn_v3_t)(mmdeploy_value_t* input, mmdeploy_value_t* output, void*); + typedef mmdeploy_value_t (*mmdeploy_then_fn_v2_t)(mmdeploy_value_t*, void*); + + typedef int (*mmdeploy_then_fn_v3_t)(mmdeploy_value_t* input, mmdeploy_value_t* output, void*); + + struct mmdeploy_sender; + struct mmdeploy_scheduler; + + typedef struct mmdeploy_sender* mmdeploy_sender_t; + typedef struct mmdeploy_scheduler* mmdeploy_scheduler_t; -struct mmdeploy_sender; -struct mmdeploy_scheduler; + typedef mmdeploy_sender_t (*mmdeploy_let_value_fn_t)(mmdeploy_value_t, void*); -typedef struct mmdeploy_sender* mmdeploy_sender_t; -typedef struct mmdeploy_scheduler* mmdeploy_scheduler_t; + /////////////////////////////////////////////////////////////////////////////// + // Scheduler + /////////////////////////////////////////////////////////////////////////////// + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_inline(); -typedef mmdeploy_sender_t (*mmdeploy_let_value_fn_t)(mmdeploy_value_t, void*); + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_system_pool(); -/////////////////////////////////////////////////////////////////////////////// -// Scheduler -/////////////////////////////////////////////////////////////////////////////// -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_inline(); + /** + * Create a thread pool with the given number of worker threads + * @param[in] num_threads + * @return the handle to the created thread pool + */ + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_create_thread_pool(int num_threads); -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_system_pool(); + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_create_thread(); -/** - * Create a thread pool with the given number of worker threads - * @param[in] num_threads - * @return the handle to the created thread pool - */ -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_create_thread_pool(int num_threads); + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_dynamic_batch(mmdeploy_scheduler_t scheduler, + int max_batch_size, + int timeout); -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_create_thread(); + MMDEPLOY_API int mmdeploy_scheduler_destroy(mmdeploy_scheduler_t scheduler); -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_dynamic_batch(mmdeploy_scheduler_t scheduler, - int max_batch_size, int timeout); + /////////////////////////////////////////////////////////////////////////////// + // Utilities + /////////////////////////////////////////////////////////////////////////////// -MMDEPLOY_API int mmdeploy_scheduler_destroy(mmdeploy_scheduler_t scheduler); + /** + * @brief Create a copy of a copyable sender. Only senders created by \ref mmdeploy_executor_split + * is copyable for now. + * @param[in] input copyable sender, + * @return the sender created, or nullptr if the sender is not copyable + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_sender_copy(mmdeploy_sender_t input); -/////////////////////////////////////////////////////////////////////////////// -// Utilities -/////////////////////////////////////////////////////////////////////////////// + /** + * @brief Destroy a sender, notice that all sender adapters will consume input senders, only unused + * senders should be destroyed using this function. + * @param[in] input + */ + MMDEPLOY_API int mmdeploy_sender_destroy(mmdeploy_sender_t sender); -/** - * @brief Create a copy of a copyable sender. Only senders created by \ref mmdeploy_executor_split - * is copyable for now. - * @param[in] input copyable sender, - * @return the sender created, or nullptr if the sender is not copyable - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_sender_copy(mmdeploy_sender_t input); + /////////////////////////////////////////////////////////////////////////////// + // Sender factories + /////////////////////////////////////////////////////////////////////////////// -/** - * @brief Destroy a sender, notice that all sender adapters will consume input senders, only unused - * senders should be destroyed using this function. - * @param[in] input - */ -MMDEPLOY_API int mmdeploy_sender_destroy(mmdeploy_sender_t sender); + /** + * @brief Create a sender that sends the provided value + * @param[in] value + * @return created sender + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_just(mmdeploy_value_t value); -/////////////////////////////////////////////////////////////////////////////// -// Sender factories -/////////////////////////////////////////////////////////////////////////////// + /** + * @brief + * @param[in] scheduler + * @return the sender created + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_schedule(mmdeploy_scheduler_t scheduler); -/** - * @brief Create a sender that sends the provided value - * @param[in] value - * @return created sender - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_just(mmdeploy_value_t value); + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_transfer_just(mmdeploy_scheduler_t scheduler, + mmdeploy_value_t value); -/** - * @brief - * @param[in] scheduler - * @return the sender created - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_schedule(mmdeploy_scheduler_t scheduler); + /////////////////////////////////////////////////////////////////////////////// + // Sender adapters + /////////////////////////////////////////////////////////////////////////////// -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_transfer_just(mmdeploy_scheduler_t scheduler, - mmdeploy_value_t value); + /** + * Transfer the execution to the execution agent of the provided scheduler + * @param[in] input + * @param[in] scheduler + * @return the sender created + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_transfer(mmdeploy_sender_t input, + mmdeploy_scheduler_t scheduler); -/////////////////////////////////////////////////////////////////////////////// -// Sender adapters -/////////////////////////////////////////////////////////////////////////////// + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_on(mmdeploy_scheduler_t scheduler, + mmdeploy_sender_t input); -/** - * Transfer the execution to the execution agent of the provided scheduler - * @param[in] input - * @param[in] scheduler - * @return the sender created - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_transfer(mmdeploy_sender_t input, - mmdeploy_scheduler_t scheduler); + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_then(mmdeploy_sender_t input, + mmdeploy_then_fn_t fn, + void* context); -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_on(mmdeploy_scheduler_t scheduler, - mmdeploy_sender_t input); + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_let_value(mmdeploy_sender_t input, + mmdeploy_let_value_fn_t fn, + void* context); -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_then(mmdeploy_sender_t input, - mmdeploy_then_fn_t fn, void* context); - -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_let_value(mmdeploy_sender_t input, - mmdeploy_let_value_fn_t fn, - void* context); - -/** - * Convert the input sender into a sender that is copyable via \ref mmdeploy_sender_copy. Notice - * that this function doesn't make the sender multi-shot, it just return a sender that is copyable. - * @param[in] input - * @return the sender that is copyable - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_split(mmdeploy_sender_t input); - -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_when_all(mmdeploy_sender_t inputs[], int32_t n); - -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_ensure_started(mmdeploy_sender_t input); - -/////////////////////////////////////////////////////////////////////////////// -// Sender consumers -/////////////////////////////////////////////////////////////////////////////// -MMDEPLOY_API int mmdeploy_executor_start_detached(mmdeploy_sender_t input); - -MMDEPLOY_API mmdeploy_value_t mmdeploy_executor_sync_wait(mmdeploy_sender_t input); - -MMDEPLOY_API int mmdeploy_executor_sync_wait_v2(mmdeploy_sender_t input, mmdeploy_value_t* output); - -MMDEPLOY_API void mmdeploy_executor_execute(mmdeploy_scheduler_t scheduler, void (*fn)(void*), - void* context); + /** + * Convert the input sender into a sender that is copyable via \ref mmdeploy_sender_copy. Notice + * that this function doesn't make the sender multi-shot, it just return a sender that is copyable. + * @param[in] input + * @return the sender that is copyable + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_split(mmdeploy_sender_t input); + + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_when_all(mmdeploy_sender_t inputs[], int32_t n); + + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_ensure_started(mmdeploy_sender_t input); + + /////////////////////////////////////////////////////////////////////////////// + // Sender consumers + /////////////////////////////////////////////////////////////////////////////// + MMDEPLOY_API int mmdeploy_executor_start_detached(mmdeploy_sender_t input); + + MMDEPLOY_API mmdeploy_value_t mmdeploy_executor_sync_wait(mmdeploy_sender_t input); + + MMDEPLOY_API int mmdeploy_executor_sync_wait_v2(mmdeploy_sender_t input, mmdeploy_value_t* output); + + MMDEPLOY_API void mmdeploy_executor_execute(mmdeploy_scheduler_t scheduler, void (*fn)(void*), void* context); #if __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/executor_internal.h b/csrc/mmdeploy/apis/c/mmdeploy/executor_internal.h index 95f39fe009..0ae8c2a529 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/executor_internal.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/executor_internal.h @@ -8,33 +8,49 @@ using namespace mmdeploy; -using SenderType = TypeErasedSender; +using SenderType = TypeErasedSender; using SchedulerType = TypeErasedScheduler; -namespace { - -inline SchedulerType* Cast(mmdeploy_scheduler_t s) { return reinterpret_cast(s); } - -inline mmdeploy_scheduler_t Cast(SchedulerType* s) { - return reinterpret_cast(s); -} - -inline SenderType* Cast(mmdeploy_sender_t s) { return reinterpret_cast(s); } - -inline mmdeploy_sender_t Cast(SenderType* s) { return reinterpret_cast(s); } - -inline SenderType Take(mmdeploy_sender_t s) { - auto sender = std::move(*Cast(s)); - mmdeploy_sender_destroy(s); - return sender; -} - -inline mmdeploy_sender_t Take(SenderType s) { return Cast(new SenderType(std::move(s))); } - -template , int> = 0> -inline mmdeploy_sender_t Take(T& s) { - return Take(SenderType(std::move(s))); -} +namespace +{ + + inline SchedulerType* Cast(mmdeploy_scheduler_t s) + { + return reinterpret_cast(s); + } + + inline mmdeploy_scheduler_t Cast(SchedulerType* s) + { + return reinterpret_cast(s); + } + + inline SenderType* Cast(mmdeploy_sender_t s) + { + return reinterpret_cast(s); + } + + inline mmdeploy_sender_t Cast(SenderType* s) + { + return reinterpret_cast(s); + } + + inline SenderType Take(mmdeploy_sender_t s) + { + auto sender = std::move(*Cast(s)); + mmdeploy_sender_destroy(s); + return sender; + } + + inline mmdeploy_sender_t Take(SenderType s) + { + return Cast(new SenderType(std::move(s))); + } + + template, int> = 0> + inline mmdeploy_sender_t Take(T& s) + { + return Take(SenderType(std::move(s))); + } } // namespace diff --git a/csrc/mmdeploy/apis/c/mmdeploy/handle.h b/csrc/mmdeploy/apis/c/mmdeploy/handle.h index 006ddaae3d..d2ccde1ef5 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/handle.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/handle.h @@ -11,42 +11,53 @@ #include "mmdeploy/graph/common.h" #include "mmdeploy/graph/static_router.h" -namespace mmdeploy { - -using namespace framework; - -namespace { - -class AsyncHandle { - public: - AsyncHandle(const char* device_name, int device_id, Value config) - : AsyncHandle(SetContext(std::move(config), device_name, device_id)) {} - - explicit AsyncHandle(const Value& config) { - if (auto builder = graph::Builder::CreateFromConfig(config).value()) { - node_ = builder->Build().value(); - } else { - MMDEPLOY_ERROR("failed to find creator for node"); - throw_exception(eEntryNotFound); - } - } - - graph::Sender Process(graph::Sender input) { - return node_->Process(std::move(input)); - } - - private: - static Value SetContext(Value config, const char* device_name, int device_id) { - Device device(device_name, device_id); - Stream stream(device); - config["context"].update({{"device", device}, {"stream", stream}}); - return config; - } - - std::unique_ptr node_; -}; - -} // namespace +namespace mmdeploy +{ + + using namespace framework; + + namespace + { + + class AsyncHandle + { + public: + AsyncHandle(const char* device_name, int device_id, Value config) + : AsyncHandle(SetContext(std::move(config), device_name, device_id)) + { + } + + explicit AsyncHandle(const Value& config) + { + if (auto builder = graph::Builder::CreateFromConfig(config).value()) + { + node_ = builder->Build().value(); + } + else + { + MMDEPLOY_ERROR("failed to find creator for node"); + throw_exception(eEntryNotFound); + } + } + + graph::Sender Process(graph::Sender input) + { + return node_->Process(std::move(input)); + } + + private: + static Value SetContext(Value config, const char* device_name, int device_id) + { + Device device(device_name, device_id); + Stream stream(device); + config["context"].update({{"device", device}, {"stream", stream}}); + return config; + } + + std::unique_ptr node_; + }; + + } // namespace } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/c/mmdeploy/model.cpp b/csrc/mmdeploy/apis/c/mmdeploy/model.cpp index 6d202bce81..08af517522 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/model.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/model.cpp @@ -12,30 +12,45 @@ using namespace mmdeploy; -int mmdeploy_model_create_by_path(const char* path, mmdeploy_model_t* model) { - try { - auto ptr = std::make_unique(path); - *model = reinterpret_cast(ptr.release()); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("failed to create model: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; +int mmdeploy_model_create_by_path(const char* path, mmdeploy_model_t* model) +{ + try + { + auto ptr = std::make_unique(path); + *model = reinterpret_cast(ptr.release()); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("failed to create model: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_model_create(const void* buffer, int size, mmdeploy_model_t* model) { - try { - auto ptr = std::make_unique(buffer, size); - *model = reinterpret_cast(ptr.release()); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("failed to create model: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; +int mmdeploy_model_create(const void* buffer, int size, mmdeploy_model_t* model) +{ + try + { + auto ptr = std::make_unique(buffer, size); + *model = reinterpret_cast(ptr.release()); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("failed to create model: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -void mmdeploy_model_destroy(mmdeploy_model_t model) { delete reinterpret_cast(model); } +void mmdeploy_model_destroy(mmdeploy_model_t model) +{ + delete reinterpret_cast(model); +} diff --git a/csrc/mmdeploy/apis/c/mmdeploy/model.h b/csrc/mmdeploy/apis/c/mmdeploy/model.h index 394d2902c2..ddea967f1a 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/model.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/model.h @@ -11,34 +11,35 @@ #include "mmdeploy/common.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_model* mmdeploy_model_t; - -/** - * @brief Create SDK Model instance from given model path - * @param[in] path model path - * @param[out] model sdk model instance that must be destroyed by \ref mmdeploy_model_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_model_create_by_path(const char* path, mmdeploy_model_t* model); - -/** - * @brief Create SDK Model instance from memory - * @param[in] buffer a linear buffer contains the model information - * @param[in] size size of \p buffer in bytes - * @param[out] model sdk model instance that must be destroyed by \ref mmdeploy_model_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_model_create(const void* buffer, int size, mmdeploy_model_t* model); - -/** - * @brief Destroy model instance - * @param[in] model sdk model instance created by \ref mmdeploy_model_create_by_path or \ref - * mmdeploy_model_create - */ -MMDEPLOY_API void mmdeploy_model_destroy(mmdeploy_model_t model); + typedef struct mmdeploy_model* mmdeploy_model_t; + + /** + * @brief Create SDK Model instance from given model path + * @param[in] path model path + * @param[out] model sdk model instance that must be destroyed by \ref mmdeploy_model_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_model_create_by_path(const char* path, mmdeploy_model_t* model); + + /** + * @brief Create SDK Model instance from memory + * @param[in] buffer a linear buffer contains the model information + * @param[in] size size of \p buffer in bytes + * @param[out] model sdk model instance that must be destroyed by \ref mmdeploy_model_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_model_create(const void* buffer, int size, mmdeploy_model_t* model); + + /** + * @brief Destroy model instance + * @param[in] model sdk model instance created by \ref mmdeploy_model_create_by_path or \ref + * mmdeploy_model_create + */ + MMDEPLOY_API void mmdeploy_model_destroy(mmdeploy_model_t model); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp b/csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp index a9a02807ee..9e0fcf011e 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp @@ -6,73 +6,95 @@ #include "mmdeploy/executor_internal.h" #include "mmdeploy/handle.h" -int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t context, - mmdeploy_pipeline_t* pipeline) { - try { - auto _config = *Cast(config); - if (context) { - if (!_config.contains("context")) { - _config["context"] = Value::Object(); - } - update(_config["context"].object(), Cast(context)->object(), 2); +int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t context, mmdeploy_pipeline_t* pipeline) +{ + try + { + auto _config = *Cast(config); + if (context) + { + if (!_config.contains("context")) + { + _config["context"] = Value::Object(); + } + update(_config["context"].object(), Cast(context)->object(), 2); + } + auto _handle = std::make_unique(std::move(_config)); + *pipeline = Cast(_handle.release()); + return MMDEPLOY_SUCCESS; } - auto _handle = std::make_unique(std::move(_config)); - *pipeline = Cast(_handle.release()); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_pipeline_t* pipeline) { - auto config = Cast(model)->ReadConfig("pipeline.json"); - auto _context = *Cast(context); - _context["model"] = *Cast(model); - return mmdeploy_pipeline_create_v3(Cast(&config.value()), (mmdeploy_context_t)&_context, - pipeline); +int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_pipeline_t* pipeline) +{ + auto config = Cast(model)->ReadConfig("pipeline.json"); + auto _context = *Cast(context); + _context["model"] = *Cast(model); + return mmdeploy_pipeline_create_v3(Cast(&config.value()), (mmdeploy_context_t)&_context, pipeline); } -int mmdeploy_pipeline_apply_async(mmdeploy_pipeline_t pipeline, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - if (!pipeline || !input || !output) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - auto h = Cast(pipeline); - *output = Take(h->Process(Take(input))); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; +int mmdeploy_pipeline_apply_async(mmdeploy_pipeline_t pipeline, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + if (!pipeline || !input || !output) + { + return MMDEPLOY_E_INVALID_ARG; + } + + try + { + auto h = Cast(pipeline); + *output = Take(h->Process(Take(input))); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + + return MMDEPLOY_E_FAIL; } -void mmdeploy_pipeline_destroy(mmdeploy_pipeline_t pipeline) { - if (pipeline != nullptr) { - delete Cast(pipeline); - } +void mmdeploy_pipeline_destroy(mmdeploy_pipeline_t pipeline) +{ + if (pipeline != nullptr) + { + delete Cast(pipeline); + } } -int mmdeploy_pipeline_apply(mmdeploy_pipeline_t pipeline, mmdeploy_value_t input, - mmdeploy_value_t* output) { - auto input_sender = mmdeploy_executor_just(input); - if (!input_sender) { - return MMDEPLOY_E_FAIL; - } - mmdeploy_sender_t output_sender{}; - if (auto ec = mmdeploy_pipeline_apply_async(pipeline, input_sender, &output_sender)) { - return ec; - } - auto _output = mmdeploy_executor_sync_wait(output_sender); - if (!_output) { - return MMDEPLOY_E_FAIL; - } - *output = _output; - return MMDEPLOY_SUCCESS; +int mmdeploy_pipeline_apply(mmdeploy_pipeline_t pipeline, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + auto input_sender = mmdeploy_executor_just(input); + if (!input_sender) + { + return MMDEPLOY_E_FAIL; + } + + mmdeploy_sender_t output_sender{}; + if (auto ec = mmdeploy_pipeline_apply_async(pipeline, input_sender, &output_sender)) + { + return ec; + } + + auto _output = mmdeploy_executor_sync_wait(output_sender); + if (!_output) + { + return MMDEPLOY_E_FAIL; + } + + *output = _output; + return MMDEPLOY_SUCCESS; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pipeline.h b/csrc/mmdeploy/apis/c/mmdeploy/pipeline.h index 55ccf1e67c..faf523863f 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pipeline.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/pipeline.h @@ -8,59 +8,59 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -/****************************************************************************** - * Experimental pipeline APIs */ + /****************************************************************************** + * Experimental pipeline APIs */ -typedef struct mmdeploy_pipeline* mmdeploy_pipeline_t; + typedef struct mmdeploy_pipeline* mmdeploy_pipeline_t; -/** - * Create pipeline - * @param config - * @param context - * @param pipeline - * @return - */ -MMDEPLOY_API int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t context, - mmdeploy_pipeline_t* pipeline); -/** - * Create pipeline from internal pipeline config of the model - * @param model - * @param context - * @param pipeline - * @return - */ -MMDEPLOY_API int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_pipeline_t* pipeline); + /** + * Create pipeline + * @param config + * @param context + * @param pipeline + * @return + */ + MMDEPLOY_API int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t context, mmdeploy_pipeline_t* pipeline); + /** + * Create pipeline from internal pipeline config of the model + * @param model + * @param context + * @param pipeline + * @return + */ + MMDEPLOY_API int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_pipeline_t* pipeline); -/** - * @brief Apply pipeline - * @param[in] pipeline handle of the pipeline - * @param[in] input input value - * @param[out] output output value - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pipeline_apply(mmdeploy_pipeline_t pipeline, mmdeploy_value_t input, - mmdeploy_value_t* output); + /** + * @brief Apply pipeline + * @param[in] pipeline handle of the pipeline + * @param[in] input input value + * @param[out] output output value + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pipeline_apply(mmdeploy_pipeline_t pipeline, mmdeploy_value_t input, mmdeploy_value_t* output); -/** - * Apply pipeline asynchronously - * @param pipeline handle of the pipeline - * @param input input sender that will be consumed by the operation - * @param output output sender - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pipeline_apply_async(mmdeploy_pipeline_t pipeline, - mmdeploy_sender_t input, mmdeploy_sender_t* output); + /** + * Apply pipeline asynchronously + * @param pipeline handle of the pipeline + * @param input input sender that will be consumed by the operation + * @param output output sender + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pipeline_apply_async(mmdeploy_pipeline_t pipeline, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); -/** - * @brief destroy pipeline - * @param[in] pipeline - */ -MMDEPLOY_API void mmdeploy_pipeline_destroy(mmdeploy_pipeline_t pipeline); + /** + * @brief destroy pipeline + * @param[in] pipeline + */ + MMDEPLOY_API void mmdeploy_pipeline_destroy(mmdeploy_pipeline_t pipeline); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp b/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp index 46f9921e62..ee0cc0c564 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp @@ -16,164 +16,197 @@ using namespace std; using namespace mmdeploy; -int mmdeploy_pose_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_pose_detector_t* detector) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_pose_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_pose_detector_t* detector) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_pose_detector_create_v2(model, context, detector); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_pose_detector_create_v2(model, context, detector); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_pose_detector_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_pose_detector_t* detector) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_pose_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_pose_detector_t* detector) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_pose_detector_create(model, device_name, device_id, detector); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_pose_detector_create(model, device_name, device_id, detector); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_pose_detector_apply(mmdeploy_pose_detector_t detector, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_pose_detection_t** results) { - return mmdeploy_pose_detector_apply_bbox(detector, mats, mat_count, nullptr, nullptr, results); +int mmdeploy_pose_detector_apply(mmdeploy_pose_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_pose_detection_t** results) +{ + return mmdeploy_pose_detector_apply_bbox(detector, mats, mat_count, nullptr, nullptr, results); } -int mmdeploy_pose_detector_apply_bbox(mmdeploy_pose_detector_t detector, const mmdeploy_mat_t* mats, - int mat_count, const mmdeploy_rect_t* bboxes, - const int* bbox_count, mmdeploy_pose_detection_t** results) { - wrapped input; - if (auto ec = - mmdeploy_pose_detector_create_input(mats, mat_count, bboxes, bbox_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_pose_detector_apply_v2(detector, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_pose_detector_get_result(output, results)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_pose_detector_apply_bbox(mmdeploy_pose_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, const mmdeploy_rect_t* bboxes, const int* bbox_count, mmdeploy_pose_detection_t** results) +{ + wrapped input; + if (auto ec = + mmdeploy_pose_detector_create_input(mats, mat_count, bboxes, bbox_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_pose_detector_apply_v2(detector, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_pose_detector_get_result(output, results)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -void mmdeploy_pose_detector_release_result(mmdeploy_pose_detection_t* results, int count) { - if (results == nullptr) { - return; - } - for (int i = 0; i < count; ++i) { - delete[] results[i].point; - delete[] results[i].score; - } - delete[] results; +void mmdeploy_pose_detector_release_result(mmdeploy_pose_detection_t* results, int count) +{ + if (results == nullptr) + { + return; + } + for (int i = 0; i < count; ++i) + { + delete[] results[i].point; + delete[] results[i].score; + } + delete[] results; } -void mmdeploy_pose_detector_destroy(mmdeploy_pose_detector_t detector) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); +void mmdeploy_pose_detector_destroy(mmdeploy_pose_detector_t detector) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); } -int mmdeploy_pose_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_pose_detector_t* detector) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); +int mmdeploy_pose_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_pose_detector_t* detector) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); } -int mmdeploy_pose_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - const mmdeploy_rect_t* bboxes, const int* bbox_count, - mmdeploy_value_t* value) { - if (mat_count && mats == nullptr) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value::Array input_images; - - auto add_bbox = [&](const Mat& img, const mmdeploy_rect_t* bbox) { - Value::Array b; - if (bbox) { - float width = bbox->right - bbox->left + 1; - float height = bbox->bottom - bbox->top + 1; - b = {bbox->left, bbox->top, width, height, 1.0}; - } else { - b = {0, 0, img.width(), img.height(), 1.0}; - } - input_images.push_back({{"ori_img", img}, {"bbox", std::move(b)}}); - }; - - for (int i = 0; i < mat_count; ++i) { - auto _mat = Cast(mats[i]); - if (bboxes && bbox_count) { - for (int j = 0; j < bbox_count[i]; ++j) { - add_bbox(_mat, bboxes++); - } - } else { // inference whole image - add_bbox(_mat, nullptr); - } +int mmdeploy_pose_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, const mmdeploy_rect_t* bboxes, const int* bbox_count, mmdeploy_value_t* value) +{ + if (mat_count && mats == nullptr) + { + return MMDEPLOY_E_INVALID_ARG; } + try + { + Value::Array input_images; + + auto add_bbox = [&](const Mat& img, const mmdeploy_rect_t* bbox) + { + Value::Array b; + if (bbox) + { + float width = bbox->right - bbox->left + 1; + float height = bbox->bottom - bbox->top + 1; + b = {bbox->left, bbox->top, width, height, 1.0}; + } + else + { + b = {0, 0, img.width(), img.height(), 1.0}; + } + input_images.push_back({{"ori_img", img}, {"bbox", std::move(b)}}); + }; + + for (int i = 0; i < mat_count; ++i) + { + auto _mat = Cast(mats[i]); + if (bboxes && bbox_count) + { + for (int j = 0; j < bbox_count[i]; ++j) + { + add_bbox(_mat, bboxes++); + } + } + else + { // inference whole image + add_bbox(_mat, nullptr); + } + } - *value = Take(Value{std::move(input_images)}); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + *value = Take(Value{std::move(input_images)}); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_pose_detector_apply_v2(mmdeploy_pose_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_pose_detector_apply_v2(mmdeploy_pose_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_pose_detector_apply_async(mmdeploy_pose_detector_t detector, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_pose_detector_apply_async(mmdeploy_pose_detector_t detector, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_pose_detector_get_result(mmdeploy_value_t output, - mmdeploy_pose_detection_t** results) { - if (!output || !results) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - std::vector detections; - from_value(Cast(output)->front(), detections); - - size_t count = detections.size(); - - auto deleter = [&](mmdeploy_pose_detection_t* p) { - mmdeploy_pose_detector_release_result(p, static_cast(count)); - }; - - std::unique_ptr _results( - new mmdeploy_pose_detection_t[count]{}, deleter); - - size_t result_idx = 0; - for (const auto& bbox_result : detections) { - auto& res = _results[result_idx++]; - auto size = bbox_result.key_points.size(); - - res.point = new mmdeploy_point_t[size]; - res.score = new float[size]; - res.length = static_cast(size); - - for (int k = 0; k < size; k++) { - res.point[k].x = bbox_result.key_points[k].bbox[0]; - res.point[k].y = bbox_result.key_points[k].bbox[1]; - res.score[k] = bbox_result.key_points[k].score; - } +int mmdeploy_pose_detector_get_result(mmdeploy_value_t output, + mmdeploy_pose_detection_t** results) +{ + if (!output || !results) + { + return MMDEPLOY_E_INVALID_ARG; } + try + { + std::vector detections; + from_value(Cast(output)->front(), detections); + + size_t count = detections.size(); + + auto deleter = [&](mmdeploy_pose_detection_t* p) + { + mmdeploy_pose_detector_release_result(p, static_cast(count)); + }; + + std::unique_ptr _results( + new mmdeploy_pose_detection_t[count]{}, + deleter); + + size_t result_idx = 0; + for (const auto& bbox_result : detections) + { + auto& res = _results[result_idx++]; + auto size = bbox_result.key_points.size(); + + res.point = new mmdeploy_point_t[size]; + res.score = new float[size]; + res.length = static_cast(size); + + for (int k = 0; k < size; k++) + { + res.point[k].x = bbox_result.key_points[k].bbox[0]; + res.point[k].y = bbox_result.key_points[k].bbox[1]; + res.score[k] = bbox_result.key_points[k].score; + } + } - *results = _results.release(); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + *results = _results.release(); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h b/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h index ff0987cee4..6fceb99f72 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h @@ -13,111 +13,113 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_pose_detection_t { - mmdeploy_point_t* point; ///< keypoint - float* score; ///< keypoint score - int length; ///< number of keypoint -} mmdeploy_pose_detection_t; - -typedef struct mmdeploy_pose_detector* mmdeploy_pose_detector_t; - -/** - * @brief Create a pose detector instance - * @param[in] model an instance of mmpose model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector handle of the created pose detector, which must be destroyed - * by \ref mmdeploy_pose_detector_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_pose_detector_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_pose_detector_t* detector); - -/** - * @brief Create a pose detector instance - * @param[in] model_path path to pose detection model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector handle of the created pose detector, which must be destroyed - * by \ref mmdeploy_pose_detector_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_pose_detector_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_pose_detector_t* detector); - -/** - * @brief Apply pose detector to a batch of images with full image roi - * @param[in] detector pose detector's handle created by \ref - * mmdeploy_pose_detector_create_by_path - * @param[in] images a batch of images - * @param[in] count number of images in the batch - * @param[out] results a linear buffer contains the pose result, must be release - * by \ref mmdeploy_pose_detector_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_pose_detector_apply(mmdeploy_pose_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_pose_detection_t** results); - -/** - * @brief Apply pose detector to a batch of images supplied with bboxes(roi) - * @param[in] detector pose detector's handle created by \ref - * mmdeploy_pose_detector_create_by_path - * @param[in] images a batch of images - * @param[in] image_count number of images in the batch - * @param[in] bboxes bounding boxes(roi) detected by mmdet - * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images - * @param[out] results a linear buffer contains the pose result, which has the same length as \p - * bboxes, must be release by \ref mmdeploy_pose_detector_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_pose_detector_apply_bbox(mmdeploy_pose_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - const mmdeploy_rect_t* bboxes, - const int* bbox_count, - mmdeploy_pose_detection_t** results); - -/** @brief Release result buffer returned by \ref mmdeploy_pose_detector_apply or \ref - * mmdeploy_pose_detector_apply_bbox - * @param[in] results result buffer by pose detector - * @param[in] count length of \p result - */ -MMDEPLOY_API void mmdeploy_pose_detector_release_result(mmdeploy_pose_detection_t* results, - int count); - -/** - * @brief destroy pose_detector - * @param[in] detector handle of pose_detector created by \ref - * mmdeploy_pose_detector_create_by_path or \ref mmdeploy_pose_detector_create - */ -MMDEPLOY_API void mmdeploy_pose_detector_destroy(mmdeploy_pose_detector_t detector); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -MMDEPLOY_API int mmdeploy_pose_detector_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_pose_detector_t* detector); - -MMDEPLOY_API int mmdeploy_pose_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - const mmdeploy_rect_t* bboxes, - const int* bbox_count, - mmdeploy_value_t* value); - -MMDEPLOY_API int mmdeploy_pose_detector_apply_v2(mmdeploy_pose_detector_t detector, - mmdeploy_value_t input, mmdeploy_value_t* output); - -MMDEPLOY_API int mmdeploy_pose_detector_apply_async(mmdeploy_pose_detector_t detector, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_pose_detector_get_result(mmdeploy_value_t output, - mmdeploy_pose_detection_t** results); + typedef struct mmdeploy_pose_detection_t + { + mmdeploy_point_t* point; ///< keypoint + float* score; ///< keypoint score + int length; ///< number of keypoint + } mmdeploy_pose_detection_t; + + typedef struct mmdeploy_pose_detector* mmdeploy_pose_detector_t; + + /** + * @brief Create a pose detector instance + * @param[in] model an instance of mmpose model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector handle of the created pose detector, which must be destroyed + * by \ref mmdeploy_pose_detector_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_pose_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_pose_detector_t* detector); + + /** + * @brief Create a pose detector instance + * @param[in] model_path path to pose detection model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector handle of the created pose detector, which must be destroyed + * by \ref mmdeploy_pose_detector_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_pose_detector_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_pose_detector_t* detector); + + /** + * @brief Apply pose detector to a batch of images with full image roi + * @param[in] detector pose detector's handle created by \ref + * mmdeploy_pose_detector_create_by_path + * @param[in] images a batch of images + * @param[in] count number of images in the batch + * @param[out] results a linear buffer contains the pose result, must be release + * by \ref mmdeploy_pose_detector_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_pose_detector_apply(mmdeploy_pose_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_pose_detection_t** results); + + /** + * @brief Apply pose detector to a batch of images supplied with bboxes(roi) + * @param[in] detector pose detector's handle created by \ref + * mmdeploy_pose_detector_create_by_path + * @param[in] images a batch of images + * @param[in] image_count number of images in the batch + * @param[in] bboxes bounding boxes(roi) detected by mmdet + * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images + * @param[out] results a linear buffer contains the pose result, which has the same length as \p + * bboxes, must be release by \ref mmdeploy_pose_detector_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_pose_detector_apply_bbox(mmdeploy_pose_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + const mmdeploy_rect_t* bboxes, + const int* bbox_count, + mmdeploy_pose_detection_t** results); + + /** @brief Release result buffer returned by \ref mmdeploy_pose_detector_apply or \ref + * mmdeploy_pose_detector_apply_bbox + * @param[in] results result buffer by pose detector + * @param[in] count length of \p result + */ + MMDEPLOY_API void mmdeploy_pose_detector_release_result(mmdeploy_pose_detection_t* results, + int count); + + /** + * @brief destroy pose_detector + * @param[in] detector handle of pose_detector created by \ref + * mmdeploy_pose_detector_create_by_path or \ref mmdeploy_pose_detector_create + */ + MMDEPLOY_API void mmdeploy_pose_detector_destroy(mmdeploy_pose_detector_t detector); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + MMDEPLOY_API int mmdeploy_pose_detector_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_pose_detector_t* detector); + + MMDEPLOY_API int mmdeploy_pose_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, const mmdeploy_rect_t* bboxes, const int* bbox_count, mmdeploy_value_t* value); + + MMDEPLOY_API int mmdeploy_pose_detector_apply_v2(mmdeploy_pose_detector_t detector, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + MMDEPLOY_API int mmdeploy_pose_detector_apply_async(mmdeploy_pose_detector_t detector, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_pose_detector_get_result(mmdeploy_value_t output, + mmdeploy_pose_detection_t** results); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.cpp b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.cpp index 113b520c39..d2587b1949 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.cpp @@ -9,18 +9,21 @@ #include "mmdeploy/core/mpl/structure.h" #include "mmdeploy/pipeline.h" -namespace mmdeploy { +namespace mmdeploy +{ -using namespace framework; + using namespace framework; } // namespace mmdeploy using namespace mmdeploy; -namespace { +namespace +{ -Value config_template() { - static const auto json = R"( + Value config_template() + { + static const auto json = R"( { "type": "Pipeline", "input": ["img", "force_det", "state"], @@ -77,149 +80,184 @@ Value config_template() { ] } )"_json; - static const auto config = from_json(json); - return config; -} + static const auto config = from_json(json); + return config; + } } // namespace -int mmdeploy_pose_tracker_default_params(mmdeploy_pose_tracker_param_t* params) { - mmpose::_pose_tracker::SetDefaultParams(*params); - return 0; +int mmdeploy_pose_tracker_default_params(mmdeploy_pose_tracker_param_t* params) +{ + mmpose::_pose_tracker::SetDefaultParams(*params); + return 0; } -int mmdeploy_pose_tracker_create(mmdeploy_model_t det_model, mmdeploy_model_t pose_model, - mmdeploy_context_t context, mmdeploy_pose_tracker_t* pipeline) { - mmdeploy_context_add(context, MMDEPLOY_TYPE_MODEL, "detection", det_model); - mmdeploy_context_add(context, MMDEPLOY_TYPE_MODEL, "pose", pose_model); - auto config = config_template(); - return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)pipeline); +int mmdeploy_pose_tracker_create(mmdeploy_model_t det_model, mmdeploy_model_t pose_model, mmdeploy_context_t context, mmdeploy_pose_tracker_t* pipeline) +{ + mmdeploy_context_add(context, MMDEPLOY_TYPE_MODEL, "detection", det_model); + mmdeploy_context_add(context, MMDEPLOY_TYPE_MODEL, "pose", pose_model); + auto config = config_template(); + return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)pipeline); } -void mmdeploy_pose_tracker_destroy(mmdeploy_pose_tracker_t pipeline) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)pipeline); +void mmdeploy_pose_tracker_destroy(mmdeploy_pose_tracker_t pipeline) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)pipeline); } -int mmdeploy_pose_tracker_create_state(mmdeploy_pose_tracker_t pipeline, +int mmdeploy_pose_tracker_create_state(mmdeploy_pose_tracker_t pipeline, const mmdeploy_pose_tracker_param_t* params, - mmdeploy_pose_tracker_state_t* state) { - try { - auto create_fn = gRegistry().Create("pose_tracker::Create", Value()).value(); - *state = reinterpret_cast(new Value( - create_fn->Process({const_cast(params)}).value()[0])); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + mmdeploy_pose_tracker_state_t* state) +{ + try + { + auto create_fn = gRegistry().Create("pose_tracker::Create", Value()).value(); + *state = reinterpret_cast(new Value( + create_fn->Process({const_cast(params)}).value()[0])); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -void mmdeploy_pose_tracker_destroy_state(mmdeploy_pose_tracker_state_t state) { - delete reinterpret_cast(state); +void mmdeploy_pose_tracker_destroy_state(mmdeploy_pose_tracker_state_t state) +{ + delete reinterpret_cast(state); } int mmdeploy_pose_tracker_create_input(mmdeploy_pose_tracker_state_t* states, - const mmdeploy_mat_t* frames, const int32_t* use_detect, - int batch_size, mmdeploy_value_t* value) { - try { - Value::Array images; - Value::Array use_dets; - Value::Array trackers; - for (int i = 0; i < batch_size; ++i) { - images.push_back({{"ori_img", Cast(frames[i])}}); - use_dets.emplace_back(use_detect ? use_detect[i] : -1); - trackers.push_back(*reinterpret_cast(states[i])); + const mmdeploy_mat_t* frames, + const int32_t* use_detect, + int batch_size, + mmdeploy_value_t* value) +{ + try + { + Value::Array images; + Value::Array use_dets; + Value::Array trackers; + for (int i = 0; i < batch_size; ++i) + { + images.push_back({{"ori_img", Cast(frames[i])}}); + use_dets.emplace_back(use_detect ? use_detect[i] : -1); + trackers.push_back(*reinterpret_cast(states[i])); + } + *value = Take(Value{std::move(images), std::move(use_dets), std::move(trackers)}); + return MMDEPLOY_SUCCESS; } - *value = Take(Value{std::move(images), std::move(use_dets), std::move(trackers)}); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -using ResultType = mmdeploy::Structure, - std::vector>; +using ResultType = mmdeploy::Structure, std::vector>; -int mmdeploy_pose_tracker_get_result(mmdeploy_value_t output, +int mmdeploy_pose_tracker_get_result(mmdeploy_value_t output, mmdeploy_pose_tracker_target_t** results, - int32_t** result_count) { - if (!output || !results) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - // convert result from Values - std::vector res; - from_value(Cast(output)->front(), res); - - size_t total = 0; - for (const auto& r : res) { - total += r.bboxes.size(); + int32_t** result_count) +{ + if (!output || !results) + { + return MMDEPLOY_E_INVALID_ARG; } + try + { + // convert result from Values + std::vector res; + from_value(Cast(output)->front(), res); - // preserve space for the output structure - ResultType result_type({total, 1, 1}); - auto [result_data, result_cnt, result_holder] = result_type.pointers(); + size_t total = 0; + for (const auto& r : res) + { + total += r.bboxes.size(); + } - auto result_ptr = result_data; + // preserve space for the output structure + ResultType result_type({total, 1, 1}); + auto [result_data, result_cnt, result_holder] = result_type.pointers(); - result_holder->swap(res); + auto result_ptr = result_data; - // build output structure - for (auto& r : *result_holder) { - for (int j = 0; j < r.bboxes.size(); ++j) { - auto& p = *result_ptr++; - p.keypoint_count = static_cast(r.keypoints[j].size()); - p.keypoints = r.keypoints[j].data(); - p.scores = r.scores[j].data(); - p.bbox = r.bboxes[j]; - p.target_id = r.track_ids[j]; - } - result_cnt->push_back(r.bboxes.size()); - // debug info - // p.reserved0 = new std::vector(r.pose_input_bboxes); - // p.reserved1 = new std::vector(r.pose_output_bboxes); - } + result_holder->swap(res); - *results = result_data; - *result_count = result_cnt->data(); - result_type.release(); + // build output structure + for (auto& r : *result_holder) + { + for (int j = 0; j < r.bboxes.size(); ++j) + { + auto& p = *result_ptr++; + p.keypoint_count = static_cast(r.keypoints[j].size()); + p.keypoints = r.keypoints[j].data(); + p.scores = r.scores[j].data(); + p.bbox = r.bboxes[j]; + p.target_id = r.track_ids[j]; + } + result_cnt->push_back(r.bboxes.size()); + // debug info + // p.reserved0 = new std::vector(r.pose_input_bboxes); + // p.reserved1 = new std::vector(r.pose_output_bboxes); + } - return MMDEPLOY_SUCCESS; + *results = result_data; + *result_count = result_cnt->data(); + result_type.release(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_pose_tracker_apply(mmdeploy_pose_tracker_t pipeline, - mmdeploy_pose_tracker_state_t* states, const mmdeploy_mat_t* frames, - const int32_t* use_detect, int32_t count, - mmdeploy_pose_tracker_target_t** results, int32_t** result_count) { - wrapped input; - if (auto ec = - mmdeploy_pose_tracker_create_input(states, frames, use_detect, count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_pipeline_apply((mmdeploy_pipeline_t)pipeline, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_pose_tracker_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_pose_tracker_apply(mmdeploy_pose_tracker_t pipeline, + mmdeploy_pose_tracker_state_t* states, + const mmdeploy_mat_t* frames, + const int32_t* use_detect, + int32_t count, + mmdeploy_pose_tracker_target_t** results, + int32_t** result_count) +{ + wrapped input; + if (auto ec = + mmdeploy_pose_tracker_create_input(states, frames, use_detect, count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_pipeline_apply((mmdeploy_pipeline_t)pipeline, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_pose_tracker_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } void mmdeploy_pose_tracker_release_result(mmdeploy_pose_tracker_target_t* results, - const int32_t* result_count, int count) { - auto total = std::accumulate(result_count, result_count + count, 0); - ResultType deleter({static_cast(total), 1, 1}, results); + const int32_t* result_count, + int count) +{ + auto total = std::accumulate(result_count, result_count + count, 0); + ResultType deleter({static_cast(total), 1, 1}, results); } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h index 4b27fbab8a..c8191b40fa 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h @@ -14,142 +14,147 @@ #include "mmdeploy/pose_detector.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_pose_tracker* mmdeploy_pose_tracker_t; -typedef struct mmdeploy_pose_tracker_state* mmdeploy_pose_tracker_state_t; - -typedef struct mmdeploy_pose_tracker_param_t { - // detection interval, default = 1 - int32_t det_interval; - // detection label use for pose estimation, default = 0 - int32_t det_label; - // detection score threshold, default = 0.5 - float det_thr; - // detection minimum bbox size (compute as sqrt(area)), default = -1 - float det_min_bbox_size; - // nms iou threshold for merging detected bboxes and bboxes from tracked targets, default = 0.7 - float det_nms_thr; - - // max number of bboxes used for pose estimation per frame, default = -1 - int32_t pose_max_num_bboxes; - // threshold for visible key-points, default = 0.5 - float pose_kpt_thr; - // min number of key-points for valid poses (-1 indicates ceil(n_kpts/2)), default = -1 - int32_t pose_min_keypoints; - // scale for expanding key-points to bbox, default = 1.25 - float pose_bbox_scale; - // min pose bbox size, tracks with bbox size smaller than the threshold will be dropped, - // default = -1 - float pose_min_bbox_size; - // nms oks/iou threshold for suppressing overlapped poses, useful when multiple pose estimations - // collapse to the same target, default = 0.5 - float pose_nms_thr; - // keypoint sigmas for computing OKS, will use IOU if not set, default = nullptr - float* keypoint_sigmas; - // size of keypoint sigma array, must be consistent with the number of key-points, default = 0 - int32_t keypoint_sigmas_size; - - // iou threshold for associating missing tracks, default = 0.4 - float track_iou_thr; - // max number of missing frames before a missing tracks is removed, default = 10 - int32_t track_max_missing; - // track history size, default = 1 - int32_t track_history_size; - - // weight of position for setting covariance matrices of kalman filters, default = 0.05 - float std_weight_position; - // weight of velocity for setting covariance matrices of kalman filters, default = 0.00625 - float std_weight_velocity; - - // params for the one-euro filter for smoothing the outputs - (beta, fc_min, fc_derivative) - // default = (0.007, 1, 1) - float smooth_params[3]; -} mmdeploy_pose_tracker_param_t; - -typedef struct mmdeploy_pose_tracker_target_t { - mmdeploy_point_t* keypoints; // key-points of the target - int32_t keypoint_count; // size of `keypoints` array - float* scores; // scores of each key-point - mmdeploy_rect_t bbox; // estimated bbox from key-points - uint32_t target_id; // target id from internal tracker -} mmdeploy_pose_tracker_target_t; - -/** - * @brief Fill params with default parameters - * @param[in,out] params - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pose_tracker_default_params(mmdeploy_pose_tracker_param_t* params); - -/** - * @brief Create pose tracker pipeline - * @param[in] det_model detection model object, created by \ref mmdeploy_model_create - * @param[in] pose_model pose model object - * @param[in] context context object describing execution environment (device, profiler, etc...), - * created by \ref mmdeploy_context_create - * @param[out] pipeline handle of the created pipeline - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pose_tracker_create(mmdeploy_model_t det_model, - mmdeploy_model_t pose_model, - mmdeploy_context_t context, - mmdeploy_pose_tracker_t* pipeline); - -/** - * @brief Destroy pose tracker pipeline - * @param[in] pipeline - */ -MMDEPLOY_API void mmdeploy_pose_tracker_destroy(mmdeploy_pose_tracker_t pipeline); - -/** - * @brief Create a tracker state handle corresponds to a video stream - * @param[in] pipeline handle of a pose tracker pipeline - * @param[in] params params for creating the tracker state - * @param[out] state handle of the created tracker state - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pose_tracker_create_state(mmdeploy_pose_tracker_t pipeline, - const mmdeploy_pose_tracker_param_t* params, - mmdeploy_pose_tracker_state_t* state); - -/** - * @brief Destroy tracker state - * @param[in] state handle of the tracker state - */ -MMDEPLOY_API void mmdeploy_pose_tracker_destroy_state(mmdeploy_pose_tracker_state_t state); - -/** - * @brief Apply pose tracker pipeline, notice that this function supports batch operation by feeding - * arrays of size \p count to \p states, \p frames and \p use_detect - * @param[in] pipeline handle of a pose tracker pipeline - * @param[in] states tracker states handles, array of size \p count - * @param[in] frames input frames of size \p count - * @param[in] use_detect control the use of detector, array of size \p count - * -1: use params.det_interval, 0: don't use detector, 1: force use detector - * @param[in] count batch size - * @param[out] results a linear buffer contains the tracked targets of input frames. Should be - * released by \ref mmdeploy_pose_tracker_release_result - * @param[out] result_count a linear buffer of size \p count contains the number of tracked - * targets of the frames. Should be released by \ref mmdeploy_pose_tracker_release_result - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pose_tracker_apply(mmdeploy_pose_tracker_t pipeline, - mmdeploy_pose_tracker_state_t* states, - const mmdeploy_mat_t* frames, - const int32_t* use_detect, int32_t count, - mmdeploy_pose_tracker_target_t** results, - int32_t** result_count); - -/** - * @brief Release result objects - * @param[in] results - * @param[in] result_count - * @param[in] count - */ -MMDEPLOY_API void mmdeploy_pose_tracker_release_result(mmdeploy_pose_tracker_target_t* results, - const int32_t* result_count, int count); + typedef struct mmdeploy_pose_tracker* mmdeploy_pose_tracker_t; + typedef struct mmdeploy_pose_tracker_state* mmdeploy_pose_tracker_state_t; + + typedef struct mmdeploy_pose_tracker_param_t + { + // detection interval, default = 1 + int32_t det_interval; + // detection label use for pose estimation, default = 0 + int32_t det_label; + // detection score threshold, default = 0.5 + float det_thr; + // detection minimum bbox size (compute as sqrt(area)), default = -1 + float det_min_bbox_size; + // nms iou threshold for merging detected bboxes and bboxes from tracked targets, default = 0.7 + float det_nms_thr; + + // max number of bboxes used for pose estimation per frame, default = -1 + int32_t pose_max_num_bboxes; + // threshold for visible key-points, default = 0.5 + float pose_kpt_thr; + // min number of key-points for valid poses (-1 indicates ceil(n_kpts/2)), default = -1 + int32_t pose_min_keypoints; + // scale for expanding key-points to bbox, default = 1.25 + float pose_bbox_scale; + // min pose bbox size, tracks with bbox size smaller than the threshold will be dropped, + // default = -1 + float pose_min_bbox_size; + // nms oks/iou threshold for suppressing overlapped poses, useful when multiple pose estimations + // collapse to the same target, default = 0.5 + float pose_nms_thr; + // keypoint sigmas for computing OKS, will use IOU if not set, default = nullptr + float* keypoint_sigmas; + // size of keypoint sigma array, must be consistent with the number of key-points, default = 0 + int32_t keypoint_sigmas_size; + + // iou threshold for associating missing tracks, default = 0.4 + float track_iou_thr; + // max number of missing frames before a missing tracks is removed, default = 10 + int32_t track_max_missing; + // track history size, default = 1 + int32_t track_history_size; + + // weight of position for setting covariance matrices of kalman filters, default = 0.05 + float std_weight_position; + // weight of velocity for setting covariance matrices of kalman filters, default = 0.00625 + float std_weight_velocity; + + // params for the one-euro filter for smoothing the outputs - (beta, fc_min, fc_derivative) + // default = (0.007, 1, 1) + float smooth_params[3]; + } mmdeploy_pose_tracker_param_t; + + typedef struct mmdeploy_pose_tracker_target_t + { + mmdeploy_point_t* keypoints; // key-points of the target + int32_t keypoint_count; // size of `keypoints` array + float* scores; // scores of each key-point + mmdeploy_rect_t bbox; // estimated bbox from key-points + uint32_t target_id; // target id from internal tracker + } mmdeploy_pose_tracker_target_t; + + /** + * @brief Fill params with default parameters + * @param[in,out] params + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pose_tracker_default_params(mmdeploy_pose_tracker_param_t* params); + + /** + * @brief Create pose tracker pipeline + * @param[in] det_model detection model object, created by \ref mmdeploy_model_create + * @param[in] pose_model pose model object + * @param[in] context context object describing execution environment (device, profiler, etc...), + * created by \ref mmdeploy_context_create + * @param[out] pipeline handle of the created pipeline + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pose_tracker_create(mmdeploy_model_t det_model, + mmdeploy_model_t pose_model, + mmdeploy_context_t context, + mmdeploy_pose_tracker_t* pipeline); + + /** + * @brief Destroy pose tracker pipeline + * @param[in] pipeline + */ + MMDEPLOY_API void mmdeploy_pose_tracker_destroy(mmdeploy_pose_tracker_t pipeline); + + /** + * @brief Create a tracker state handle corresponds to a video stream + * @param[in] pipeline handle of a pose tracker pipeline + * @param[in] params params for creating the tracker state + * @param[out] state handle of the created tracker state + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pose_tracker_create_state(mmdeploy_pose_tracker_t pipeline, + const mmdeploy_pose_tracker_param_t* params, + mmdeploy_pose_tracker_state_t* state); + + /** + * @brief Destroy tracker state + * @param[in] state handle of the tracker state + */ + MMDEPLOY_API void mmdeploy_pose_tracker_destroy_state(mmdeploy_pose_tracker_state_t state); + + /** + * @brief Apply pose tracker pipeline, notice that this function supports batch operation by feeding + * arrays of size \p count to \p states, \p frames and \p use_detect + * @param[in] pipeline handle of a pose tracker pipeline + * @param[in] states tracker states handles, array of size \p count + * @param[in] frames input frames of size \p count + * @param[in] use_detect control the use of detector, array of size \p count + * -1: use params.det_interval, 0: don't use detector, 1: force use detector + * @param[in] count batch size + * @param[out] results a linear buffer contains the tracked targets of input frames. Should be + * released by \ref mmdeploy_pose_tracker_release_result + * @param[out] result_count a linear buffer of size \p count contains the number of tracked + * targets of the frames. Should be released by \ref mmdeploy_pose_tracker_release_result + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pose_tracker_apply(mmdeploy_pose_tracker_t pipeline, + mmdeploy_pose_tracker_state_t* states, + const mmdeploy_mat_t* frames, + const int32_t* use_detect, + int32_t count, + mmdeploy_pose_tracker_target_t** results, + int32_t** result_count); + + /** + * @brief Release result objects + * @param[in] results + * @param[in] result_count + * @param[in] count + */ + MMDEPLOY_API void mmdeploy_pose_tracker_release_result(mmdeploy_pose_tracker_target_t* results, + const int32_t* result_count, + int count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp b/csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp index 9ca2ca65f7..49f8487d12 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp @@ -16,106 +16,121 @@ using namespace mmdeploy; using ResultType = mmdeploy::Structure; -int mmdeploy_restorer_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_restorer_t* restorer) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_restorer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_restorer_t* restorer) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_restorer_create_v2(model, context, restorer); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_restorer_create_v2(model, context, restorer); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_restorer_create_by_path(const char* model_path, const char* device_name, int device_id, - mmdeploy_restorer_t* restorer) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_restorer_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_restorer_t* restorer) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_restorer_create(model, device_name, device_id, restorer); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_restorer_create(model, device_name, device_id, restorer); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_restorer_apply(mmdeploy_restorer_t restorer, const mmdeploy_mat_t* images, int count, - mmdeploy_mat_t** results) { - wrapped input; - if (auto ec = mmdeploy_restorer_create_input(images, count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_restorer_apply_v2(restorer, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_restorer_get_result(output, results)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_restorer_apply(mmdeploy_restorer_t restorer, const mmdeploy_mat_t* images, int count, mmdeploy_mat_t** results) +{ + wrapped input; + if (auto ec = mmdeploy_restorer_create_input(images, count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_restorer_apply_v2(restorer, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_restorer_get_result(output, results)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -void mmdeploy_restorer_release_result(mmdeploy_mat_t* results, int count) { - ResultType deleter{static_cast(count), results}; +void mmdeploy_restorer_release_result(mmdeploy_mat_t* results, int count) +{ + ResultType deleter{static_cast(count), results}; } -void mmdeploy_restorer_destroy(mmdeploy_restorer_t restorer) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)restorer); +void mmdeploy_restorer_destroy(mmdeploy_restorer_t restorer) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)restorer); } -int mmdeploy_restorer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_restorer_t* restorer) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)restorer); +int mmdeploy_restorer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_restorer_t* restorer) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)restorer); } -int mmdeploy_restorer_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value) { - return mmdeploy_common_create_input(mats, mat_count, value); +int mmdeploy_restorer_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value) +{ + return mmdeploy_common_create_input(mats, mat_count, value); } -int mmdeploy_restorer_apply_v2(mmdeploy_restorer_t restorer, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)restorer, input, output); +int mmdeploy_restorer_apply_v2(mmdeploy_restorer_t restorer, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)restorer, input, output); } -int mmdeploy_restorer_apply_async(mmdeploy_restorer_t restorer, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)restorer, input, output); +int mmdeploy_restorer_apply_async(mmdeploy_restorer_t restorer, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)restorer, input, output); } -int mmdeploy_restorer_get_result(mmdeploy_value_t output, mmdeploy_mat_t** results) { - if (!output || !results) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - const Value& value = Cast(output)->front(); - - auto restorer_output = from_value>(value); - auto count = restorer_output.size(); - - ResultType r(count); - auto [_results, buffers] = r.pointers(); - - for (int i = 0; i < count; ++i) { - auto upscale = restorer_output[i]; - auto& res = _results[i]; - res.data = upscale.data(); - buffers[i] = upscale.buffer(); - res.format = (mmdeploy_pixel_format_t)upscale.pixel_format(); - res.height = upscale.height(); - res.width = upscale.width(); - res.channel = upscale.channel(); - res.type = (mmdeploy_data_type_t)upscale.type(); +int mmdeploy_restorer_get_result(mmdeploy_value_t output, mmdeploy_mat_t** results) +{ + if (!output || !results) + { + return MMDEPLOY_E_INVALID_ARG; } - - *results = _results; - r.release(); - - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + try + { + const Value& value = Cast(output)->front(); + + auto restorer_output = from_value>(value); + auto count = restorer_output.size(); + + ResultType r(count); + auto [_results, buffers] = r.pointers(); + + for (int i = 0; i < count; ++i) + { + auto upscale = restorer_output[i]; + auto& res = _results[i]; + res.data = upscale.data(); + buffers[i] = upscale.buffer(); + res.format = (mmdeploy_pixel_format_t)upscale.pixel_format(); + res.height = upscale.height(); + res.width = upscale.width(); + res.channel = upscale.channel(); + res.type = (mmdeploy_data_type_t)upscale.type(); + } + + *results = _results; + r.release(); + + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/restorer.h b/csrc/mmdeploy/apis/c/mmdeploy/restorer.h index 9ab529850f..5c8533102f 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/restorer.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/restorer.h @@ -13,76 +13,72 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_restorer* mmdeploy_restorer_t; - -/** - * @brief Create a restorer instance - * @param[in] model an instance of image restoration model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] restorer handle of the created restorer, which must be destroyed - * by \ref mmdeploy_restorer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_restorer_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_restorer_t* restorer); - -/** - * @brief Create a restorer instance - * @param[in] model_path path to image restoration model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] restorer handle of the created restorer, which must be destroyed - * by \ref mmdeploy_restorer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_restorer_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_restorer_t* restorer); - -/** - * @brief Apply restorer to a batch of images - * @param[in] restorer restorer's handle created by \ref mmdeploy_restorer_create_by_path - * @param[in] images a batch of images - * @param[in] count number of images in the batch - * @param[out] results a linear buffer contains the restored images, must be release - * by \ref mmdeploy_restorer_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_restorer_apply(mmdeploy_restorer_t restorer, const mmdeploy_mat_t* images, - int count, mmdeploy_mat_t** results); - -/** @brief Release result buffer returned by \ref mmdeploy_restorer_apply - * @param[in] results result buffer by restorer - * @param[in] count length of \p result - */ -MMDEPLOY_API void mmdeploy_restorer_release_result(mmdeploy_mat_t* results, int count); - -/** - * @brief destroy restorer - * @param[in] restorer handle of restorer created by \ref mmdeploy_restorer_create_by_path - */ -MMDEPLOY_API void mmdeploy_restorer_destroy(mmdeploy_restorer_t restorer); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -MMDEPLOY_API int mmdeploy_restorer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_restorer_t* restorer); - -MMDEPLOY_API int mmdeploy_restorer_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value); - -MMDEPLOY_API int mmdeploy_restorer_apply_v2(mmdeploy_restorer_t restorer, mmdeploy_value_t input, - mmdeploy_value_t* output); - -MMDEPLOY_API int mmdeploy_restorer_apply_async(mmdeploy_restorer_t restorer, - mmdeploy_sender_t input, mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_restorer_get_result(mmdeploy_value_t output, mmdeploy_mat_t** results); + typedef struct mmdeploy_restorer* mmdeploy_restorer_t; + + /** + * @brief Create a restorer instance + * @param[in] model an instance of image restoration model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] restorer handle of the created restorer, which must be destroyed + * by \ref mmdeploy_restorer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_restorer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_restorer_t* restorer); + + /** + * @brief Create a restorer instance + * @param[in] model_path path to image restoration model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] restorer handle of the created restorer, which must be destroyed + * by \ref mmdeploy_restorer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_restorer_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_restorer_t* restorer); + + /** + * @brief Apply restorer to a batch of images + * @param[in] restorer restorer's handle created by \ref mmdeploy_restorer_create_by_path + * @param[in] images a batch of images + * @param[in] count number of images in the batch + * @param[out] results a linear buffer contains the restored images, must be release + * by \ref mmdeploy_restorer_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_restorer_apply(mmdeploy_restorer_t restorer, const mmdeploy_mat_t* images, int count, mmdeploy_mat_t** results); + + /** @brief Release result buffer returned by \ref mmdeploy_restorer_apply + * @param[in] results result buffer by restorer + * @param[in] count length of \p result + */ + MMDEPLOY_API void mmdeploy_restorer_release_result(mmdeploy_mat_t* results, int count); + + /** + * @brief destroy restorer + * @param[in] restorer handle of restorer created by \ref mmdeploy_restorer_create_by_path + */ + MMDEPLOY_API void mmdeploy_restorer_destroy(mmdeploy_restorer_t restorer); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + MMDEPLOY_API int mmdeploy_restorer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_restorer_t* restorer); + + MMDEPLOY_API int mmdeploy_restorer_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value); + + MMDEPLOY_API int mmdeploy_restorer_apply_v2(mmdeploy_restorer_t restorer, mmdeploy_value_t input, mmdeploy_value_t* output); + + MMDEPLOY_API int mmdeploy_restorer_apply_async(mmdeploy_restorer_t restorer, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_restorer_get_result(mmdeploy_value_t output, mmdeploy_mat_t** results); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.cpp b/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.cpp index d2172c54b8..04d537a376 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.cpp @@ -15,124 +15,146 @@ using namespace std; using namespace mmdeploy; -int mmdeploy_rotated_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_rotated_detector_t* detector) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_rotated_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_rotated_detector_t* detector) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_rotated_detector_create_v2(model, context, detector); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_rotated_detector_create_v2(model, context, detector); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_rotated_detector_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_rotated_detector_t* detector) { - mmdeploy_model_t model{}; +int mmdeploy_rotated_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_rotated_detector_t* detector) +{ + mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_rotated_detector_create(model, device_name, device_id, detector); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_rotated_detector_create(model, device_name, device_id, detector); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_rotated_detector_apply(mmdeploy_rotated_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_rotated_detection_t** results, int** result_count) { - wrapped input; - if (auto ec = mmdeploy_rotated_detector_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_rotated_detector_apply_v2(detector, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_rotated_detector_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_rotated_detector_apply(mmdeploy_rotated_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_rotated_detection_t** results, + int** result_count) +{ + wrapped input; + if (auto ec = mmdeploy_rotated_detector_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_rotated_detector_apply_v2(detector, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_rotated_detector_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } void mmdeploy_rotated_detector_release_result(mmdeploy_rotated_detection_t* results, - const int* result_count) { - delete[] results; - delete[] result_count; + const int* result_count) +{ + delete[] results; + delete[] result_count; } -void mmdeploy_rotated_detector_destroy(mmdeploy_rotated_detector_t detector) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); +void mmdeploy_rotated_detector_destroy(mmdeploy_rotated_detector_t detector) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); } -int mmdeploy_rotated_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_rotated_detector_t* detector) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); +int mmdeploy_rotated_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_rotated_detector_t* detector) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); } -int mmdeploy_rotated_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input) { - return mmdeploy_common_create_input(mats, mat_count, input); +int mmdeploy_rotated_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input) +{ + return mmdeploy_common_create_input(mats, mat_count, input); } -int mmdeploy_rotated_detector_apply_v2(mmdeploy_rotated_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_rotated_detector_apply_v2(mmdeploy_rotated_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); } int mmdeploy_rotated_detector_apply_async(mmdeploy_rotated_detector_t detector, - mmdeploy_sender_t input, mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); + mmdeploy_sender_t input, + mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_rotated_detector_get_result(mmdeploy_value_t output, +int mmdeploy_rotated_detector_get_result(mmdeploy_value_t output, mmdeploy_rotated_detection_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - - try { - Value& value = Cast(output)->front(); - auto detector_outputs = from_value>(value); - - vector _result_count; - _result_count.reserve(detector_outputs.size()); - for (const auto& det_output : detector_outputs) { - _result_count.push_back((int)det_output.detections.size()); + int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } - auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0); + try + { + Value& value = Cast(output)->front(); + auto detector_outputs = from_value>(value); - std::unique_ptr result_count_data(new int[_result_count.size()]{}); - std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); - - std::unique_ptr result_data( - new mmdeploy_rotated_detection_t[total]{}); - auto result_ptr = result_data.get(); - - for (const auto& det_output : detector_outputs) { - for (const auto& detection : det_output.detections) { - result_ptr->label_id = detection.label_id; - result_ptr->score = detection.score; - const auto& rbbox = detection.rbbox; - for (int i = 0; i < 5; i++) { - result_ptr->rbbox[i] = rbbox[i]; + vector _result_count; + _result_count.reserve(detector_outputs.size()); + for (const auto& det_output : detector_outputs) + { + _result_count.push_back((int)det_output.detections.size()); } - ++result_ptr; - } - } - *result_count = result_count_data.release(); - *results = result_data.release(); + auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0); + + std::unique_ptr result_count_data(new int[_result_count.size()]{}); + std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); + + std::unique_ptr result_data( + new mmdeploy_rotated_detection_t[total]{}); + auto result_ptr = result_data.get(); + + for (const auto& det_output : detector_outputs) + { + for (const auto& detection : det_output.detections) + { + result_ptr->label_id = detection.label_id; + result_ptr->score = detection.score; + const auto& rbbox = detection.rbbox; + for (int i = 0; i < 5; i++) + { + result_ptr->rbbox[i] = rbbox[i]; + } + ++result_ptr; + } + } - return MMDEPLOY_SUCCESS; + *result_count = result_count_data.release(); + *results = result_data.release(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.h b/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.h index 35125a74ff..1d745debae 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.h @@ -13,125 +13,126 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_rotated_detection_t { - int label_id; - float score; - float rbbox[5]; // cx, cy, w, h, angle -} mmdeploy_rotated_detection_t; - -typedef struct mmdeploy_rotated_detector* mmdeploy_rotated_detector_t; - -/** - * @brief Create rotated detector's handle - * @param[in] model an instance of mmrotate sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a rotated detector - * @return status of creating rotated detector's handle - */ -MMDEPLOY_API int mmdeploy_rotated_detector_create(mmdeploy_model_t model, const char* device_name, - int device_id, - mmdeploy_rotated_detector_t* detector); - -/** - * @brief Create rotated detector's handle - * @param[in] model_path path of mmrotate sdk model exported by mmdeploy model converter - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a rotated detector - * @return status of creating rotated detector's handle - */ -MMDEPLOY_API int mmdeploy_rotated_detector_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_rotated_detector_t* detector); - -/** - * @brief Apply rotated detector to batch images and get their inference results - * @param[in] detector rotated detector's handle created by \ref - * mmdeploy_rotated_detector_create_by_path - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer to save detection results of each image. It must be released - * by \ref mmdeploy_rotated_detector_release_result - * @param[out] result_count a linear buffer with length being \p mat_count to save the number of - * detection results of each image. And it must be released by \ref - * mmdeploy_rotated_detector_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_rotated_detector_apply(mmdeploy_rotated_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_rotated_detection_t** results, - int** result_count); - -/** @brief Release the inference result buffer created by \ref mmdeploy_rotated_detector_apply - * @param[in] results rotated detection results buffer - * @param[in] result_count \p results size buffer - */ -MMDEPLOY_API void mmdeploy_rotated_detector_release_result(mmdeploy_rotated_detection_t* results, - const int* result_count); - -/** - * @brief Destroy rotated detector's handle - * @param[in] detector rotated detector's handle created by \ref - * mmdeploy_rotated_detector_create_by_path or by \ref mmdeploy_rotated_detector_create - */ -MMDEPLOY_API void mmdeploy_rotated_detector_destroy(mmdeploy_rotated_detector_t detector); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_detector_create, but allows to control execution context of tasks - * via context - */ -MMDEPLOY_API int mmdeploy_rotated_detector_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_rotated_detector_t* detector); - -/** - * @brief Pack rotated detector inputs into mmdeploy_value_t - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @return the created value - */ -MMDEPLOY_API int mmdeploy_rotated_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input); - -/** - * @brief Same as \ref mmdeploy_rotated_detector_apply, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_rotated_detector_apply_v2(mmdeploy_rotated_detector_t detector, - mmdeploy_value_t input, - mmdeploy_value_t* output); - -/** - * @brief Apply rotated detector asynchronously - * @param[in] detector handle to the detector - * @param[in] input input sender - * @return output sender - */ -MMDEPLOY_API int mmdeploy_rotated_detector_apply_async(mmdeploy_rotated_detector_t detector, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -/** - * @brief Unpack rotated detector output from a mmdeploy_value_t - * @param[in] output output obtained by applying a detector - * @param[out] results a linear buffer to save detection results of each image. It must be released - * by \ref mmdeploy_detector_release_result - * @param[out] result_count a linear buffer with length number of input images to save the number of - * detection results of each image. Must be released by \ref - * mmdeploy_detector_release_result - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_rotated_detector_get_result(mmdeploy_value_t output, - mmdeploy_rotated_detection_t** results, - int** result_count); + typedef struct mmdeploy_rotated_detection_t + { + int label_id; + float score; + float rbbox[5]; // cx, cy, w, h, angle + } mmdeploy_rotated_detection_t; + + typedef struct mmdeploy_rotated_detector* mmdeploy_rotated_detector_t; + + /** + * @brief Create rotated detector's handle + * @param[in] model an instance of mmrotate sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a rotated detector + * @return status of creating rotated detector's handle + */ + MMDEPLOY_API int mmdeploy_rotated_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_rotated_detector_t* detector); + + /** + * @brief Create rotated detector's handle + * @param[in] model_path path of mmrotate sdk model exported by mmdeploy model converter + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a rotated detector + * @return status of creating rotated detector's handle + */ + MMDEPLOY_API int mmdeploy_rotated_detector_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_rotated_detector_t* detector); + + /** + * @brief Apply rotated detector to batch images and get their inference results + * @param[in] detector rotated detector's handle created by \ref + * mmdeploy_rotated_detector_create_by_path + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer to save detection results of each image. It must be released + * by \ref mmdeploy_rotated_detector_release_result + * @param[out] result_count a linear buffer with length being \p mat_count to save the number of + * detection results of each image. And it must be released by \ref + * mmdeploy_rotated_detector_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_rotated_detector_apply(mmdeploy_rotated_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_rotated_detection_t** results, + int** result_count); + + /** @brief Release the inference result buffer created by \ref mmdeploy_rotated_detector_apply + * @param[in] results rotated detection results buffer + * @param[in] result_count \p results size buffer + */ + MMDEPLOY_API void mmdeploy_rotated_detector_release_result(mmdeploy_rotated_detection_t* results, + const int* result_count); + + /** + * @brief Destroy rotated detector's handle + * @param[in] detector rotated detector's handle created by \ref + * mmdeploy_rotated_detector_create_by_path or by \ref mmdeploy_rotated_detector_create + */ + MMDEPLOY_API void mmdeploy_rotated_detector_destroy(mmdeploy_rotated_detector_t detector); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_detector_create, but allows to control execution context of tasks + * via context + */ + MMDEPLOY_API int mmdeploy_rotated_detector_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_rotated_detector_t* detector); + + /** + * @brief Pack rotated detector inputs into mmdeploy_value_t + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @return the created value + */ + MMDEPLOY_API int mmdeploy_rotated_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input); + + /** + * @brief Same as \ref mmdeploy_rotated_detector_apply, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_rotated_detector_apply_v2(mmdeploy_rotated_detector_t detector, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Apply rotated detector asynchronously + * @param[in] detector handle to the detector + * @param[in] input input sender + * @return output sender + */ + MMDEPLOY_API int mmdeploy_rotated_detector_apply_async(mmdeploy_rotated_detector_t detector, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + /** + * @brief Unpack rotated detector output from a mmdeploy_value_t + * @param[in] output output obtained by applying a detector + * @param[out] results a linear buffer to save detection results of each image. It must be released + * by \ref mmdeploy_detector_release_result + * @param[out] result_count a linear buffer with length number of input images to save the number of + * detection results of each image. Must be released by \ref + * mmdeploy_detector_release_result + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_rotated_detector_get_result(mmdeploy_value_t output, + mmdeploy_rotated_detection_t** results, + int** result_count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp index c982df39e5..9ec8ae366c 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp @@ -18,111 +18,128 @@ using namespace mmdeploy; using ResultType = mmdeploy::Structure; -int mmdeploy_segmentor_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_segmentor_t* segmentor) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_segmentor_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_segmentor_t* segmentor) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_segmentor_create_v2(model, context, segmentor); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_segmentor_create_v2(model, context, segmentor); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_segmentor_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_segmentor_t* segmentor) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_segmentor_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_segmentor_t* segmentor) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_segmentor_create(model, device_name, device_id, segmentor); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_segmentor_create(model, device_name, device_id, segmentor); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_segmentor_apply(mmdeploy_segmentor_t segmentor, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_segmentation_t** results) { - wrapped input; - if (auto ec = mmdeploy_segmentor_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_segmentor_apply_v2(segmentor, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_segmentor_get_result(output, results)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_segmentor_apply(mmdeploy_segmentor_t segmentor, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_segmentation_t** results) +{ + wrapped input; + if (auto ec = mmdeploy_segmentor_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_segmentor_apply_v2(segmentor, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_segmentor_get_result(output, results)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -void mmdeploy_segmentor_release_result(mmdeploy_segmentation_t* results, int count) { - ResultType deleter(static_cast(count), results); +void mmdeploy_segmentor_release_result(mmdeploy_segmentation_t* results, int count) +{ + ResultType deleter(static_cast(count), results); } -void mmdeploy_segmentor_destroy(mmdeploy_segmentor_t segmentor) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)segmentor); +void mmdeploy_segmentor_destroy(mmdeploy_segmentor_t segmentor) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)segmentor); } -int mmdeploy_segmentor_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_segmentor_t* segmentor) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)segmentor); +int mmdeploy_segmentor_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_segmentor_t* segmentor) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)segmentor); } -int mmdeploy_segmentor_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value) { - return mmdeploy_common_create_input(mats, mat_count, value); +int mmdeploy_segmentor_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value) +{ + return mmdeploy_common_create_input(mats, mat_count, value); } -int mmdeploy_segmentor_apply_v2(mmdeploy_segmentor_t segmentor, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)segmentor, input, output); +int mmdeploy_segmentor_apply_v2(mmdeploy_segmentor_t segmentor, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)segmentor, input, output); } -int mmdeploy_segmentor_apply_async(mmdeploy_segmentor_t segmentor, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)segmentor, input, output); +int mmdeploy_segmentor_apply_async(mmdeploy_segmentor_t segmentor, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)segmentor, input, output); } -int mmdeploy_segmentor_get_result(mmdeploy_value_t output, mmdeploy_segmentation_t** results) { - try { - const auto& value = Cast(output)->front(); - size_t image_count = value.size(); - - ResultType r(image_count); - auto [results_data, buffers] = r.pointers(); - - auto results_ptr = results_data; - - for (auto i = 0; i < image_count; ++i, ++results_ptr) { - auto& output_item = value[i]; - MMDEPLOY_DEBUG("the {}-th item in output: {}", i, output_item); - auto segmentor_output = from_value(output_item); - results_ptr->height = segmentor_output.height; - results_ptr->width = segmentor_output.width; - results_ptr->classes = segmentor_output.classes; - auto& mask = segmentor_output.mask; - auto& score = segmentor_output.score; - results_ptr->mask = nullptr; - results_ptr->score = nullptr; - if (mask.shape().size()) { - results_ptr->mask = mask.data(); - buffers[i] = mask.buffer(); - } else { - results_ptr->score = score.data(); - buffers[i] = score.buffer(); - } +int mmdeploy_segmentor_get_result(mmdeploy_value_t output, mmdeploy_segmentation_t** results) +{ + try + { + const auto& value = Cast(output)->front(); + size_t image_count = value.size(); + + ResultType r(image_count); + auto [results_data, buffers] = r.pointers(); + + auto results_ptr = results_data; + + for (auto i = 0; i < image_count; ++i, ++results_ptr) + { + auto& output_item = value[i]; + MMDEPLOY_DEBUG("the {}-th item in output: {}", i, output_item); + auto segmentor_output = from_value(output_item); + results_ptr->height = segmentor_output.height; + results_ptr->width = segmentor_output.width; + results_ptr->classes = segmentor_output.classes; + auto& mask = segmentor_output.mask; + auto& score = segmentor_output.score; + results_ptr->mask = nullptr; + results_ptr->score = nullptr; + if (mask.shape().size()) + { + results_ptr->mask = mask.data(); + buffers[i] = mask.buffer(); + } + else + { + results_ptr->score = score.data(); + buffers[i] = score.buffer(); + } + } + + *results = results_data; + r.release(); + + return MMDEPLOY_SUCCESS; } - - *results = results_data; - r.release(); - - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h index 65bcfd03f3..8d885a275b 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h @@ -13,91 +13,90 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_segmentation_t { - int height; ///< height of \p mask that equals to the input image's height - int width; ///< width of \p mask that equals to the input image's width - int classes; ///< the number of labels in \p mask - int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates - ///< the label id of pixel at (i, j), this field might be null - float* score; ///< segmentation score map of the input image in CHW format, in which - ///< score[height * width * k + i * width + j] indicates the score - ///< of class k at pixel (i, j), this field might be null -} mmdeploy_segmentation_t; - -typedef struct mmdeploy_segmentor* mmdeploy_segmentor_t; - -/** - * @brief Create segmentor's handle - * @param[in] model an instance of mmsegmentation sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] segmentor instance of a segmentor, which must be destroyed - * by \ref mmdeploy_segmentor_destroy - * @return status of creating segmentor's handle - */ -MMDEPLOY_API int mmdeploy_segmentor_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_segmentor_t* segmentor); - -/** - * @brief Create segmentor's handle - * @param[in] model_path path of mmsegmentation sdk model exported by mmdeploy model converter - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] segmentor instance of a segmentor, which must be destroyed - * by \ref mmdeploy_segmentor_destroy - * @return status of creating segmentor's handle - */ -MMDEPLOY_API int mmdeploy_segmentor_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_segmentor_t* segmentor); - -/** - * @brief Apply segmentor to batch images and get their inference results - * @param[in] segmentor segmentor's handle created by \ref mmdeploy_segmentor_create_by_path or \ref - * mmdeploy_segmentor_create - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer of length \p mat_count to save segmentation result of each - * image. It must be released by \ref mmdeploy_segmentor_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_segmentor_apply(mmdeploy_segmentor_t segmentor, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_segmentation_t** results); - -/** - * @brief Release result buffer returned by \ref mmdeploy_segmentor_apply - * @param[in] results result buffer - * @param[in] count length of \p results - */ -MMDEPLOY_API void mmdeploy_segmentor_release_result(mmdeploy_segmentation_t* results, int count); - -/** - * @brief Destroy segmentor's handle - * @param[in] segmentor segmentor's handle created by \ref mmdeploy_segmentor_create_by_path - */ -MMDEPLOY_API void mmdeploy_segmentor_destroy(mmdeploy_segmentor_t segmentor); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -MMDEPLOY_API int mmdeploy_segmentor_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_segmentor_t* segmentor); - -MMDEPLOY_API int mmdeploy_segmentor_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value); - -MMDEPLOY_API int mmdeploy_segmentor_apply_v2(mmdeploy_segmentor_t segmentor, mmdeploy_value_t input, - mmdeploy_value_t* output); - -MMDEPLOY_API int mmdeploy_segmentor_apply_async(mmdeploy_segmentor_t segmentor, - mmdeploy_sender_t input, mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_segmentor_get_result(mmdeploy_value_t output, - mmdeploy_segmentation_t** results); + typedef struct mmdeploy_segmentation_t + { + int height; ///< height of \p mask that equals to the input image's height + int width; ///< width of \p mask that equals to the input image's width + int classes; ///< the number of labels in \p mask + int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates + ///< the label id of pixel at (i, j), this field might be null + float* score; ///< segmentation score map of the input image in CHW format, in which + ///< score[height * width * k + i * width + j] indicates the score + ///< of class k at pixel (i, j), this field might be null + } mmdeploy_segmentation_t; + + typedef struct mmdeploy_segmentor* mmdeploy_segmentor_t; + + /** + * @brief Create segmentor's handle + * @param[in] model an instance of mmsegmentation sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] segmentor instance of a segmentor, which must be destroyed + * by \ref mmdeploy_segmentor_destroy + * @return status of creating segmentor's handle + */ + MMDEPLOY_API int mmdeploy_segmentor_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_segmentor_t* segmentor); + + /** + * @brief Create segmentor's handle + * @param[in] model_path path of mmsegmentation sdk model exported by mmdeploy model converter + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] segmentor instance of a segmentor, which must be destroyed + * by \ref mmdeploy_segmentor_destroy + * @return status of creating segmentor's handle + */ + MMDEPLOY_API int mmdeploy_segmentor_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_segmentor_t* segmentor); + + /** + * @brief Apply segmentor to batch images and get their inference results + * @param[in] segmentor segmentor's handle created by \ref mmdeploy_segmentor_create_by_path or \ref + * mmdeploy_segmentor_create + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer of length \p mat_count to save segmentation result of each + * image. It must be released by \ref mmdeploy_segmentor_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_segmentor_apply(mmdeploy_segmentor_t segmentor, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_segmentation_t** results); + + /** + * @brief Release result buffer returned by \ref mmdeploy_segmentor_apply + * @param[in] results result buffer + * @param[in] count length of \p results + */ + MMDEPLOY_API void mmdeploy_segmentor_release_result(mmdeploy_segmentation_t* results, int count); + + /** + * @brief Destroy segmentor's handle + * @param[in] segmentor segmentor's handle created by \ref mmdeploy_segmentor_create_by_path + */ + MMDEPLOY_API void mmdeploy_segmentor_destroy(mmdeploy_segmentor_t segmentor); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + MMDEPLOY_API int mmdeploy_segmentor_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_segmentor_t* segmentor); + + MMDEPLOY_API int mmdeploy_segmentor_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value); + + MMDEPLOY_API int mmdeploy_segmentor_apply_v2(mmdeploy_segmentor_t segmentor, mmdeploy_value_t input, mmdeploy_value_t* output); + + MMDEPLOY_API int mmdeploy_segmentor_apply_async(mmdeploy_segmentor_t segmentor, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_segmentor_get_result(mmdeploy_value_t output, + mmdeploy_segmentation_t** results); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/text_detector.cpp b/csrc/mmdeploy/apis/c/mmdeploy/text_detector.cpp index 576af07762..44b124187f 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/text_detector.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/text_detector.cpp @@ -16,158 +16,186 @@ using namespace std; using namespace mmdeploy; -int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_text_detector_t* detector) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_text_detector_t* detector) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_text_detector_create_v2(model, context, detector); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_text_detector_create_v2(model, context, detector); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_text_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_text_detector_t* detector) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); +int mmdeploy_text_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_text_detector_t* detector) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); } -int mmdeploy_text_detector_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_text_detector_t* detector) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_text_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_text_detector_t* detector) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_text_detector_create(model, device_name, device_id, detector); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_text_detector_create(model, device_name, device_id, detector); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_text_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input) { - return mmdeploy_common_create_input(mats, mat_count, input); +int mmdeploy_text_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input) +{ + return mmdeploy_common_create_input(mats, mat_count, input); } -int mmdeploy_text_detector_apply(mmdeploy_text_detector_t detector, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_text_detection_t** results, - int** result_count) { - wrapped input; - if (auto ec = mmdeploy_text_detector_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_text_detector_apply_v2(detector, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_text_detector_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_text_detector_apply(mmdeploy_text_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_text_detection_t** results, int** result_count) +{ + wrapped input; + if (auto ec = mmdeploy_text_detector_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_text_detector_apply_v2(detector, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_text_detector_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_text_detector_apply_v2(mmdeploy_text_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_text_detector_apply_v2(mmdeploy_text_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_text_detector_apply_async(mmdeploy_text_detector_t detector, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_text_detector_apply_async(mmdeploy_text_detector_t detector, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_text_detector_get_result(mmdeploy_value_t output, mmdeploy_text_detection_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value& value = reinterpret_cast(output)->front(); - auto detector_outputs = from_value>(value); - - vector _result_count; - _result_count.reserve(detector_outputs.size()); - for (const auto& det_output : detector_outputs) { - _result_count.push_back((int)det_output.size()); +int mmdeploy_text_detector_get_result(mmdeploy_value_t output, mmdeploy_text_detection_t** results, int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } - - auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0); - - std::unique_ptr result_count_data(new int[_result_count.size()]{}); - std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); - - std::unique_ptr result_data( - new mmdeploy_text_detection_t[total]{}); - auto result_ptr = result_data.get(); - - for (const auto& det_output : detector_outputs) { - for (auto i = 0; i < det_output.size(); ++i, ++result_ptr) { - result_ptr->score = det_output[i].score; - auto& bbox = det_output[i].bbox; - for (auto j = 0; j < bbox.size(); j += 2) { - result_ptr->bbox[j / 2].x = bbox[j]; - result_ptr->bbox[j / 2].y = bbox[j + 1]; + try + { + Value& value = reinterpret_cast(output)->front(); + auto detector_outputs = from_value>(value); + + vector _result_count; + _result_count.reserve(detector_outputs.size()); + for (const auto& det_output : detector_outputs) + { + _result_count.push_back((int)det_output.size()); } - } - } - *result_count = result_count_data.release(); - *results = result_data.release(); + auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0); + + std::unique_ptr result_count_data(new int[_result_count.size()]{}); + std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); + + std::unique_ptr result_data( + new mmdeploy_text_detection_t[total]{}); + auto result_ptr = result_data.get(); + + for (const auto& det_output : detector_outputs) + { + for (auto i = 0; i < det_output.size(); ++i, ++result_ptr) + { + result_ptr->score = det_output[i].score; + auto& bbox = det_output[i].bbox; + for (auto j = 0; j < bbox.size(); j += 2) + { + result_ptr->bbox[j / 2].x = bbox[j]; + result_ptr->bbox[j / 2].y = bbox[j + 1]; + } + } + } - return MMDEPLOY_SUCCESS; + *result_count = result_count_data.release(); + *results = result_data.release(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return 0; + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return 0; } void mmdeploy_text_detector_release_result(mmdeploy_text_detection_t* results, - const int* result_count, int count) { - delete[] results; - delete[] result_count; + const int* result_count, + int count) +{ + delete[] results; + delete[] result_count; } -void mmdeploy_text_detector_destroy(mmdeploy_text_detector_t detector) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); +void mmdeploy_text_detector_destroy(mmdeploy_text_detector_t detector) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); } -int mmdeploy_text_detector_apply_async_v2(mmdeploy_text_detector_t detector, - const mmdeploy_mat_t* imgs, int img_count, - mmdeploy_text_detector_continue_t cont, void* context, - mmdeploy_sender_t* output) { - mmdeploy_sender_t result_sender{}; - if (auto ec = mmdeploy_text_detector_apply_async_v3(detector, imgs, img_count, &result_sender)) { - return ec; - } - if (auto ec = mmdeploy_text_detector_continue_async(result_sender, cont, context, output)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_text_detector_apply_async_v2(mmdeploy_text_detector_t detector, + const mmdeploy_mat_t* imgs, + int img_count, + mmdeploy_text_detector_continue_t cont, + void* context, + mmdeploy_sender_t* output) +{ + mmdeploy_sender_t result_sender{}; + if (auto ec = mmdeploy_text_detector_apply_async_v3(detector, imgs, img_count, &result_sender)) + { + return ec; + } + if (auto ec = mmdeploy_text_detector_continue_async(result_sender, cont, context, output)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } int mmdeploy_text_detector_apply_async_v3(mmdeploy_text_detector_t detector, - const mmdeploy_mat_t* imgs, int img_count, - mmdeploy_sender_t* output) { - wrapped input_val; - if (auto ec = mmdeploy_text_detector_create_input(imgs, img_count, input_val.ptr())) { - return ec; - } - mmdeploy_sender_t input_sndr = mmdeploy_executor_just(input_val); - if (auto ec = mmdeploy_text_detector_apply_async(detector, input_sndr, output)) { - return ec; - } - return MMDEPLOY_SUCCESS; + const mmdeploy_mat_t* imgs, + int img_count, + mmdeploy_sender_t* output) +{ + wrapped input_val; + if (auto ec = mmdeploy_text_detector_create_input(imgs, img_count, input_val.ptr())) + { + return ec; + } + mmdeploy_sender_t input_sndr = mmdeploy_executor_just(input_val); + if (auto ec = mmdeploy_text_detector_apply_async(detector, input_sndr, output)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, - mmdeploy_text_detector_continue_t cont, void* context, - mmdeploy_sender_t* output) { - auto sender = Guard([&] { - return Take( - LetValue(Take(input), [fn = cont, context](Value& value) -> TypeErasedSender { +int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, + mmdeploy_text_detector_continue_t cont, + void* context, + mmdeploy_sender_t* output) +{ + auto sender = Guard([&] + { return Take( + LetValue(Take(input), [fn = cont, context](Value& value) -> TypeErasedSender + { mmdeploy_text_detection_t* results{}; int* result_count{}; if (auto ec = mmdeploy_text_detector_get_result(Cast(&value), &results, &result_count)) { @@ -178,12 +206,11 @@ int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, if (auto ec = fn(results, result_count, context, &output); ec || !output) { return Just(Value()); } - return Take(output); - })); - }); - if (sender) { - *output = sender; - return MMDEPLOY_SUCCESS; - } - return MMDEPLOY_E_FAIL; + return Take(output); })); }); + if (sender) + { + *output = sender; + return MMDEPLOY_SUCCESS; + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/text_detector.h b/csrc/mmdeploy/apis/c/mmdeploy/text_detector.h index a3c38dc6f6..da363940d7 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/text_detector.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/text_detector.h @@ -13,141 +13,147 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_text_detection_t { - mmdeploy_point_t bbox[4]; ///< a text bounding box of which the vertex are in clock-wise - float score; -} mmdeploy_text_detection_t; - -typedef struct mmdeploy_text_detector* mmdeploy_text_detector_t; - -/** - * @brief Create text-detector's handle - * @param[in] model an instance of mmocr text detection model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a text-detector, which must be destroyed - * by \ref mmdeploy_text_detector_destroy - * @return status of creating text-detector's handle - */ -MMDEPLOY_API int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_text_detector_t* detector); - -/** - * @brief Create text-detector's handle - * @param[in] model_path path to text detection model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device - * @param[out] detector instance of a text-detector, which must be destroyed - * by \ref mmdeploy_text_detector_destroy - * @return status of creating text-detector's handle - */ -MMDEPLOY_API int mmdeploy_text_detector_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_text_detector_t* detector); - -/** - * @brief Apply text-detector to batch images and get their inference results - * @param[in] detector text-detector's handle created by \ref mmdeploy_text_detector_create_by_path - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer to save text detection results of each - * image. It must be released by calling \ref mmdeploy_text_detector_release_result - * @param[out] result_count a linear buffer of length \p mat_count to save the number of detection - * results of each image. It must be released by \ref mmdeploy_detector_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_text_detector_apply(mmdeploy_text_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_text_detection_t** results, - int** result_count); - -/** @brief Release the inference result buffer returned by \ref mmdeploy_text_detector_apply - * @param[in] results text detection result buffer - * @param[in] result_count \p results size buffer - * @param[in] count the length of buffer \p result_count - */ -MMDEPLOY_API void mmdeploy_text_detector_release_result(mmdeploy_text_detection_t* results, - const int* result_count, int count); - -/** - * @brief Destroy text-detector's handle - * @param[in] detector text-detector's handle created by \ref mmdeploy_text_detector_create_by_path - * or \ref mmdeploy_text_detector_create - */ -MMDEPLOY_API void mmdeploy_text_detector_destroy(mmdeploy_text_detector_t detector); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_text_detector_create, but allows to control execution context of - * tasks via context - */ -MMDEPLOY_API int mmdeploy_text_detector_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_text_detector_t* detector); - -/** - * @brief Pack text-detector inputs into mmdeploy_value_t - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @return the created value - */ -MMDEPLOY_API int mmdeploy_text_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input); - -/** - * @brief Same as \ref mmdeploy_text_detector_apply, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_text_detector_apply_v2(mmdeploy_text_detector_t detector, - mmdeploy_value_t input, mmdeploy_value_t* output); - -/** - * @brief Apply text-detector asynchronously - * @param[in] detector handle to the detector - * @param[in] input input sender that will be consumed by the operation - * @return output sender - */ -MMDEPLOY_API int mmdeploy_text_detector_apply_async(mmdeploy_text_detector_t detector, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -/** - * @brief Unpack detector output from a mmdeploy_value_t - * @param[in] output output sender returned by applying a detector - * @param[out] results a linear buffer to save detection results of each image. It must be - * released by \ref mmdeploy_text_detector_release_result - * @param[out] result_count a linear buffer with length number of input images to save the - * number of detection results of each image. Must be released by \ref - * mmdeploy_text_detector_release_result - * @return status of the operation - */ -MMDEPLOY_API -int mmdeploy_text_detector_get_result(mmdeploy_value_t output, mmdeploy_text_detection_t** results, - int** result_count); - -typedef int (*mmdeploy_text_detector_continue_t)(mmdeploy_text_detection_t* results, - int* result_count, void* context, - mmdeploy_sender_t* output); - -// MMDEPLOY_API int mmdeploy_text_detector_apply_async_v2(mm_handle_t handle, const mm_mat_t* imgs, -// int img_count, -// mmdeploy_text_detector_continuation_t -// cont, void* context, mmdeploy_sender_t* -// output); - -MMDEPLOY_API int mmdeploy_text_detector_apply_async_v3(mmdeploy_text_detector_t detector, - const mmdeploy_mat_t* imgs, int img_count, - mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, - mmdeploy_text_detector_continue_t cont, - void* context, mmdeploy_sender_t* output); + typedef struct mmdeploy_text_detection_t + { + mmdeploy_point_t bbox[4]; ///< a text bounding box of which the vertex are in clock-wise + float score; + } mmdeploy_text_detection_t; + + typedef struct mmdeploy_text_detector* mmdeploy_text_detector_t; + + /** + * @brief Create text-detector's handle + * @param[in] model an instance of mmocr text detection model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a text-detector, which must be destroyed + * by \ref mmdeploy_text_detector_destroy + * @return status of creating text-detector's handle + */ + MMDEPLOY_API int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_text_detector_t* detector); + + /** + * @brief Create text-detector's handle + * @param[in] model_path path to text detection model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device + * @param[out] detector instance of a text-detector, which must be destroyed + * by \ref mmdeploy_text_detector_destroy + * @return status of creating text-detector's handle + */ + MMDEPLOY_API int mmdeploy_text_detector_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_text_detector_t* detector); + + /** + * @brief Apply text-detector to batch images and get their inference results + * @param[in] detector text-detector's handle created by \ref mmdeploy_text_detector_create_by_path + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer to save text detection results of each + * image. It must be released by calling \ref mmdeploy_text_detector_release_result + * @param[out] result_count a linear buffer of length \p mat_count to save the number of detection + * results of each image. It must be released by \ref mmdeploy_detector_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_text_detector_apply(mmdeploy_text_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_text_detection_t** results, + int** result_count); + + /** @brief Release the inference result buffer returned by \ref mmdeploy_text_detector_apply + * @param[in] results text detection result buffer + * @param[in] result_count \p results size buffer + * @param[in] count the length of buffer \p result_count + */ + MMDEPLOY_API void mmdeploy_text_detector_release_result(mmdeploy_text_detection_t* results, + const int* result_count, + int count); + + /** + * @brief Destroy text-detector's handle + * @param[in] detector text-detector's handle created by \ref mmdeploy_text_detector_create_by_path + * or \ref mmdeploy_text_detector_create + */ + MMDEPLOY_API void mmdeploy_text_detector_destroy(mmdeploy_text_detector_t detector); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_text_detector_create, but allows to control execution context of + * tasks via context + */ + MMDEPLOY_API int mmdeploy_text_detector_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_text_detector_t* detector); + + /** + * @brief Pack text-detector inputs into mmdeploy_value_t + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @return the created value + */ + MMDEPLOY_API int mmdeploy_text_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input); + + /** + * @brief Same as \ref mmdeploy_text_detector_apply, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_text_detector_apply_v2(mmdeploy_text_detector_t detector, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Apply text-detector asynchronously + * @param[in] detector handle to the detector + * @param[in] input input sender that will be consumed by the operation + * @return output sender + */ + MMDEPLOY_API int mmdeploy_text_detector_apply_async(mmdeploy_text_detector_t detector, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + /** + * @brief Unpack detector output from a mmdeploy_value_t + * @param[in] output output sender returned by applying a detector + * @param[out] results a linear buffer to save detection results of each image. It must be + * released by \ref mmdeploy_text_detector_release_result + * @param[out] result_count a linear buffer with length number of input images to save the + * number of detection results of each image. Must be released by \ref + * mmdeploy_text_detector_release_result + * @return status of the operation + */ + MMDEPLOY_API + int mmdeploy_text_detector_get_result(mmdeploy_value_t output, mmdeploy_text_detection_t** results, int** result_count); + + typedef int (*mmdeploy_text_detector_continue_t)(mmdeploy_text_detection_t* results, + int* result_count, + void* context, + mmdeploy_sender_t* output); + + // MMDEPLOY_API int mmdeploy_text_detector_apply_async_v2(mm_handle_t handle, const mm_mat_t* imgs, + // int img_count, + // mmdeploy_text_detector_continuation_t + // cont, void* context, mmdeploy_sender_t* + // output); + + MMDEPLOY_API int mmdeploy_text_detector_apply_async_v3(mmdeploy_text_detector_t detector, + const mmdeploy_mat_t* imgs, + int img_count, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, + mmdeploy_text_detector_continue_t cont, + void* context, + mmdeploy_sender_t* output); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.cpp b/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.cpp index 3c8cfbb5c6..4c94666add 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.cpp @@ -19,10 +19,12 @@ using namespace mmdeploy; -namespace { +namespace +{ -Value config_template(const Model& model) { - // clang-format off + Value config_template(const Model& model) + { + // clang-format off return { {"type", "Pipeline"}, {"input", {"imgs", "bboxes"}}, @@ -44,194 +46,238 @@ Value config_template(const Model& model) { }, {"output", "texts"}, }; - // clang-format on -} + // clang-format on + } } // namespace -int mmdeploy_text_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_text_recognizer_t* recognizer) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_text_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_text_recognizer_t* recognizer) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_text_recognizer_create_v2(model, context, recognizer); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_text_recognizer_create_v2(model, context, recognizer); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_text_recognizer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_text_recognizer_t* recognizer) { - auto config = config_template(*Cast(model)); - return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)recognizer); +int mmdeploy_text_recognizer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_text_recognizer_t* recognizer) +{ + auto config = config_template(*Cast(model)); + return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)recognizer); } -int mmdeploy_text_recognizer_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_text_recognizer_t* recognizer) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_text_recognizer_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_text_recognizer_t* recognizer) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_text_recognizer_create(model, device_name, device_id, recognizer); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_text_recognizer_create(model, device_name, device_id, recognizer); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_text_recognizer_apply(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* images, int count, - mmdeploy_text_recognition_t** results) { - return mmdeploy_text_recognizer_apply_bbox(recognizer, images, count, nullptr, nullptr, results); +int mmdeploy_text_recognizer_apply(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* images, + int count, + mmdeploy_text_recognition_t** results) +{ + return mmdeploy_text_recognizer_apply_bbox(recognizer, images, count, nullptr, nullptr, results); } -int mmdeploy_text_recognizer_create_input(const mmdeploy_mat_t* images, int image_count, - const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, mmdeploy_value_t* output) { - if (image_count && images == nullptr) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value::Array input_images; - Value::Array input_bboxes; - - auto add_bbox = [&](Mat img, const mmdeploy_text_detection_t* det) { - if (det) { - const auto& b = det->bbox; - Value::Array bbox{b[0].x, b[0].y, b[1].x, b[1].y, b[2].x, b[2].y, b[3].x, b[3].y}; - input_bboxes.push_back({{"bbox", std::move(bbox)}}); - } else { - input_bboxes.push_back(nullptr); - } - input_images.push_back({{"ori_img", img}}); - }; - - for (int i = 0; i < image_count; ++i) { - auto _mat = Cast(images[i]); - if (bboxes && bbox_count) { - for (int j = 0; j < bbox_count[i]; ++j) { - add_bbox(_mat, bboxes++); - } - } else { // inference with whole image - add_bbox(_mat, nullptr); - } +int mmdeploy_text_recognizer_create_input(const mmdeploy_mat_t* images, int image_count, const mmdeploy_text_detection_t* bboxes, const int* bbox_count, mmdeploy_value_t* output) +{ + if (image_count && images == nullptr) + { + return MMDEPLOY_E_INVALID_ARG; } + try + { + Value::Array input_images; + Value::Array input_bboxes; - *output = Take(Value{std::move(input_images), std::move(input_bboxes)}); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + auto add_bbox = [&](Mat img, const mmdeploy_text_detection_t* det) + { + if (det) + { + const auto& b = det->bbox; + Value::Array bbox{b[0].x, b[0].y, b[1].x, b[1].y, b[2].x, b[2].y, b[3].x, b[3].y}; + input_bboxes.push_back({{"bbox", std::move(bbox)}}); + } + else + { + input_bboxes.push_back(nullptr); + } + input_images.push_back({{"ori_img", img}}); + }; + + for (int i = 0; i < image_count; ++i) + { + auto _mat = Cast(images[i]); + if (bboxes && bbox_count) + { + for (int j = 0; j < bbox_count[i]; ++j) + { + add_bbox(_mat, bboxes++); + } + } + else + { // inference with whole image + add_bbox(_mat, nullptr); + } + } + + *output = Take(Value{std::move(input_images), std::move(input_bboxes)}); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_text_recognizer_apply_bbox(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* images, int image_count, +int mmdeploy_text_recognizer_apply_bbox(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* images, + int image_count, const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, - mmdeploy_text_recognition_t** results) { - wrapped input; - if (auto ec = mmdeploy_text_recognizer_create_input(images, image_count, bboxes, bbox_count, - input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_text_recognizer_apply_v2(recognizer, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_text_recognizer_get_result(output, results)) { - return ec; - } - return MMDEPLOY_SUCCESS; + const int* bbox_count, + mmdeploy_text_recognition_t** results) +{ + wrapped input; + if (auto ec = mmdeploy_text_recognizer_create_input(images, image_count, bboxes, bbox_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_text_recognizer_apply_v2(recognizer, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_text_recognizer_get_result(output, results)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_text_recognizer_apply_v2(mmdeploy_text_recognizer_t recognizer, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)recognizer, input, output); +int mmdeploy_text_recognizer_apply_v2(mmdeploy_text_recognizer_t recognizer, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)recognizer, input, output); } int mmdeploy_text_recognizer_apply_async(mmdeploy_text_recognizer_t recognizer, - mmdeploy_sender_t input, mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)recognizer, input, output); + mmdeploy_sender_t input, + mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)recognizer, input, output); } -MMDEPLOY_API int mmdeploy_text_recognizer_get_result(mmdeploy_value_t output, - mmdeploy_text_recognition_t** results) { - if (!output || !results) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - std::vector recognitions; - from_value(Cast(output)->front(), recognitions); +MMDEPLOY_API int mmdeploy_text_recognizer_get_result(mmdeploy_value_t output, + mmdeploy_text_recognition_t** results) +{ + if (!output || !results) + { + return MMDEPLOY_E_INVALID_ARG; + } + try + { + std::vector recognitions; + from_value(Cast(output)->front(), recognitions); - size_t count = recognitions.size(); + size_t count = recognitions.size(); - auto deleter = [&](mmdeploy_text_recognition_t* p) { - mmdeploy_text_recognizer_release_result(p, static_cast(count)); - }; + auto deleter = [&](mmdeploy_text_recognition_t* p) + { + mmdeploy_text_recognizer_release_result(p, static_cast(count)); + }; - std::unique_ptr _results( - new mmdeploy_text_recognition_t[count]{}, deleter); + std::unique_ptr _results( + new mmdeploy_text_recognition_t[count]{}, + deleter); - size_t result_idx = 0; - for (const auto& bbox_result : recognitions) { - auto& res = _results[result_idx++]; + size_t result_idx = 0; + for (const auto& bbox_result : recognitions) + { + auto& res = _results[result_idx++]; - auto& score = bbox_result.score; - res.length = static_cast(score.size()); + auto& score = bbox_result.score; + res.length = static_cast(score.size()); - res.score = new float[score.size()]; - std::copy_n(score.data(), score.size(), res.score); + res.score = new float[score.size()]; + std::copy_n(score.data(), score.size(), res.score); - auto text = bbox_result.text; - res.text = new char[text.length() + 1]; - std::copy_n(text.data(), text.length() + 1, res.text); - } + auto text = bbox_result.text; + res.text = new char[text.length() + 1]; + std::copy_n(text.data(), text.length() + 1, res.text); + } - *results = _results.release(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_SUCCESS; + *results = _results.release(); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_SUCCESS; } -void mmdeploy_text_recognizer_release_result(mmdeploy_text_recognition_t* results, int count) { - for (int i = 0; i < count; ++i) { - delete[] results[i].score; - delete[] results[i].text; - } - delete[] results; +void mmdeploy_text_recognizer_release_result(mmdeploy_text_recognition_t* results, int count) +{ + for (int i = 0; i < count; ++i) + { + delete[] results[i].score; + delete[] results[i].text; + } + delete[] results; } -void mmdeploy_text_recognizer_destroy(mmdeploy_text_recognizer_t recognizer) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)recognizer); +void mmdeploy_text_recognizer_destroy(mmdeploy_text_recognizer_t recognizer) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)recognizer); } -int mmdeploy_text_recognizer_apply_async_v3(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* imgs, int img_count, +int mmdeploy_text_recognizer_apply_async_v3(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* imgs, + int img_count, const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, mmdeploy_sender_t* output) { - wrapped input_val; - if (auto ec = mmdeploy_text_recognizer_create_input(imgs, img_count, bboxes, bbox_count, - input_val.ptr())) { - return ec; - } - mmdeploy_sender_t input_sndr = mmdeploy_executor_just(input_val); - if (auto ec = mmdeploy_text_recognizer_apply_async(recognizer, input_sndr, output)) { - return ec; - } - return MMDEPLOY_SUCCESS; + const int* bbox_count, + mmdeploy_sender_t* output) +{ + wrapped input_val; + if (auto ec = mmdeploy_text_recognizer_create_input(imgs, img_count, bboxes, bbox_count, input_val.ptr())) + { + return ec; + } + mmdeploy_sender_t input_sndr = mmdeploy_executor_just(input_val); + if (auto ec = mmdeploy_text_recognizer_apply_async(recognizer, input_sndr, output)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, - mmdeploy_text_recognizer_continue_t cont, void* context, - mmdeploy_sender_t* output) { - auto sender = Guard([&] { - return Take( - LetValue(Take(input), [fn = cont, context](Value& value) -> TypeErasedSender { +int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, + mmdeploy_text_recognizer_continue_t cont, + void* context, + mmdeploy_sender_t* output) +{ + auto sender = Guard([&] + { return Take( + LetValue(Take(input), [fn = cont, context](Value& value) -> TypeErasedSender + { mmdeploy_text_recognition_t* results{}; if (auto ec = mmdeploy_text_recognizer_get_result(Cast(&value), &results)) { return Just(Value()); @@ -241,12 +287,11 @@ int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, if (auto ec = fn(results, context, &output); ec || !output) { return Just(Value()); } - return Take(output); - })); - }); - if (sender) { - *output = sender; - return MMDEPLOY_SUCCESS; - } - return MMDEPLOY_E_FAIL; + return Take(output); })); }); + if (sender) + { + *output = sender; + return MMDEPLOY_SUCCESS; + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.h b/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.h index 6c18928242..f20c878028 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.h @@ -13,149 +13,155 @@ #include "mmdeploy/text_detector.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_text_recognition_t { - char* text; - float* score; - int length; -} mmdeploy_text_recognition_t; - -typedef struct mmdeploy_text_recognizer* mmdeploy_text_recognizer_t; - -/** - * @brief Create a text recognizer instance - * @param[in] model an instance of mmocr text recognition model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] recognizer handle of the created text recognizer, which must be destroyed - * by \ref mmdeploy_text_recognizer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_create(mmdeploy_model_t model, const char* device_name, - int device_id, - mmdeploy_text_recognizer_t* recognizer); - -/** - * @brief Create a text recognizer instance - * @param[in] model_path path to text recognition model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] recognizer handle of the created text recognizer, which must be destroyed - * by \ref mmdeploy_text_recognizer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_text_recognizer_t* recognizer); - -/** - * @brief Apply text recognizer to a batch of text images - * @param[in] recognizer text recognizer's handle created by \ref - * mmdeploy_text_recognizer_create_by_path - * @param[in] images a batch of text images - * @param[in] count number of images in the batch - * @param[out] results a linear buffer contains the recognized text, must be release - * by \ref mmdeploy_text_recognizer_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_apply(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* images, int count, - mmdeploy_text_recognition_t** results); - -/** - * @brief Apply text recognizer to a batch of images supplied with text bboxes - * @param[in] recognizer text recognizer's handle created by \ref - * mmdeploy_text_recognizer_create_by_path - * @param[in] images a batch of text images - * @param[in] image_count number of images in the batch - * @param[in] bboxes bounding boxes detected by text detector - * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images - * @param[out] results a linear buffer contains the recognized text, which has the same length as \p - * bboxes, must be release by \ref mmdeploy_text_recognizer_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_apply_bbox(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* images, int image_count, - const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, - mmdeploy_text_recognition_t** results); - -/** @brief Release result buffer returned by \ref mmdeploy_text_recognizer_apply or \ref - * mmdeploy_text_recognizer_apply_bbox - * @param[in] results result buffer by text recognizer - * @param[in] count length of \p result - */ -MMDEPLOY_API void mmdeploy_text_recognizer_release_result(mmdeploy_text_recognition_t* results, - int count); - -/** - * @brief destroy text recognizer - * @param[in] recognizer handle of text recognizer created by \ref - * mmdeploy_text_recognizer_create_by_path or \ref mmdeploy_text_recognizer_create - */ -MMDEPLOY_API void mmdeploy_text_recognizer_destroy(mmdeploy_text_recognizer_t recognizer); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_text_recognizer_create, but allows to control execution context of - * tasks via context - */ -MMDEPLOY_API int mmdeploy_text_recognizer_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_text_recognizer_t* recognizer); - -/** - * @brief Pack text-recognizer inputs into mmdeploy_value_t - * @param[in] images a batch of images - * @param[in] image_count number of images in the batch - * @param[in] bboxes bounding boxes detected by text detector - * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images - * @return value created - */ -MMDEPLOY_API int mmdeploy_text_recognizer_create_input(const mmdeploy_mat_t* images, - int image_count, - const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, - mmdeploy_value_t* output); - -MMDEPLOY_API int mmdeploy_text_recognizer_apply_v2(mmdeploy_text_recognizer_t recognizer, - mmdeploy_value_t input, - mmdeploy_value_t* output); - -/** - * @brief Same as \ref mmdeploy_text_recognizer_apply_bbox, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_text_recognizer_apply_async(mmdeploy_text_recognizer_t recognizer, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -typedef int (*mmdeploy_text_recognizer_continue_t)(mmdeploy_text_recognition_t* results, - void* context, mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_text_recognizer_apply_async_v3(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* imgs, int img_count, - const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, - mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, - mmdeploy_text_recognizer_continue_t cont, - void* context, mmdeploy_sender_t* output); - -/** - * @brief Unpack text-recognizer output from a mmdeploy_value_t - * @param[in] output - * @param[out] results - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_get_result(mmdeploy_value_t output, - mmdeploy_text_recognition_t** results); + typedef struct mmdeploy_text_recognition_t + { + char* text; + float* score; + int length; + } mmdeploy_text_recognition_t; + + typedef struct mmdeploy_text_recognizer* mmdeploy_text_recognizer_t; + + /** + * @brief Create a text recognizer instance + * @param[in] model an instance of mmocr text recognition model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] recognizer handle of the created text recognizer, which must be destroyed + * by \ref mmdeploy_text_recognizer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_text_recognizer_t* recognizer); + + /** + * @brief Create a text recognizer instance + * @param[in] model_path path to text recognition model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] recognizer handle of the created text recognizer, which must be destroyed + * by \ref mmdeploy_text_recognizer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_text_recognizer_t* recognizer); + + /** + * @brief Apply text recognizer to a batch of text images + * @param[in] recognizer text recognizer's handle created by \ref + * mmdeploy_text_recognizer_create_by_path + * @param[in] images a batch of text images + * @param[in] count number of images in the batch + * @param[out] results a linear buffer contains the recognized text, must be release + * by \ref mmdeploy_text_recognizer_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_apply(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* images, + int count, + mmdeploy_text_recognition_t** results); + + /** + * @brief Apply text recognizer to a batch of images supplied with text bboxes + * @param[in] recognizer text recognizer's handle created by \ref + * mmdeploy_text_recognizer_create_by_path + * @param[in] images a batch of text images + * @param[in] image_count number of images in the batch + * @param[in] bboxes bounding boxes detected by text detector + * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images + * @param[out] results a linear buffer contains the recognized text, which has the same length as \p + * bboxes, must be release by \ref mmdeploy_text_recognizer_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_apply_bbox(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* images, + int image_count, + const mmdeploy_text_detection_t* bboxes, + const int* bbox_count, + mmdeploy_text_recognition_t** results); + + /** @brief Release result buffer returned by \ref mmdeploy_text_recognizer_apply or \ref + * mmdeploy_text_recognizer_apply_bbox + * @param[in] results result buffer by text recognizer + * @param[in] count length of \p result + */ + MMDEPLOY_API void mmdeploy_text_recognizer_release_result(mmdeploy_text_recognition_t* results, + int count); + + /** + * @brief destroy text recognizer + * @param[in] recognizer handle of text recognizer created by \ref + * mmdeploy_text_recognizer_create_by_path or \ref mmdeploy_text_recognizer_create + */ + MMDEPLOY_API void mmdeploy_text_recognizer_destroy(mmdeploy_text_recognizer_t recognizer); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_text_recognizer_create, but allows to control execution context of + * tasks via context + */ + MMDEPLOY_API int mmdeploy_text_recognizer_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_text_recognizer_t* recognizer); + + /** + * @brief Pack text-recognizer inputs into mmdeploy_value_t + * @param[in] images a batch of images + * @param[in] image_count number of images in the batch + * @param[in] bboxes bounding boxes detected by text detector + * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images + * @return value created + */ + MMDEPLOY_API int mmdeploy_text_recognizer_create_input(const mmdeploy_mat_t* images, + int image_count, + const mmdeploy_text_detection_t* bboxes, + const int* bbox_count, + mmdeploy_value_t* output); + + MMDEPLOY_API int mmdeploy_text_recognizer_apply_v2(mmdeploy_text_recognizer_t recognizer, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Same as \ref mmdeploy_text_recognizer_apply_bbox, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_text_recognizer_apply_async(mmdeploy_text_recognizer_t recognizer, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + typedef int (*mmdeploy_text_recognizer_continue_t)(mmdeploy_text_recognition_t* results, + void* context, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_text_recognizer_apply_async_v3(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* imgs, + int img_count, + const mmdeploy_text_detection_t* bboxes, + const int* bbox_count, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, + mmdeploy_text_recognizer_continue_t cont, + void* context, + mmdeploy_sender_t* output); + + /** + * @brief Unpack text-recognizer output from a mmdeploy_value_t + * @param[in] output + * @param[out] results + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_get_result(mmdeploy_value_t output, + mmdeploy_text_recognition_t** results); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.cpp b/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.cpp index de71e57842..3f0ab3c305 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.cpp @@ -20,146 +20,178 @@ using namespace mmdeploy; -int mmdeploy_video_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_video_recognizer_t* recognizer) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_video_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_video_recognizer_t* recognizer) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_video_recognizer_create_v2(model, context, recognizer); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_video_recognizer_create_v2(model, context, recognizer); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_video_recognizer_create_by_path(const char* model_path, const char* device_name, - int device_id, - mmdeploy_video_recognizer_t* recognizer) { - mmdeploy_model_t model{}; +int mmdeploy_video_recognizer_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_video_recognizer_t* recognizer) +{ + mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_video_recognizer_create(model, device_name, device_id, recognizer); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_video_recognizer_create(model, device_name, device_id, recognizer); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_video_recognizer_apply(mmdeploy_video_recognizer_t recognizer, - const mmdeploy_mat_t* images, - const mmdeploy_video_sample_info_t* video_info, int video_count, - mmdeploy_video_recognition_t** results, int** result_count) { - wrapped input; - if (auto ec = - mmdeploy_video_recognizer_create_input(images, video_info, video_count, input.ptr())) { - return ec; - } +int mmdeploy_video_recognizer_apply(mmdeploy_video_recognizer_t recognizer, + const mmdeploy_mat_t* images, + const mmdeploy_video_sample_info_t* video_info, + int video_count, + mmdeploy_video_recognition_t** results, + int** result_count) +{ + wrapped input; + if (auto ec = + mmdeploy_video_recognizer_create_input(images, video_info, video_count, input.ptr())) + { + return ec; + } - wrapped output; - if (auto ec = mmdeploy_video_recognizer_apply_v2(recognizer, input, output.ptr())) { - return ec; - } + wrapped output; + if (auto ec = mmdeploy_video_recognizer_apply_v2(recognizer, input, output.ptr())) + { + return ec; + } - if (auto ec = mmdeploy_video_recognizer_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; + if (auto ec = mmdeploy_video_recognizer_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } void mmdeploy_video_recognizer_release_result(mmdeploy_video_recognition_t* results, - int* result_count, int video_count) { - delete[] results; - delete[] result_count; + int* result_count, + int video_count) +{ + delete[] results; + delete[] result_count; } -void mmdeploy_video_recognizer_destroy(mmdeploy_video_recognizer_t recognizer) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)recognizer); +void mmdeploy_video_recognizer_destroy(mmdeploy_video_recognizer_t recognizer) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)recognizer); } -int mmdeploy_video_recognizer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_video_recognizer_t* recognizer) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)recognizer); +int mmdeploy_video_recognizer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_video_recognizer_t* recognizer) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)recognizer); } -int mmdeploy_video_recognizer_create_input(const mmdeploy_mat_t* images, +int mmdeploy_video_recognizer_create_input(const mmdeploy_mat_t* images, const mmdeploy_video_sample_info_t* video_info, - int video_count, mmdeploy_value_t* value) { - if (video_count && (images == nullptr || video_info == nullptr)) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - auto input = std::make_unique(Value{Value::kArray}); - auto sample = std::make_unique(Value::kArray); - for (int i = 0; i < video_count; ++i) { - int clip_len = video_info[i].clip_len; - int num_clips = video_info[i].num_clips; - int n_mat = clip_len * num_clips; - for (int j = 0; j < n_mat; j++) { - mmdeploy::Mat _mat{images[j].height, - images[j].width, - PixelFormat(images[j].format), - DataType(images[j].type), - images[j].data, - images[j].device ? *(const Device*)(images[j].device) : Device{0}}; - sample->push_back({{"ori_img", _mat}, {"clip_len", clip_len}, {"num_clips", num_clips}}); - } - input->front().push_back(std::move(*sample.release())); + int video_count, + mmdeploy_value_t* value) +{ + if (video_count && (images == nullptr || video_info == nullptr)) + { + return MMDEPLOY_E_INVALID_ARG; + } + try + { + auto input = std::make_unique(Value{Value::kArray}); + auto sample = std::make_unique(Value::kArray); + for (int i = 0; i < video_count; ++i) + { + int clip_len = video_info[i].clip_len; + int num_clips = video_info[i].num_clips; + int n_mat = clip_len * num_clips; + for (int j = 0; j < n_mat; j++) + { + mmdeploy::Mat _mat{images[j].height, + images[j].width, + PixelFormat(images[j].format), + DataType(images[j].type), + images[j].data, + images[j].device ? *(const Device*)(images[j].device) : Device{0}}; + sample->push_back({{"ori_img", _mat}, {"clip_len", clip_len}, {"num_clips", num_clips}}); + } + input->front().push_back(std::move(*sample.release())); + } + *value = Cast(input.release()); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); } - *value = Cast(input.release()); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_SUCCESS; + return MMDEPLOY_SUCCESS; } int mmdeploy_video_recognizer_apply_v2(mmdeploy_video_recognizer_t recognizer, - mmdeploy_value_t input, mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)recognizer, input, output); + mmdeploy_value_t input, + mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)recognizer, input, output); } -int mmdeploy_video_recognizer_get_result(mmdeploy_value_t output, +int mmdeploy_video_recognizer_get_result(mmdeploy_value_t output, mmdeploy_video_recognition_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value& value = Cast(output)->front(); - - auto classify_outputs = from_value>(value); - - std::vector _result_count; - _result_count.reserve(classify_outputs.size()); - - for (const auto& cls_output : classify_outputs) { - _result_count.push_back((int)cls_output.size()); + int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } - - auto total = std::accumulate(begin(_result_count), end(_result_count), 0); - - std::unique_ptr result_count_data(new int[_result_count.size()]{}); - std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); - - std::unique_ptr result_data( - new mmdeploy_video_recognition_t[total]{}); - auto result_ptr = result_data.get(); - for (const auto& cls_output : classify_outputs) { - for (const auto& label : cls_output) { - result_ptr->label_id = label.label_id; - result_ptr->score = label.score; - ++result_ptr; - } + try + { + Value& value = Cast(output)->front(); + + auto classify_outputs = from_value>(value); + + std::vector _result_count; + _result_count.reserve(classify_outputs.size()); + + for (const auto& cls_output : classify_outputs) + { + _result_count.push_back((int)cls_output.size()); + } + + auto total = std::accumulate(begin(_result_count), end(_result_count), 0); + + std::unique_ptr result_count_data(new int[_result_count.size()]{}); + std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); + + std::unique_ptr result_data( + new mmdeploy_video_recognition_t[total]{}); + auto result_ptr = result_data.get(); + for (const auto& cls_output : classify_outputs) + { + for (const auto& label : cls_output) + { + result_ptr->label_id = label.label_id; + result_ptr->score = label.score; + ++result_ptr; + } + } + + *result_count = result_count_data.release(); + *results = result_data.release(); + + return MMDEPLOY_SUCCESS; } - - *result_count = result_count_data.release(); - *results = result_data.release(); - - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.h b/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.h index e98b2bd07e..6893170e7d 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.h @@ -13,124 +13,129 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_video_recognition_t { - int label_id; - float score; -} mmdeploy_video_recognition_t; - -typedef struct mmdeploy_video_sample_info_t { - int clip_len; - int num_clips; -} mmdeploy_video_sample_info_t; - -typedef struct mmdeploy_video_recognizer* mmdeploy_video_recognizer_t; - -/** - * @brief Create video recognizer's handle - * @param[in] model an instance of mmaction sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] recognizer handle of the created video recognizer, which must be destroyed - * by \ref mmdeploy_video_recognizer_destroy - * @return status of creating video recognizer's handle - */ -MMDEPLOY_API int mmdeploy_video_recognizer_create(mmdeploy_model_t model, const char* device_name, - int device_id, - mmdeploy_video_recognizer_t* recognizer); - -/** - * @brief Create a video recognizer instance - * @param[in] model_path path to video recognition model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] recognizer handle of the created video recognizer, which must be destroyed - * by \ref mmdeploy_video_recognizer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_video_recognizer_t* recognizer); - -/** - * @brief Apply video recognizer to a batch of videos - * @param[in] recognizer video recognizer's handle created by \ref - * mmdeploy_video_recognizer_create_by_path - * @param[in] images a batch of videos - * @param[in] video_info video information of each video - * @param[in] video_count number of videos - * @param[out] results a linear buffer contains the recognized video, must be release - * by \ref mmdeploy_video_recognizer_release_result - * @param[out] result_count a linear buffer with length being \p video_count to save the number of - * recognition results of each video. It must be released by \ref - * mmdeploy_video_recognizer_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_apply(mmdeploy_video_recognizer_t recognizer, - const mmdeploy_mat_t* images, - const mmdeploy_video_sample_info_t* video_info, - int video_count, - mmdeploy_video_recognition_t** results, - int** result_count); - -/** @brief Release result buffer returned by \ref mmdeploy_video_recognizer_apply - * @param[in] results result buffer by video recognizer - * @param[in] result_count \p results size buffer - * @param[in] video_count length of \p result_count - */ -MMDEPLOY_API void mmdeploy_video_recognizer_release_result(mmdeploy_video_recognition_t* results, - int* result_count, int video_count); - -/** - * @brief destroy video recognizer - * @param[in] recognizer handle of video recognizer created by \ref - * mmdeploy_video_recognizer_create_by_path or \ref mmdeploy_video_recognizer_create - */ -MMDEPLOY_API void mmdeploy_video_recognizer_destroy(mmdeploy_video_recognizer_t recognizer); - -/** - * @brief Same as \ref mmdeploy_video_recognizer_create, but allows to control execution context of - * tasks via context - */ -MMDEPLOY_API int mmdeploy_video_recognizer_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_video_recognizer_t* recognizer); - -/** - * @brief Pack video recognizer inputs into mmdeploy_value_t - * @param[in] images a batch of videos - * @param[in] video_info video information of each video - * @param[in] video_count number of videos in the batch - * @param[out] value created value - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_create_input( - const mmdeploy_mat_t* images, const mmdeploy_video_sample_info_t* video_info, int video_count, - mmdeploy_value_t* value); - -/** - * @brief Apply video recognizer to a batch of videos - * @param[in] input packed input - * @param[out] output inference output - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_apply_v2(mmdeploy_video_recognizer_t recognizer, - mmdeploy_value_t input, - mmdeploy_value_t* output); - -/** - * @brief Apply video recognizer to a batch of videos - * @param[in] output inference output - * @param[out] results structured output - * @param[out] result_count number of each videos - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_get_result(mmdeploy_value_t output, - mmdeploy_video_recognition_t** results, - int** result_count); + typedef struct mmdeploy_video_recognition_t + { + int label_id; + float score; + } mmdeploy_video_recognition_t; + + typedef struct mmdeploy_video_sample_info_t + { + int clip_len; + int num_clips; + } mmdeploy_video_sample_info_t; + + typedef struct mmdeploy_video_recognizer* mmdeploy_video_recognizer_t; + + /** + * @brief Create video recognizer's handle + * @param[in] model an instance of mmaction sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] recognizer handle of the created video recognizer, which must be destroyed + * by \ref mmdeploy_video_recognizer_destroy + * @return status of creating video recognizer's handle + */ + MMDEPLOY_API int mmdeploy_video_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_video_recognizer_t* recognizer); + + /** + * @brief Create a video recognizer instance + * @param[in] model_path path to video recognition model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] recognizer handle of the created video recognizer, which must be destroyed + * by \ref mmdeploy_video_recognizer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_video_recognizer_t* recognizer); + + /** + * @brief Apply video recognizer to a batch of videos + * @param[in] recognizer video recognizer's handle created by \ref + * mmdeploy_video_recognizer_create_by_path + * @param[in] images a batch of videos + * @param[in] video_info video information of each video + * @param[in] video_count number of videos + * @param[out] results a linear buffer contains the recognized video, must be release + * by \ref mmdeploy_video_recognizer_release_result + * @param[out] result_count a linear buffer with length being \p video_count to save the number of + * recognition results of each video. It must be released by \ref + * mmdeploy_video_recognizer_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_apply(mmdeploy_video_recognizer_t recognizer, + const mmdeploy_mat_t* images, + const mmdeploy_video_sample_info_t* video_info, + int video_count, + mmdeploy_video_recognition_t** results, + int** result_count); + + /** @brief Release result buffer returned by \ref mmdeploy_video_recognizer_apply + * @param[in] results result buffer by video recognizer + * @param[in] result_count \p results size buffer + * @param[in] video_count length of \p result_count + */ + MMDEPLOY_API void mmdeploy_video_recognizer_release_result(mmdeploy_video_recognition_t* results, + int* result_count, + int video_count); + + /** + * @brief destroy video recognizer + * @param[in] recognizer handle of video recognizer created by \ref + * mmdeploy_video_recognizer_create_by_path or \ref mmdeploy_video_recognizer_create + */ + MMDEPLOY_API void mmdeploy_video_recognizer_destroy(mmdeploy_video_recognizer_t recognizer); + + /** + * @brief Same as \ref mmdeploy_video_recognizer_create, but allows to control execution context of + * tasks via context + */ + MMDEPLOY_API int mmdeploy_video_recognizer_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_video_recognizer_t* recognizer); + + /** + * @brief Pack video recognizer inputs into mmdeploy_value_t + * @param[in] images a batch of videos + * @param[in] video_info video information of each video + * @param[in] video_count number of videos in the batch + * @param[out] value created value + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_create_input( + const mmdeploy_mat_t* images, + const mmdeploy_video_sample_info_t* video_info, + int video_count, + mmdeploy_value_t* value); + + /** + * @brief Apply video recognizer to a batch of videos + * @param[in] input packed input + * @param[out] output inference output + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_apply_v2(mmdeploy_video_recognizer_t recognizer, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Apply video recognizer to a batch of videos + * @param[in] output inference output + * @param[out] results structured output + * @param[out] result_count number of each videos + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_get_result(mmdeploy_value_t output, + mmdeploy_video_recognition_t** results, + int** result_count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/cxx/CMakeLists.txt b/csrc/mmdeploy/apis/cxx/CMakeLists.txt index 0ee897ca4d..9073665516 100644 --- a/csrc/mmdeploy/apis/cxx/CMakeLists.txt +++ b/csrc/mmdeploy/apis/cxx/CMakeLists.txt @@ -4,41 +4,44 @@ cmake_minimum_required(VERSION 3.14) project(mmdeploy_cxx_api) add_library(${PROJECT_NAME} INTERFACE) -target_include_directories(${PROJECT_NAME} INTERFACE - $ - $) +target_include_directories( + ${PROJECT_NAME} INTERFACE $ + $) target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) set(_tasks ${MMDEPLOY_TASKS} pipeline) -foreach (task ${_tasks}) - target_link_libraries(mmdeploy_${task} INTERFACE ${PROJECT_NAME}) - install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${task}.hpp - DESTINATION include/mmdeploy) -endforeach () -if (TARGET mmdeploy) - target_include_directories(${PROJECT_NAME} INTERFACE - $ - $ - $ - ) - target_include_directories(${PROJECT_NAME} INTERFACE - $ - $ - $ - ) - if (NOT MMDEPLOY_SPDLOG_EXTERNAL) - target_include_directories(${PROJECT_NAME} INTERFACE - $ - $) - endif () - target_link_libraries(mmdeploy INTERFACE ${PROJECT_NAME}) -else () - target_link_libraries(${PROJECT_NAME} INTERFACE mmdeploy::core) -endif () +foreach(task ${_tasks}) + target_link_libraries(mmdeploy_${task} INTERFACE ${PROJECT_NAME}) + install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/${task}.hpp + DESTINATION include/mmdeploy) +endforeach() +if(TARGET mmdeploy) + target_include_directories( + ${PROJECT_NAME} + INTERFACE $ + $ + $) + target_include_directories( + ${PROJECT_NAME} + INTERFACE $ + $ + $) + if(NOT MMDEPLOY_SPDLOG_EXTERNAL) + target_include_directories( + ${PROJECT_NAME} + INTERFACE + $ + $) + endif() + target_link_libraries(mmdeploy INTERFACE ${PROJECT_NAME}) +else() + target_link_libraries(${PROJECT_NAME} INTERFACE mmdeploy::core) +endif() mmdeploy_export_impl(${PROJECT_NAME}) install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/common.hpp DESTINATION include/mmdeploy) -install(DIRECTORY ${CMAKE_SOURCE_DIR}/demo/csrc/ DESTINATION example/cpp - FILES_MATCHING - PATTERN "*.cxx" - PATTERN "*.h" - ) +install( + DIRECTORY ${CMAKE_SOURCE_DIR}/demo/csrc/ + DESTINATION example/cpp + FILES_MATCHING + PATTERN "*.cxx" + PATTERN "*.h") diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/classifier.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/classifier.hpp index 1d9880fb7d..5ba395ad77 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/classifier.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/classifier.hpp @@ -6,68 +6,87 @@ #include "mmdeploy/classifier.h" #include "mmdeploy/common.hpp" -namespace mmdeploy { - -namespace cxx { - -using Classification = mmdeploy_classification_t; - -class Classifier : public NonMovable { - public: - Classifier(const Model& model, const Context& context) { - auto ec = mmdeploy_classifier_create_v2(model, context, &classifier_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~Classifier() { - if (classifier_) { - mmdeploy_classifier_destroy(classifier_); - classifier_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - Classification* results{}; - int* result_count{}; - auto ec = mmdeploy_classifier_apply(classifier_, reinterpret(images.data()), - static_cast(images.size()), &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::vector rets; - rets.reserve(images.size()); - - std::shared_ptr data(results, [result_count, count = images.size()](auto p) { - mmdeploy_classifier_release_result(p, result_count, count); - }); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& img) { return Apply(Span{img})[0]; } - - private: - mmdeploy_classifier_t classifier_{}; -}; - -} // namespace cxx - -using cxx::Classification; -using cxx::Classifier; +namespace mmdeploy +{ + + namespace cxx + { + + using Classification = mmdeploy_classification_t; + + class Classifier : public NonMovable + { + public: + Classifier(const Model& model, const Context& context) + { + auto ec = mmdeploy_classifier_create_v2(model, context, &classifier_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~Classifier() + { + if (classifier_) + { + mmdeploy_classifier_destroy(classifier_); + classifier_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + Classification* results{}; + int* result_count{}; + auto ec = mmdeploy_classifier_apply(classifier_, + reinterpret(images.data()), + static_cast(images.size()), + &results, + &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::vector rets; + rets.reserve(images.size()); + + std::shared_ptr data(results, + [result_count, count = images.size()](auto p) + { + mmdeploy_classifier_release_result(p, result_count, count); + }); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& img) + { + return Apply(Span{img})[0]; + } + + private: + mmdeploy_classifier_t classifier_{}; + }; + + } // namespace cxx + + using cxx::Classification; + using cxx::Classifier; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/common.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/common.hpp index 610c3a8b9e..07b6b225b2 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/common.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/common.hpp @@ -16,253 +16,432 @@ #include "mmdeploy/model.h" #ifndef MMDEPLOY_CXX_USE_OPENCV -#define MMDEPLOY_CXX_USE_OPENCV 1 + #define MMDEPLOY_CXX_USE_OPENCV 1 #endif #if MMDEPLOY_CXX_USE_OPENCV -#include "opencv2/core/core.hpp" + #include "opencv2/core/core.hpp" #endif -namespace mmdeploy { - -namespace cxx { - -using Rect = mmdeploy_rect_t; - -template -class UniqueHandle : public NonCopyable { - public: - UniqueHandle() = default; - explicit UniqueHandle(T handle) : handle_(handle) {} - - // derived class must destroy the object and reset `handle_` - ~UniqueHandle() { assert(handle_ == nullptr); } - - UniqueHandle(UniqueHandle&& o) noexcept : handle_(std::exchange(o.handle_, nullptr)) {} - UniqueHandle& operator=(UniqueHandle&& o) noexcept { - if (this != &o) { - handle_ = std::exchange(o.handle_, nullptr); - } - return *this; - } - - explicit operator T() const noexcept { return handle_; } - T operator->() const noexcept { return handle_; } - - protected: - T handle_{}; -}; - -class Model { - public: - explicit Model(const char* path) { - mmdeploy_model_t model{}; - auto ec = mmdeploy_model_create_by_path(path, &model); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - model_.reset(model, [](auto p) { mmdeploy_model_destroy(p); }); - } - - explicit Model(const std::string& path) : Model(path.c_str()) {} - - Model(const void* buffer, size_t size) { - mmdeploy_model_t model{}; - auto ec = mmdeploy_model_create(buffer, static_cast(size), &model); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - model_.reset(model, [](auto p) { mmdeploy_model_destroy(p); }); - } - - operator mmdeploy_model_t() const noexcept { return model_.get(); } - - private: - std::shared_ptr model_{}; -}; - -class Device { - public: - explicit Device(std::string name, int index = 0) : name_(std::move(name)), index_(index) { - mmdeploy_device_t device{}; - auto ec = mmdeploy_device_create(name_.c_str(), index, &device); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - device_.reset(device, [](auto p) { mmdeploy_device_destroy(p); }); - } - - const char* name() const noexcept { return name_.c_str(); } - int index() const noexcept { return index_; } - - operator mmdeploy_device_t() const noexcept { return device_.get(); } - - private: - std::string name_; - int index_; - std::shared_ptr device_; -}; - -class Profiler { - public: - explicit Profiler(std::string_view path) : path_(path) { - mmdeploy_profiler_t profiler{}; - auto ec = mmdeploy_profiler_create(path_.c_str(), &profiler); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - profiler_.reset(profiler, [](auto p) { mmdeploy_profiler_destroy(p); }); - }; - - operator mmdeploy_profiler_t() const noexcept { return profiler_.get(); } - - private: - std::string path_; - std::shared_ptr profiler_; -}; - -class Mat { - public: - Mat() : desc_{} {} - - Mat(int height, int width, int channels, mmdeploy_pixel_format_t format, - mmdeploy_data_type_t type, uint8_t* data, mmdeploy_device_t device = nullptr) - : desc_{data, height, width, channels, format, type, device} {} - - Mat(const mmdeploy_mat_t& desc) : desc_(desc) {} // NOLINT - - const mmdeploy_mat_t& desc() const noexcept { return desc_; } +namespace mmdeploy +{ + + namespace cxx + { + + using Rect = mmdeploy_rect_t; + + template + class UniqueHandle : public NonCopyable + { + public: + UniqueHandle() = default; + explicit UniqueHandle(T handle) + : handle_(handle) + { + } + + // derived class must destroy the object and reset `handle_` + ~UniqueHandle() + { + assert(handle_ == nullptr); + } + + UniqueHandle(UniqueHandle&& o) noexcept + : handle_(std::exchange(o.handle_, nullptr)) + { + } + + UniqueHandle& operator=(UniqueHandle&& o) noexcept + { + if (this != &o) + { + handle_ = std::exchange(o.handle_, nullptr); + } + return *this; + } + + explicit operator T() const noexcept + { + return handle_; + } + + T operator->() const noexcept + { + return handle_; + } + + protected: + T handle_{}; + }; + + class Model + { + public: + explicit Model(const char* path) + { + mmdeploy_model_t model{}; + auto ec = mmdeploy_model_create_by_path(path, &model); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + model_.reset(model, + [](auto p) + { + mmdeploy_model_destroy(p); + }); + } + + explicit Model(const std::string& path) + : Model(path.c_str()) + { + } + + Model(const void* buffer, size_t size) + { + mmdeploy_model_t model{}; + auto ec = mmdeploy_model_create(buffer, + static_cast(size), + &model); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + model_.reset(model, + [](auto p) + { + mmdeploy_model_destroy(p); + }); + } + + operator mmdeploy_model_t() const noexcept + { + return model_.get(); + } + + private: + std::shared_ptr model_{}; + }; + + class Device + { + public: + explicit Device(std::string name, int index = 0) + : name_(std::move(name)) + , index_(index) + { + mmdeploy_device_t device{}; + auto ec = mmdeploy_device_create(name_.c_str(), + index, + &device); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + device_.reset(device, + [](auto p) + { + mmdeploy_device_destroy(p); + }); + } + + const char* name() const noexcept + { + return name_.c_str(); + } + + int index() const noexcept + { + return index_; + } + + operator mmdeploy_device_t() const noexcept + { + return device_.get(); + } + + private: + std::string name_; + int index_; + std::shared_ptr device_; + }; + + class Profiler + { + public: + explicit Profiler(std::string_view path) + : path_(path) + { + mmdeploy_profiler_t profiler{}; + auto ec = mmdeploy_profiler_create(path_.c_str(), &profiler); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + profiler_.reset(profiler, + [](auto p) + { + mmdeploy_profiler_destroy(p); + }); + }; + + operator mmdeploy_profiler_t() const noexcept + { + return profiler_.get(); + } + + private: + std::string path_; + std::shared_ptr profiler_; + }; + + class Mat + { + public: + Mat() + : desc_{} + { + } + + Mat(int height, + int width, + int channels, + mmdeploy_pixel_format_t format, + mmdeploy_data_type_t type, + uint8_t* data, + mmdeploy_device_t device = nullptr) + : desc_{data, + height, + width, + channels, + format, + type, + device} + { + } + + Mat(const mmdeploy_mat_t& desc) + : desc_(desc) + { + } // NOLINT + + const mmdeploy_mat_t& desc() const noexcept + { + return desc_; + } #if MMDEPLOY_CXX_USE_OPENCV - Mat(const cv::Mat& mat, mmdeploy_pixel_format_t pixel_format) - : desc_{mat.data, mat.rows, mat.cols, mat.channels(), pixel_format, GetCvType(mat.depth())} { - if (pixel_format == MMDEPLOY_PIXEL_FORMAT_COUNT) { - throw_exception(eNotSupported); - } - if (desc_.type == MMDEPLOY_DATA_TYPE_COUNT) { - throw_exception(eNotSupported); - } - } - Mat(const cv::Mat& mat) : Mat(mat, GetCvFormat(mat.channels())) {} - - static mmdeploy_data_type_t GetCvType(int depth) { - switch (depth) { - case CV_8U: - return MMDEPLOY_DATA_TYPE_UINT8; - case CV_32F: - return MMDEPLOY_DATA_TYPE_FLOAT; - default: - return MMDEPLOY_DATA_TYPE_COUNT; - } - } - static mmdeploy_pixel_format_t GetCvFormat(int channels) { - switch (channels) { - case 1: - return MMDEPLOY_PIXEL_FORMAT_GRAYSCALE; - case 3: - return MMDEPLOY_PIXEL_FORMAT_BGR; - case 4: - return MMDEPLOY_PIXEL_FORMAT_BGRA; - default: - return MMDEPLOY_PIXEL_FORMAT_COUNT; - } - } + Mat(const cv::Mat& mat, mmdeploy_pixel_format_t pixel_format) + : desc_{mat.data, + mat.rows, + mat.cols, + mat.channels(), + pixel_format, + GetCvType(mat.depth())} + { + if (pixel_format == MMDEPLOY_PIXEL_FORMAT_COUNT) + { + throw_exception(eNotSupported); + } + + if (desc_.type == MMDEPLOY_DATA_TYPE_COUNT) + { + throw_exception(eNotSupported); + } + } + + Mat(const cv::Mat& mat) + : Mat(mat, GetCvFormat(mat.channels())) + { + } + + static mmdeploy_data_type_t GetCvType(int depth) + { + switch (depth) + { + case CV_8U: + return MMDEPLOY_DATA_TYPE_UINT8; + case CV_32F: + return MMDEPLOY_DATA_TYPE_FLOAT; + default: + return MMDEPLOY_DATA_TYPE_COUNT; + } + } + + static mmdeploy_pixel_format_t GetCvFormat(int channels) + { + switch (channels) + { + case 1: + return MMDEPLOY_PIXEL_FORMAT_GRAYSCALE; + case 3: + return MMDEPLOY_PIXEL_FORMAT_BGR; + case 4: + return MMDEPLOY_PIXEL_FORMAT_BGRA; + default: + return MMDEPLOY_PIXEL_FORMAT_COUNT; + } + } #endif - private: - mmdeploy_mat_t desc_; -}; - -template -class Result_ { - public: - using value_type = T; - using size_type = size_t; - using difference_type = ptrdiff_t; - using reference = T&; - using const_reference = const T&; - using pointer = T*; - using const_pointer = const T*; - using iterator = T*; - using const_iterator = T*; - - Result_(size_t offset, size_t size, std::shared_ptr data) - : offset_(offset), size_(size), data_(std::move(data)) {} - - T& operator[](size_t index) const noexcept { return *(data_.get() + offset_ + index); } - size_t size() const noexcept { return size_; } - T* begin() const noexcept { return data_.get() + offset_; } - T* end() const noexcept { return begin() + size_; } - - T* operator->() const noexcept { return data_.get(); } - T& operator*() const noexcept { return *data_; } - - private: - size_t offset_; - size_t size_; - std::shared_ptr data_; -}; - -inline const mmdeploy_mat_t* reinterpret(const Mat* p) { - return reinterpret_cast(p); -} - -class Scheduler { - public: - explicit Scheduler(mmdeploy_scheduler_t scheduler) { - scheduler_.reset(scheduler, [](auto p) { mmdeploy_scheduler_destroy(p); }); - } - - static Scheduler ThreadPool(int num_threads) { - return Scheduler(mmdeploy_executor_create_thread_pool(num_threads)); - } - static Scheduler Thread() { return Scheduler(mmdeploy_executor_create_thread()); } - - operator mmdeploy_scheduler_t() const noexcept { return scheduler_.get(); } - - private: - std::shared_ptr scheduler_; -}; - -class Context { - public: - Context() { - mmdeploy_context_t context{}; - mmdeploy_context_create(&context); - context_.reset(context, [](auto p) { mmdeploy_context_destroy(p); }); - } - /* implicit */ Context(const Device& device) : Context() { Add(device); } - - void Add(const std::string& name, const Scheduler& scheduler) { - mmdeploy_context_add(*this, MMDEPLOY_TYPE_SCHEDULER, name.c_str(), scheduler); - } - - void Add(const std::string& name, const Model& model) { - mmdeploy_context_add(*this, MMDEPLOY_TYPE_MODEL, name.c_str(), model); - } - - void Add(const Device& device) { - mmdeploy_context_add(*this, MMDEPLOY_TYPE_DEVICE, nullptr, device); - } - - void Add(const Profiler& profiler) { - mmdeploy_context_add(*this, MMDEPLOY_TYPE_PROFILER, nullptr, profiler); - } - - operator mmdeploy_context_t() const noexcept { return context_.get(); } - - private: - std::shared_ptr context_; -}; - -} // namespace cxx - -using cxx::Context; -using cxx::Device; -using cxx::Mat; -using cxx::Model; -using cxx::Profiler; -using cxx::Rect; -using cxx::Scheduler; + private: + mmdeploy_mat_t desc_; + }; + + template + class Result_ + { + public: + using value_type = T; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = T&; + using const_reference = const T&; + using pointer = T*; + using const_pointer = const T*; + using iterator = T*; + using const_iterator = T*; + + Result_(size_t offset, size_t size, std::shared_ptr data) + : offset_(offset) + , size_(size) + , data_(std::move(data)) + { + } + + T& operator[](size_t index) const noexcept + { + return *(data_.get() + offset_ + index); + } + + size_t size() const noexcept + { + return size_; + } + + T* begin() const noexcept + { + return data_.get() + offset_; + } + + T* end() const noexcept + { + return begin() + size_; + } + + T* operator->() const noexcept + { + return data_.get(); + } + + T& operator*() const noexcept + { + return *data_; + } + + private: + size_t offset_; + size_t size_; + std::shared_ptr data_; + }; + + inline const mmdeploy_mat_t* reinterpret(const Mat* p) + { + return reinterpret_cast(p); + } + + class Scheduler + { + public: + explicit Scheduler(mmdeploy_scheduler_t scheduler) + { + scheduler_.reset(scheduler, + [](auto p) + { + mmdeploy_scheduler_destroy(p); + }); + } + + static Scheduler ThreadPool(int num_threads) + { + return Scheduler(mmdeploy_executor_create_thread_pool(num_threads)); + } + + static Scheduler Thread() + { + return Scheduler(mmdeploy_executor_create_thread()); + } + + operator mmdeploy_scheduler_t() const noexcept + { + return scheduler_.get(); + } + + private: + std::shared_ptr scheduler_; + }; + + class Context + { + public: + Context() + { + mmdeploy_context_t context{}; + mmdeploy_context_create(&context); + context_.reset(context, + [](auto p) + { + mmdeploy_context_destroy(p); + }); + } + + /* implicit */ Context(const Device& device) + : Context() + { + Add(device); + } + + void Add(const std::string& name, const Scheduler& scheduler) + { + mmdeploy_context_add(*this, MMDEPLOY_TYPE_SCHEDULER, name.c_str(), scheduler); + } + + void Add(const std::string& name, const Model& model) + { + mmdeploy_context_add(*this, MMDEPLOY_TYPE_MODEL, name.c_str(), model); + } + + void Add(const Device& device) + { + mmdeploy_context_add(*this, MMDEPLOY_TYPE_DEVICE, nullptr, device); + } + + void Add(const Profiler& profiler) + { + mmdeploy_context_add(*this, MMDEPLOY_TYPE_PROFILER, nullptr, profiler); + } + + operator mmdeploy_context_t() const noexcept + { + return context_.get(); + } + + private: + std::shared_ptr context_; + }; + + } // namespace cxx + + using cxx::Context; + using cxx::Device; + using cxx::Mat; + using cxx::Model; + using cxx::Profiler; + using cxx::Rect; + using cxx::Scheduler; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/detector.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/detector.hpp index 847505bbe7..31874fa9f9 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/detector.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/detector.hpp @@ -6,68 +6,87 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/detector.h" -namespace mmdeploy { - -namespace cxx { - -using Detection = mmdeploy_detection_t; - -class Detector : public NonMovable { - public: - Detector(const Model& model, const Context& context) { - auto ec = mmdeploy_detector_create_v2(model, context, &detector_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~Detector() { - if (detector_) { - mmdeploy_detector_destroy(detector_); - detector_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - Detection* results{}; - int* result_count{}; - auto ec = mmdeploy_detector_apply(detector_, reinterpret(images.data()), - static_cast(images.size()), &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [result_count, count = images.size()](auto p) { - mmdeploy_detector_release_result(p, result_count, count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_detector_t detector_{}; -}; - -} // namespace cxx - -using cxx::Detection; -using cxx::Detector; +namespace mmdeploy +{ + + namespace cxx + { + + using Detection = mmdeploy_detection_t; + + class Detector : public NonMovable + { + public: + Detector(const Model& model, const Context& context) + { + auto ec = mmdeploy_detector_create_v2(model, context, &detector_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~Detector() + { + if (detector_) + { + mmdeploy_detector_destroy(detector_); + detector_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + Detection* results{}; + int* result_count{}; + auto ec = mmdeploy_detector_apply(detector_, + reinterpret(images.data()), + static_cast(images.size()), + &results, + &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, + [result_count, count = images.size()](auto p) + { + mmdeploy_detector_release_result(p, result_count, count); + }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_detector_t detector_{}; + }; + + } // namespace cxx + + using cxx::Detection; + using cxx::Detector; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp index e20ec6a224..9380236f8c 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp @@ -7,72 +7,91 @@ #include "mmdeploy/core/value.h" #include "mmdeploy/pipeline.h" -namespace mmdeploy { +namespace mmdeploy +{ -namespace cxx { + namespace cxx + { -class Pipeline : public NonMovable { - public: - Pipeline(const Value& config, const Context& context) { - mmdeploy_pipeline_t pipeline{}; - auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &pipeline); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - pipeline_ = pipeline; - } + class Pipeline : public NonMovable + { + public: + Pipeline(const Value& config, const Context& context) + { + mmdeploy_pipeline_t pipeline{}; + auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, + context, + &pipeline); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + pipeline_ = pipeline; + } - ~Pipeline() { - if (pipeline_) { - mmdeploy_pipeline_destroy(pipeline_); - pipeline_ = nullptr; - } - } + ~Pipeline() + { + if (pipeline_) + { + mmdeploy_pipeline_destroy(pipeline_); + pipeline_ = nullptr; + } + } - Value Apply(const Value& inputs) { - mmdeploy_value_t tmp{}; - auto ec = mmdeploy_pipeline_apply(pipeline_, (mmdeploy_value_t)&inputs, &tmp); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - Value output = std::move(*(Value*)tmp); - mmdeploy_value_destroy(tmp); - return output; - } + Value Apply(const Value& inputs) + { + mmdeploy_value_t tmp{}; + auto ec = mmdeploy_pipeline_apply(pipeline_, + (mmdeploy_value_t)&inputs, + &tmp); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + Value output = std::move(*(Value*)tmp); + mmdeploy_value_destroy(tmp); + return output; + } - Value Apply(Span images) { - if (images.empty()) { - return {}; - } - mmdeploy_value_t inputs{}; - auto ec = mmdeploy_common_create_input(reinterpret(images.data()), - static_cast(images.size()), &inputs); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - auto outputs = Apply(*reinterpret_cast(inputs)); - mmdeploy_value_destroy(inputs); + Value Apply(Span images) + { + if (images.empty()) + { + return {}; + } + mmdeploy_value_t inputs{}; + auto ec = mmdeploy_common_create_input(reinterpret(images.data()), + static_cast(images.size()), + &inputs); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + auto outputs = Apply(*reinterpret_cast(inputs)); + mmdeploy_value_destroy(inputs); - return outputs; - } + return outputs; + } - Value Apply(const Mat& image) { - auto outputs = Apply(Span{image}); - Value::Array rets; - rets.reserve(outputs.size()); - for (auto& output : outputs) { - rets.push_back(std::move(output[0])); - } - return rets; - } + Value Apply(const Mat& image) + { + auto outputs = Apply(Span{image}); + Value::Array rets; + rets.reserve(outputs.size()); + for (auto& output : outputs) + { + rets.push_back(std::move(output[0])); + } + return rets; + } - private: - mmdeploy_pipeline_t pipeline_{}; -}; + private: + mmdeploy_pipeline_t pipeline_{}; + }; -} // namespace cxx + } // namespace cxx -using cxx::Pipeline; + using cxx::Pipeline; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/pose_detector.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/pose_detector.hpp index 7432a417fc..34ef2d2221 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/pose_detector.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/pose_detector.hpp @@ -6,79 +6,91 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/pose_detector.h" -namespace mmdeploy { - -namespace cxx { - -using PoseDetection = mmdeploy_pose_detection_t; - -class PoseDetector : public NonMovable { - public: - PoseDetector(const Model& model, const Context& context) { - auto ec = mmdeploy_pose_detector_create_v2(model, context, &detector_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~PoseDetector() { - if (detector_) { - mmdeploy_pose_detector_destroy(detector_); - detector_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images, Span bboxes, - Span bbox_count) { - if (images.empty()) { - return {}; - } - - const mmdeploy_rect_t* p_bboxes{}; - const int* p_bbox_count{}; - - if (!bboxes.empty()) { - p_bboxes = bboxes.data(); - p_bbox_count = bbox_count.data(); - } - - PoseDetection* results{}; - auto ec = mmdeploy_pose_detector_apply_bbox(detector_, reinterpret(images.data()), - static_cast(images.size()), p_bboxes, - p_bbox_count, &results); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [count = images.size()](auto p) { - mmdeploy_pose_detector_release_result(p, count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, bboxes.empty() ? 1 : bbox_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image, Span bboxes = {}) { - return Apply(Span{image}, bboxes, {static_cast(bboxes.size())})[0]; - } - - private: - mmdeploy_pose_detector_t detector_{}; -}; - -} // namespace cxx - -using cxx::PoseDetection; -using cxx::PoseDetector; +namespace mmdeploy +{ + + namespace cxx + { + + using PoseDetection = mmdeploy_pose_detection_t; + + class PoseDetector : public NonMovable + { + public: + PoseDetector(const Model& model, const Context& context) + { + auto ec = mmdeploy_pose_detector_create_v2(model, context, &detector_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~PoseDetector() + { + if (detector_) + { + mmdeploy_pose_detector_destroy(detector_); + detector_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images, Span bboxes, Span bbox_count) + { + if (images.empty()) + { + return {}; + } + + const mmdeploy_rect_t* p_bboxes{}; + const int* p_bbox_count{}; + + if (!bboxes.empty()) + { + p_bboxes = bboxes.data(); + p_bbox_count = bbox_count.data(); + } + + PoseDetection* results{}; + auto ec = mmdeploy_pose_detector_apply_bbox(detector_, reinterpret(images.data()), static_cast(images.size()), p_bboxes, p_bbox_count, &results); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, + [count = images.size()](auto p) + { + mmdeploy_pose_detector_release_result(p, count); + }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, bboxes.empty() ? 1 : bbox_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image, Span bboxes = {}) + { + return Apply(Span{image}, bboxes, {static_cast(bboxes.size())})[0]; + } + + private: + mmdeploy_pose_detector_t detector_{}; + }; + + } // namespace cxx + + using cxx::PoseDetection; + using cxx::PoseDetector; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/pose_tracker.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/pose_tracker.hpp index 077ec75700..e1e330ce05 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/pose_tracker.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/pose_tracker.hpp @@ -6,145 +6,171 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/pose_tracker.h" -namespace mmdeploy { - -namespace cxx { - -class PoseTracker : public UniqueHandle { - public: - using Result = Result_; - class State; - class Params; - - public: - /** - * @brief Create pose tracker pipeline - * @param detect object detection model - * @param pose pose estimation model - * @param context execution context - */ - PoseTracker(const Model& detect, const Model& pose, const Context& context) { - auto ec = mmdeploy_pose_tracker_create(detect, pose, context, &handle_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - ~PoseTracker() { - if (handle_) { - mmdeploy_pose_tracker_destroy(handle_); - handle_ = {}; - } - } - PoseTracker(PoseTracker&&) noexcept = default; - - /** - * @brief Create a tracker state corresponds to a video stream - * @param params params for creating the tracker state - * @return created tracker state - */ - State CreateState(const Params& params); - - /** - * @brief Apply pose tracker pipeline - * @param state tracker state - * @param frame input video frame - * @param detect control the use of detector - * -1: use params.det_interval, 0: don't use detector, 1: force use detector - * @return - */ - Result Apply(State& state, const Mat& frame, int detect = -1); - - /** - * @brief batched version of Apply - * @param states - * @param frames - * @param detects - * @return - */ - std::vector Apply(const Span& states, const Span& frames, - const Span& detects = {}); - - public: - /** - * see \ref mmdeploy/pose_tracker.h for detail - */ - class Params : public UniqueHandle { - public: - explicit Params() { - handle_ = new mmdeploy_pose_tracker_param_t{}; - mmdeploy_pose_tracker_default_params(handle_); - } - ~Params() { - if (handle_) { - delete handle_; - handle_ = {}; - } - } - }; - - class State : public UniqueHandle { - public: - explicit State(mmdeploy_pose_tracker_t pipeline, const mmdeploy_pose_tracker_param_t* params) { - auto ec = mmdeploy_pose_tracker_create_state(pipeline, params, &handle_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - ~State() { - if (handle_) { - mmdeploy_pose_tracker_destroy_state(handle_); - handle_ = {}; - } - } - State(State&&) noexcept = default; - }; -}; - -inline PoseTracker::State PoseTracker::CreateState(const PoseTracker::Params& params) { - return State(handle_, static_cast(params)); -} - -inline std::vector PoseTracker::Apply(const Span& states, - const Span& frames, - const Span& detects) { - if (frames.empty()) { - return {}; - } - mmdeploy_pose_tracker_target_t* results{}; - int32_t* result_count{}; - - auto ec = mmdeploy_pose_tracker_apply( - handle_, reinterpret_cast(states.data()), - reinterpret(frames.data()), detects.data(), static_cast(frames.size()), &results, - &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data( - results, [result_count, count = frames.size()](auto p) { - mmdeploy_pose_tracker_release_result(p, result_count, count); - }); - - std::vector rets; - rets.reserve(frames.size()); - - size_t offset = 0; - for (size_t i = 0; i < frames.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; -} - -inline PoseTracker::Result PoseTracker::Apply(PoseTracker::State& state, const Mat& frame, - int32_t detect) { - return Apply(Span(&state, 1), Span{frame}, Span{detect})[0]; -} - -} // namespace cxx - -using cxx::PoseTracker; +namespace mmdeploy +{ + + namespace cxx + { + + class PoseTracker : public UniqueHandle + { + public: + using Result = Result_; + class State; + class Params; + + public: + /** + * @brief Create pose tracker pipeline + * @param detect object detection model + * @param pose pose estimation model + * @param context execution context + */ + PoseTracker(const Model& detect, const Model& pose, const Context& context) + { + auto ec = mmdeploy_pose_tracker_create(detect, pose, context, &handle_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + ~PoseTracker() + { + if (handle_) + { + mmdeploy_pose_tracker_destroy(handle_); + handle_ = {}; + } + } + PoseTracker(PoseTracker&&) noexcept = default; + + /** + * @brief Create a tracker state corresponds to a video stream + * @param params params for creating the tracker state + * @return created tracker state + */ + State CreateState(const Params& params); + + /** + * @brief Apply pose tracker pipeline + * @param state tracker state + * @param frame input video frame + * @param detect control the use of detector + * -1: use params.det_interval, 0: don't use detector, 1: force use detector + * @return + */ + Result Apply(State& state, const Mat& frame, int detect = -1); + + /** + * @brief batched version of Apply + * @param states + * @param frames + * @param detects + * @return + */ + std::vector Apply(const Span& states, const Span& frames, const Span& detects = {}); + + public: + /** + * see \ref mmdeploy/pose_tracker.h for detail + */ + class Params : public UniqueHandle + { + public: + explicit Params() + { + handle_ = new mmdeploy_pose_tracker_param_t{}; + mmdeploy_pose_tracker_default_params(handle_); + } + ~Params() + { + if (handle_) + { + delete handle_; + handle_ = {}; + } + } + }; + + class State : public UniqueHandle + { + public: + explicit State(mmdeploy_pose_tracker_t pipeline, const mmdeploy_pose_tracker_param_t* params) + { + auto ec = mmdeploy_pose_tracker_create_state(pipeline, params, &handle_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + ~State() + { + if (handle_) + { + mmdeploy_pose_tracker_destroy_state(handle_); + handle_ = {}; + } + } + State(State&&) noexcept = default; + }; + }; + + inline PoseTracker::State PoseTracker::CreateState(const PoseTracker::Params& params) + { + return State(handle_, static_cast(params)); + } + + inline std::vector PoseTracker::Apply(const Span& states, + const Span& frames, + const Span& detects) + { + if (frames.empty()) + { + return {}; + } + mmdeploy_pose_tracker_target_t* results{}; + int32_t* result_count{}; + + auto ec = mmdeploy_pose_tracker_apply( + handle_, + reinterpret_cast(states.data()), + reinterpret(frames.data()), + detects.data(), + static_cast(frames.size()), + &results, + &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data( + results, + [result_count, count = frames.size()](auto p) + { + mmdeploy_pose_tracker_release_result(p, result_count, count); + }); + + std::vector rets; + rets.reserve(frames.size()); + + size_t offset = 0; + for (size_t i = 0; i < frames.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + inline PoseTracker::Result PoseTracker::Apply(PoseTracker::State& state, const Mat& frame, int32_t detect) + { + return Apply(Span(&state, 1), Span{frame}, Span{detect})[0]; + } + + } // namespace cxx + + using cxx::PoseTracker; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/restorer.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/restorer.hpp index 671c5c2d0c..dcf9ab75af 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/restorer.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/restorer.hpp @@ -6,62 +6,77 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/restorer.h" -namespace mmdeploy { - -namespace cxx { - -class Restorer : public NonMovable { - public: - Restorer(const Model& model, const Context& context) { - auto ec = mmdeploy_restorer_create_v2(model, context, &restorer_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~Restorer() { - if (restorer_) { - mmdeploy_restorer_destroy(restorer_); - restorer_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - mmdeploy_mat_t* results{}; - auto ec = mmdeploy_restorer_apply(restorer_, reinterpret(images.data()), - static_cast(images.size()), &results); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::vector rets; - rets.reserve(images.size()); - - std::shared_ptr data( - results, [count = images.size()](auto p) { mmdeploy_restorer_release_result(p, count); }); - - for (size_t i = 0; i < images.size(); ++i) { - rets.emplace_back(i, 1, data); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_restorer_t restorer_{}; -}; - -} // namespace cxx - -using cxx::Restorer; +namespace mmdeploy +{ + + namespace cxx + { + + class Restorer : public NonMovable + { + public: + Restorer(const Model& model, const Context& context) + { + auto ec = mmdeploy_restorer_create_v2(model, context, &restorer_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~Restorer() + { + if (restorer_) + { + mmdeploy_restorer_destroy(restorer_); + restorer_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + mmdeploy_mat_t* results{}; + auto ec = mmdeploy_restorer_apply(restorer_, reinterpret(images.data()), static_cast(images.size()), &results); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::vector rets; + rets.reserve(images.size()); + + std::shared_ptr data( + results, + [count = images.size()](auto p) + { mmdeploy_restorer_release_result(p, count); }); + + for (size_t i = 0; i < images.size(); ++i) + { + rets.emplace_back(i, 1, data); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_restorer_t restorer_{}; + }; + + } // namespace cxx + + using cxx::Restorer; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/rotated_detector.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/rotated_detector.hpp index fa065b0f0c..5a224f6fa5 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/rotated_detector.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/rotated_detector.hpp @@ -6,69 +6,81 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/rotated_detector.h" -namespace mmdeploy { - -namespace cxx { - -using RotatedDetection = mmdeploy_rotated_detection_t; - -class RotatedDetector : public NonMovable { - public: - RotatedDetector(const Model& model, const Context& context) { - auto ec = mmdeploy_rotated_detector_create_v2(model, context, &detector_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~RotatedDetector() { - if (detector_) { - mmdeploy_rotated_detector_destroy(detector_); - detector_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - RotatedDetection* results{}; - int* result_count{}; - auto ec = - mmdeploy_rotated_detector_apply(detector_, reinterpret(images.data()), - static_cast(images.size()), &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [result_count](auto p) { - mmdeploy_rotated_detector_release_result(p, result_count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_rotated_detector_t detector_{}; -}; - -} // namespace cxx - -using cxx::RotatedDetection; -using cxx::RotatedDetector; +namespace mmdeploy +{ + + namespace cxx + { + + using RotatedDetection = mmdeploy_rotated_detection_t; + + class RotatedDetector : public NonMovable + { + public: + RotatedDetector(const Model& model, const Context& context) + { + auto ec = mmdeploy_rotated_detector_create_v2(model, context, &detector_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~RotatedDetector() + { + if (detector_) + { + mmdeploy_rotated_detector_destroy(detector_); + detector_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + RotatedDetection* results{}; + int* result_count{}; + auto ec = + mmdeploy_rotated_detector_apply(detector_, reinterpret(images.data()), static_cast(images.size()), &results, &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, [result_count](auto p) + { mmdeploy_rotated_detector_release_result(p, result_count); }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_rotated_detector_t detector_{}; + }; + + } // namespace cxx + + using cxx::RotatedDetection; + using cxx::RotatedDetector; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/segmentor.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/segmentor.hpp index fe53023d1c..7ad98a91bb 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/segmentor.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/segmentor.hpp @@ -6,65 +6,80 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/segmentor.h" -namespace mmdeploy { - -namespace cxx { - -using Segmentation = mmdeploy_segmentation_t; - -class Segmentor : public NonMovable { - public: - Segmentor(const Model& model, const Context& context) { - auto ec = mmdeploy_segmentor_create_v2(model, context, &segmentor_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~Segmentor() { - if (segmentor_) { - mmdeploy_segmentor_destroy(segmentor_); - segmentor_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - Segmentation* results{}; - auto ec = mmdeploy_segmentor_apply(segmentor_, reinterpret(images.data()), - static_cast(images.size()), &results); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::vector rets; - rets.reserve(images.size()); - - std::shared_ptr data( - results, [count = images.size()](auto p) { mmdeploy_segmentor_release_result(p, count); }); - - for (size_t i = 0; i < images.size(); ++i) { - rets.emplace_back(i, 1, data); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_segmentor_t segmentor_{}; -}; - -} // namespace cxx - -using cxx::Segmentation; -using cxx::Segmentor; +namespace mmdeploy +{ + + namespace cxx + { + + using Segmentation = mmdeploy_segmentation_t; + + class Segmentor : public NonMovable + { + public: + Segmentor(const Model& model, const Context& context) + { + auto ec = mmdeploy_segmentor_create_v2(model, context, &segmentor_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~Segmentor() + { + if (segmentor_) + { + mmdeploy_segmentor_destroy(segmentor_); + segmentor_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + Segmentation* results{}; + auto ec = mmdeploy_segmentor_apply(segmentor_, reinterpret(images.data()), static_cast(images.size()), &results); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::vector rets; + rets.reserve(images.size()); + + std::shared_ptr data( + results, + [count = images.size()](auto p) + { mmdeploy_segmentor_release_result(p, count); }); + + for (size_t i = 0; i < images.size(); ++i) + { + rets.emplace_back(i, 1, data); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_segmentor_t segmentor_{}; + }; + + } // namespace cxx + + using cxx::Segmentation; + using cxx::Segmentor; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/text_detector.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/text_detector.hpp index d848715405..56f2f02f18 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/text_detector.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/text_detector.hpp @@ -6,69 +6,81 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/text_detector.h" -namespace mmdeploy { - -namespace cxx { - -using TextDetection = mmdeploy_text_detection_t; - -class TextDetector : public NonMovable { - public: - TextDetector(const Model& model, const Context& context) { - auto ec = mmdeploy_text_detector_create_v2(model, context, &detector_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~TextDetector() { - if (detector_) { - mmdeploy_text_detector_destroy(detector_); - detector_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - TextDetection* results{}; - int* result_count{}; - auto ec = - mmdeploy_text_detector_apply(detector_, reinterpret(images.data()), - static_cast(images.size()), &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [result_count, count = images.size()](auto p) { - mmdeploy_text_detector_release_result(p, result_count, count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_text_detector_t detector_{}; -}; - -} // namespace cxx - -using cxx::TextDetection; -using cxx::TextDetector; +namespace mmdeploy +{ + + namespace cxx + { + + using TextDetection = mmdeploy_text_detection_t; + + class TextDetector : public NonMovable + { + public: + TextDetector(const Model& model, const Context& context) + { + auto ec = mmdeploy_text_detector_create_v2(model, context, &detector_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~TextDetector() + { + if (detector_) + { + mmdeploy_text_detector_destroy(detector_); + detector_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + TextDetection* results{}; + int* result_count{}; + auto ec = + mmdeploy_text_detector_apply(detector_, reinterpret(images.data()), static_cast(images.size()), &results, &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, [result_count, count = images.size()](auto p) + { mmdeploy_text_detector_release_result(p, result_count, count); }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_text_detector_t detector_{}; + }; + + } // namespace cxx + + using cxx::TextDetection; + using cxx::TextDetector; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/text_recognizer.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/text_recognizer.hpp index eba8ea3902..31c741e2ee 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/text_recognizer.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/text_recognizer.hpp @@ -9,82 +9,91 @@ #include "mmdeploy/text_detector.hpp" #include "mmdeploy/text_recognizer.h" -namespace mmdeploy { - -namespace cxx { - -using TextRecognition = mmdeploy_text_recognition_t; - -class TextRecognizer : public NonMovable { - public: - TextRecognizer(const Model& model, const Context& context) { - auto ec = mmdeploy_text_recognizer_create_v2(model, context, &recognizer_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~TextRecognizer() { - if (recognizer_) { - mmdeploy_text_recognizer_destroy(recognizer_); - recognizer_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images, Span bboxes, - Span bbox_count) { - if (images.empty()) { - return {}; - } - - const TextDetection* p_bboxes{}; - const int* p_bbox_count{}; - - auto n_total_bboxes = static_cast(images.size()); - - if (!bboxes.empty()) { - p_bboxes = bboxes.data(); - p_bbox_count = bbox_count.data(); - n_total_bboxes = std::accumulate(bbox_count.begin(), bbox_count.end(), 0); - } - - TextRecognition* results{}; - auto ec = mmdeploy_text_recognizer_apply_bbox(recognizer_, reinterpret(images.data()), - static_cast(images.size()), p_bboxes, - p_bbox_count, &results); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [count = n_total_bboxes](auto p) { - mmdeploy_text_recognizer_release_result(p, count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, bboxes.empty() ? 1 : bbox_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image, Span bboxes = {}) { - return Apply(Span{image}, bboxes, {static_cast(bboxes.size())})[0]; - } - - private: - mmdeploy_text_recognizer_t recognizer_{}; -}; - -} // namespace cxx - -using cxx::TextRecognition; -using cxx::TextRecognizer; +namespace mmdeploy +{ + + namespace cxx + { + + using TextRecognition = mmdeploy_text_recognition_t; + + class TextRecognizer : public NonMovable + { + public: + TextRecognizer(const Model& model, const Context& context) + { + auto ec = mmdeploy_text_recognizer_create_v2(model, context, &recognizer_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~TextRecognizer() + { + if (recognizer_) + { + mmdeploy_text_recognizer_destroy(recognizer_); + recognizer_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images, Span bboxes, Span bbox_count) + { + if (images.empty()) + { + return {}; + } + + const TextDetection* p_bboxes{}; + const int* p_bbox_count{}; + + auto n_total_bboxes = static_cast(images.size()); + + if (!bboxes.empty()) + { + p_bboxes = bboxes.data(); + p_bbox_count = bbox_count.data(); + n_total_bboxes = std::accumulate(bbox_count.begin(), bbox_count.end(), 0); + } + + TextRecognition* results{}; + auto ec = mmdeploy_text_recognizer_apply_bbox(recognizer_, reinterpret(images.data()), static_cast(images.size()), p_bboxes, p_bbox_count, &results); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, [count = n_total_bboxes](auto p) + { mmdeploy_text_recognizer_release_result(p, count); }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, bboxes.empty() ? 1 : bbox_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image, Span bboxes = {}) + { + return Apply(Span{image}, bboxes, {static_cast(bboxes.size())})[0]; + } + + private: + mmdeploy_text_recognizer_t recognizer_{}; + }; + + } // namespace cxx + + using cxx::TextRecognition; + using cxx::TextRecognizer; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/video_recognizer.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/video_recognizer.hpp index 583b28dd59..ed3569e242 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/video_recognizer.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/video_recognizer.hpp @@ -6,85 +6,97 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/video_recognizer.h" -namespace mmdeploy { - -namespace cxx { - -using VideoRecognition = mmdeploy_video_recognition_t; -using VideoSampleInfo = mmdeploy_video_sample_info_t; - -class VideoRecognizer : public NonMovable { - public: - VideoRecognizer(const Model& model, const Context& context) { - auto ec = mmdeploy_video_recognizer_create_v2(model, context, &recognizer_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~VideoRecognizer() { - if (recognizer_) { - mmdeploy_video_recognizer_destroy(recognizer_); - recognizer_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span> videos, - Span infos) { - if (videos.empty()) { - return {}; - } - - int video_count = videos.size(); - - VideoRecognition* results{}; - int* result_count{}; - std::vector images; - std::vector video_info; - for (int i = 0; i < videos.size(); i++) { - for (auto& mat : videos[i]) { - images.push_back(mat); - } - video_info.push_back(infos[i]); - } - - auto ec = - mmdeploy_video_recognizer_apply(recognizer_, reinterpret(images.data()), video_info.data(), - video_count, &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::vector rets; - rets.reserve(video_count); - - std::shared_ptr data(results, [result_count, count = video_count](auto p) { - mmdeploy_video_recognizer_release_result(p, result_count, count); - }); - - size_t offset = 0; - for (size_t i = 0; i < video_count; ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const std::vector& video, const VideoSampleInfo info) { - return Apply(Span{video}, Span{info})[0]; - } - - private: - mmdeploy_video_recognizer_t recognizer_{}; -}; - -} // namespace cxx - -using cxx::VideoRecognition; -using cxx::VideoRecognizer; -using cxx::VideoSampleInfo; +namespace mmdeploy +{ + + namespace cxx + { + + using VideoRecognition = mmdeploy_video_recognition_t; + using VideoSampleInfo = mmdeploy_video_sample_info_t; + + class VideoRecognizer : public NonMovable + { + public: + VideoRecognizer(const Model& model, const Context& context) + { + auto ec = mmdeploy_video_recognizer_create_v2(model, context, &recognizer_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~VideoRecognizer() + { + if (recognizer_) + { + mmdeploy_video_recognizer_destroy(recognizer_); + recognizer_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span> videos, + Span infos) + { + if (videos.empty()) + { + return {}; + } + + int video_count = videos.size(); + + VideoRecognition* results{}; + int* result_count{}; + std::vector images; + std::vector video_info; + for (int i = 0; i < videos.size(); i++) + { + for (auto& mat : videos[i]) + { + images.push_back(mat); + } + video_info.push_back(infos[i]); + } + + auto ec = + mmdeploy_video_recognizer_apply(recognizer_, reinterpret(images.data()), video_info.data(), video_count, &results, &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::vector rets; + rets.reserve(video_count); + + std::shared_ptr data(results, [result_count, count = video_count](auto p) + { mmdeploy_video_recognizer_release_result(p, result_count, count); }); + + size_t offset = 0; + for (size_t i = 0; i < video_count; ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const std::vector& video, const VideoSampleInfo info) + { + return Apply(Span{video}, Span{info})[0]; + } + + private: + mmdeploy_video_recognizer_t recognizer_{}; + }; + + } // namespace cxx + + using cxx::VideoRecognition; + using cxx::VideoRecognizer; + using cxx::VideoSampleInfo; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/java/CMakeLists.txt b/csrc/mmdeploy/apis/java/CMakeLists.txt index 04313f1934..6ae7a8e0ad 100644 --- a/csrc/mmdeploy/apis/java/CMakeLists.txt +++ b/csrc/mmdeploy/apis/java/CMakeLists.txt @@ -1,6 +1,6 @@ -if (NOT MMDEPLOY_BUILD_SDK_JAVA_API) - return () -endif () +if(NOT MMDEPLOY_BUILD_SDK_JAVA_API) + return() +endif() project(mmdeploy_java_package) @@ -9,26 +9,27 @@ include(UseJava) add_subdirectory(native) -add_jar(${PROJECT_NAME} SOURCES - mmdeploy/DataType.java - mmdeploy/Mat.java - mmdeploy/InstanceMask.java - mmdeploy/PixelFormat.java - mmdeploy/PointF.java - mmdeploy/Rect.java - mmdeploy/Classifier.java - mmdeploy/Detector.java - mmdeploy/Segmentor.java - mmdeploy/TextDetector.java - mmdeploy/TextRecognizer.java - mmdeploy/Restorer.java - mmdeploy/PoseDetector.java - mmdeploy/Context.java - mmdeploy/Device.java - mmdeploy/Model.java - mmdeploy/Profiler.java - mmdeploy/Scheduler.java - mmdeploy/PoseTracker.java - mmdeploy/RotatedDetector.java - OUTPUT_NAME mmdeploy - OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +add_jar( + ${PROJECT_NAME} + SOURCES mmdeploy/DataType.java + mmdeploy/Mat.java + mmdeploy/InstanceMask.java + mmdeploy/PixelFormat.java + mmdeploy/PointF.java + mmdeploy/Rect.java + mmdeploy/Classifier.java + mmdeploy/Detector.java + mmdeploy/Segmentor.java + mmdeploy/TextDetector.java + mmdeploy/TextRecognizer.java + mmdeploy/Restorer.java + mmdeploy/PoseDetector.java + mmdeploy/Context.java + mmdeploy/Device.java + mmdeploy/Model.java + mmdeploy/Profiler.java + mmdeploy/Scheduler.java + mmdeploy/PoseTracker.java + mmdeploy/RotatedDetector.java + OUTPUT_NAME mmdeploy + OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/csrc/mmdeploy/apis/java/native/CMakeLists.txt b/csrc/mmdeploy/apis/java/native/CMakeLists.txt index 6324cd21a1..b1868b8567 100644 --- a/csrc/mmdeploy/apis/java/native/CMakeLists.txt +++ b/csrc/mmdeploy/apis/java/native/CMakeLists.txt @@ -1,35 +1,35 @@ # Copyright (c) OpenMMLab. All rights reserved. project(mmdeploy_java) -if (NOT ANDROID) - find_package(JNI REQUIRED) -else () - set(JNI_LIBRARIES) +if(NOT ANDROID) + find_package(JNI REQUIRED) +else() + set(JNI_LIBRARIES) endif() -mmdeploy_add_library(${PROJECT_NAME} SHARED EXCLUDE - mmdeploy_Classifier.cpp - mmdeploy_Detector.cpp - mmdeploy_Segmentor.cpp - mmdeploy_Restorer.cpp - mmdeploy_PoseDetector.cpp - mmdeploy_TextDetector.cpp - mmdeploy_TextRecognizer.cpp - mmdeploy_PoseTracker.cpp - mmdeploy_Context.cpp - mmdeploy_Device.cpp - mmdeploy_Model.cpp - mmdeploy_Profiler.cpp - mmdeploy_Scheduler.cpp - mmdeploy_RotatedDetector.cpp) +mmdeploy_add_library( + ${PROJECT_NAME} + SHARED + EXCLUDE + mmdeploy_Classifier.cpp + mmdeploy_Detector.cpp + mmdeploy_Segmentor.cpp + mmdeploy_Restorer.cpp + mmdeploy_PoseDetector.cpp + mmdeploy_TextDetector.cpp + mmdeploy_TextRecognizer.cpp + mmdeploy_PoseTracker.cpp + mmdeploy_Context.cpp + mmdeploy_Device.cpp + mmdeploy_Model.cpp + mmdeploy_Profiler.cpp + mmdeploy_Scheduler.cpp + mmdeploy_RotatedDetector.cpp) -target_include_directories(${PROJECT_NAME} PRIVATE - ${JNI_INCLUDE_DIRS}) +target_include_directories(${PROJECT_NAME} PRIVATE ${JNI_INCLUDE_DIRS}) mmdeploy_load_static(${PROJECT_NAME} MMDeployStaticModules) mmdeploy_load_dynamic(${PROJECT_NAME} MMDeployDynamicModules) -target_link_libraries(${PROJECT_NAME} PRIVATE - ${JNI_LIBRARIES} MMDeployLibs) -install(TARGETS ${PROJECT_NAME} - DESTINATION lib) +target_link_libraries(${PROJECT_NAME} PRIVATE ${JNI_LIBRARIES} MMDeployLibs) +install(TARGETS ${PROJECT_NAME} DESTINATION lib) diff --git a/csrc/mmdeploy/apis/java/native/common.h b/csrc/mmdeploy/apis/java/native/common.h index ba2601e5f1..045dc02a35 100644 --- a/csrc/mmdeploy/apis/java/native/common.h +++ b/csrc/mmdeploy/apis/java/native/common.h @@ -10,45 +10,48 @@ #include "mmdeploy/core/logger.h" #include "mmdeploy/core/utils/formatter.h" -template -static auto With(JNIEnv *env, jobjectArray imgs, F f) noexcept { - auto mat_clazz = env->FindClass("mmdeploy/Mat"); - auto shape_field = env->GetFieldID(mat_clazz, "shape", "[I"); - auto format_field = env->GetFieldID(mat_clazz, "format", "I"); - auto type_field = env->GetFieldID(mat_clazz, "type", "I"); - auto data_field = env->GetFieldID(mat_clazz, "data", "[B"); - auto num = env->GetArrayLength(imgs); - std::vector mats; - std::vector datum; - - mats.reserve(num); - datum.reserve(num); - - for (int i = 0; i < num; ++i) { - auto obj = env->GetObjectArrayElement(imgs, i); - auto shape_obj = env->GetObjectField(obj, shape_field); - auto shape = env->GetIntArrayElements((jintArray)shape_obj, nullptr); - auto format = env->GetIntField(obj, format_field); - auto type = env->GetIntField(obj, type_field); - auto &mat = mats.emplace_back(); - mat.height = shape[0]; - mat.width = shape[1]; - mat.channel = shape[2]; - env->ReleaseIntArrayElements((jintArray)shape_obj, shape, JNI_ABORT); - mat.format = (mmdeploy_pixel_format_t)format; - mat.type = (mmdeploy_data_type_t)type; - auto data_obj = env->GetObjectField(obj, data_field); - mat.data = (uint8_t *)env->GetByteArrayElements((jbyteArray)data_obj, nullptr); - datum.push_back((jbyteArray)data_obj); - } - - auto ret = f(mats.data(), mats.size()); // ! f must not throw - - for (int i = 0; i < num; ++i) { - env->ReleaseByteArrayElements(datum[i], (jbyte *)mats[i].data, JNI_ABORT); - } - - return ret; +template +static auto With(JNIEnv* env, jobjectArray imgs, F f) noexcept +{ + auto mat_clazz = env->FindClass("mmdeploy/Mat"); + auto shape_field = env->GetFieldID(mat_clazz, "shape", "[I"); + auto format_field = env->GetFieldID(mat_clazz, "format", "I"); + auto type_field = env->GetFieldID(mat_clazz, "type", "I"); + auto data_field = env->GetFieldID(mat_clazz, "data", "[B"); + auto num = env->GetArrayLength(imgs); + std::vector mats; + std::vector datum; + + mats.reserve(num); + datum.reserve(num); + + for (int i = 0; i < num; ++i) + { + auto obj = env->GetObjectArrayElement(imgs, i); + auto shape_obj = env->GetObjectField(obj, shape_field); + auto shape = env->GetIntArrayElements((jintArray)shape_obj, nullptr); + auto format = env->GetIntField(obj, format_field); + auto type = env->GetIntField(obj, type_field); + auto& mat = mats.emplace_back(); + mat.height = shape[0]; + mat.width = shape[1]; + mat.channel = shape[2]; + env->ReleaseIntArrayElements((jintArray)shape_obj, shape, JNI_ABORT); + mat.format = (mmdeploy_pixel_format_t)format; + mat.type = (mmdeploy_data_type_t)type; + auto data_obj = env->GetObjectField(obj, data_field); + mat.data = (uint8_t*)env->GetByteArrayElements((jbyteArray)data_obj, nullptr); + datum.push_back((jbyteArray)data_obj); + } + + auto ret = f(mats.data(), mats.size()); // ! f must not throw + + for (int i = 0; i < num; ++i) + { + env->ReleaseByteArrayElements(datum[i], (jbyte*)mats[i].data, JNI_ABORT); + } + + return ret; } #endif // MMDEPLOY_CSRC_APIS_JAVA_NATIVE_COMMON_H_ diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.cpp index 2a3309361e..6664a65289 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.cpp @@ -6,30 +6,33 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Classifier_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_classifier_t classifier{}; - auto ec = - mmdeploy_classifier_create_by_path(model_path, device_name, (int)device_id, &classifier); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create classifier, code = {}", ec); - return -1; - } - return (jlong)classifier; +jlong Java_mmdeploy_Classifier_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_classifier_t classifier{}; + auto ec = + mmdeploy_classifier_create_by_path(model_path, device_name, (int)device_id, &classifier); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create classifier, code = {}", ec); + return -1; + } + return (jlong)classifier; } -void Java_mmdeploy_Classifier_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_Classifier_destroy"); - mmdeploy_classifier_destroy((mmdeploy_classifier_t)handle); +void Java_mmdeploy_Classifier_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Classifier_destroy"); + mmdeploy_classifier_destroy((mmdeploy_classifier_t)handle); } -jobjectArray Java_mmdeploy_Classifier_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jintArray counts) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_Classifier_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jintArray counts) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_classification_t *results{}; int *result_count{}; auto ec = mmdeploy_classifier_apply((mmdeploy_classifier_t)handle, imgs, size, &results, @@ -55,6 +58,5 @@ jobjectArray Java_mmdeploy_Classifier_apply(JNIEnv *env, jobject thiz, jlong han } env->ReleaseIntArrayElements(counts, counts_array, 0); mmdeploy_classifier_release_result(results, result_count, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.h index 16a06b5fba..84adf58aa3 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.h @@ -3,33 +3,33 @@ /* Header for class mmdeploy_Classifier */ #ifndef _Included_mmdeploy_Classifier -#define _Included_mmdeploy_Classifier -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Classifier - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Classifier_create(JNIEnv *, jobject, jstring, jstring, jint); + #define _Included_mmdeploy_Classifier + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Classifier + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Classifier_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_Classifier - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Classifier_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Classifier + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Classifier_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_Classifier - * Method: apply - * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/Classifier/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Classifier_apply(JNIEnv *, jobject, jlong, - jobjectArray, jintArray); + /* + * Class: mmdeploy_Classifier + * Method: apply + * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/Classifier/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Classifier_apply(JNIEnv*, jobject, jlong, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Context.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Context.cpp index dbd401724e..e875a66ead 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Context.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Context.cpp @@ -8,36 +8,43 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Context_create(JNIEnv *env, jobject) { - mmdeploy_context_t context{}; - mmdeploy_context_create(&context); - return (jlong)context; +jlong Java_mmdeploy_Context_create(JNIEnv* env, jobject) +{ + mmdeploy_context_t context{}; + mmdeploy_context_create(&context); + return (jlong)context; } -jint Java_mmdeploy_Context_add(JNIEnv *env, jobject, jlong context_, jint contextType, jstring name, - jlong handle) { - auto object_name = env->GetStringUTFChars(name, nullptr); - if ((int)contextType == MMDEPLOY_TYPE_SCHEDULER) { - mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, - object_name, (mmdeploy_scheduler_t)handle); - } else if ((int)contextType == MMDEPLOY_TYPE_MODEL) { - mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, - object_name, (mmdeploy_model_t)handle); - } else if ((int)contextType == MMDEPLOY_TYPE_DEVICE) { - mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, - nullptr, (mmdeploy_device_t)handle); - } else if ((int)contextType == MMDEPLOY_TYPE_PROFILER) { - mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, - nullptr, (mmdeploy_profiler_t)handle); - } else { - MMDEPLOY_ERROR("wrong context type, got {}", (int)contextType); - return MMDEPLOY_E_NOT_SUPPORTED; - } - env->ReleaseStringUTFChars(name, object_name); - return 0; +jint Java_mmdeploy_Context_add(JNIEnv* env, jobject, jlong context_, jint contextType, jstring name, jlong handle) +{ + auto object_name = env->GetStringUTFChars(name, nullptr); + if ((int)contextType == MMDEPLOY_TYPE_SCHEDULER) + { + mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, object_name, (mmdeploy_scheduler_t)handle); + } + else if ((int)contextType == MMDEPLOY_TYPE_MODEL) + { + mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, object_name, (mmdeploy_model_t)handle); + } + else if ((int)contextType == MMDEPLOY_TYPE_DEVICE) + { + mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, nullptr, (mmdeploy_device_t)handle); + } + else if ((int)contextType == MMDEPLOY_TYPE_PROFILER) + { + mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, nullptr, (mmdeploy_profiler_t)handle); + } + else + { + MMDEPLOY_ERROR("wrong context type, got {}", (int)contextType); + return MMDEPLOY_E_NOT_SUPPORTED; + } + env->ReleaseStringUTFChars(name, object_name); + return 0; } -void Java_mmdeploy_Context_destroy(JNIEnv *, jobject, jlong context_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Context_destroy"); - mmdeploy_context_destroy((mmdeploy_context_t)context_); +void Java_mmdeploy_Context_destroy(JNIEnv*, jobject, jlong context_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Context_destroy"); + mmdeploy_context_destroy((mmdeploy_context_t)context_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Context.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Context.h index 42df819580..00e24065c6 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Context.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Context.h @@ -3,32 +3,33 @@ /* Header for class mmdeploy_Context */ #ifndef _Included_mmdeploy_Context -#define _Included_mmdeploy_Context -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Context - * Method: create - * Signature: ()J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Context_create(JNIEnv *, jobject); + #define _Included_mmdeploy_Context + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Context + * Method: create + * Signature: ()J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Context_create(JNIEnv*, jobject); -/* - * Class: mmdeploy_Context - * Method: add - * Signature: (JILjava/lang/String;J)I - */ -JNIEXPORT jint JNICALL Java_mmdeploy_Context_add(JNIEnv *, jobject, jlong, jint, jstring, jlong); + /* + * Class: mmdeploy_Context + * Method: add + * Signature: (JILjava/lang/String;J)I + */ + JNIEXPORT jint JNICALL Java_mmdeploy_Context_add(JNIEnv*, jobject, jlong, jint, jstring, jlong); -/* - * Class: mmdeploy_Context - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Context_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Context + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Context_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.cpp index c03ff1a1ff..6e8a32dac7 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.cpp @@ -6,29 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Detector_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_detector_t detector{}; - auto ec = mmdeploy_detector_create_by_path(model_path, device_name, (int)device_id, &detector); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create detector, code = {}", ec); - return -1; - } - return (jlong)detector; +jlong Java_mmdeploy_Detector_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_detector_t detector{}; + auto ec = mmdeploy_detector_create_by_path(model_path, device_name, (int)device_id, &detector); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create detector, code = {}", ec); + return -1; + } + return (jlong)detector; } -void Java_mmdeploy_Detector_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_Detector_destroy"); // maybe use info? - mmdeploy_detector_destroy((mmdeploy_detector_t)handle); +void Java_mmdeploy_Detector_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Detector_destroy"); // maybe use info? + mmdeploy_detector_destroy((mmdeploy_detector_t)handle); } -jobjectArray Java_mmdeploy_Detector_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jintArray counts) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_Detector_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jintArray counts) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_detection_t *results{}; int *result_count{}; auto ec = @@ -79,6 +82,5 @@ jobjectArray Java_mmdeploy_Detector_apply(JNIEnv *env, jobject thiz, jlong handl } env->ReleaseIntArrayElements(counts, counts_array, 0); mmdeploy_detector_release_result(results, result_count, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.h index 41e711d15a..578643efc8 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.h @@ -3,33 +3,33 @@ /* Header for class mmdeploy_Detector */ #ifndef _Included_mmdeploy_Detector -#define _Included_mmdeploy_Detector -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Detector - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Detector_create(JNIEnv *, jobject, jstring, jstring, jint); + #define _Included_mmdeploy_Detector + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Detector + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Detector_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_Detector - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Detector_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Detector + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Detector_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_Detector - * Method: apply - * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/Detector/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Detector_apply(JNIEnv *, jobject, jlong, jobjectArray, - jintArray); + /* + * Class: mmdeploy_Detector + * Method: apply + * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/Detector/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Detector_apply(JNIEnv*, jobject, jlong, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Device.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Device.cpp index 8dbec9285b..8160210ed5 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Device.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Device.cpp @@ -6,19 +6,22 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Device_create(JNIEnv *env, jobject, jstring name, jint index) { - auto device_name = env->GetStringUTFChars(name, nullptr); - mmdeploy_device_t device{}; - auto ec = mmdeploy_device_create(device_name, (int)index, &device); - env->ReleaseStringUTFChars(name, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create device, code = {}", ec); - return -1; - } - return (jlong)device; +jlong Java_mmdeploy_Device_create(JNIEnv* env, jobject, jstring name, jint index) +{ + auto device_name = env->GetStringUTFChars(name, nullptr); + mmdeploy_device_t device{}; + auto ec = mmdeploy_device_create(device_name, (int)index, &device); + env->ReleaseStringUTFChars(name, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create device, code = {}", ec); + return -1; + } + return (jlong)device; } -void Java_mmdeploy_Device_destroy(JNIEnv *, jobject, jlong device_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Device_destroy"); - mmdeploy_device_destroy((mmdeploy_device_t)device_); +void Java_mmdeploy_Device_destroy(JNIEnv*, jobject, jlong device_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Device_destroy"); + mmdeploy_device_destroy((mmdeploy_device_t)device_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Device.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Device.h index 7d7ee9dee7..e751d0f781 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Device.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Device.h @@ -3,25 +3,26 @@ /* Header for class mmdeploy_Device */ #ifndef _Included_mmdeploy_Device -#define _Included_mmdeploy_Device -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Device - * Method: create - * Signature: (Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Device_create(JNIEnv *, jobject, jstring, jint); + #define _Included_mmdeploy_Device + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Device + * Method: create + * Signature: (Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Device_create(JNIEnv*, jobject, jstring, jint); -/* - * Class: mmdeploy_Device - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Device_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Device + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Device_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Model.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Model.cpp index 2bbc9a6920..821b1e988e 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Model.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Model.cpp @@ -6,19 +6,22 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Model_create(JNIEnv *env, jobject, jstring path) { - auto model_path = env->GetStringUTFChars(path, nullptr); - mmdeploy_model_t model{}; - auto ec = mmdeploy_model_create_by_path(model_path, &model); - env->ReleaseStringUTFChars(path, model_path); - if (ec) { - MMDEPLOY_ERROR("failed to create model, code = {}", ec); - return -1; - } - return (jlong)model; +jlong Java_mmdeploy_Model_create(JNIEnv* env, jobject, jstring path) +{ + auto model_path = env->GetStringUTFChars(path, nullptr); + mmdeploy_model_t model{}; + auto ec = mmdeploy_model_create_by_path(model_path, &model); + env->ReleaseStringUTFChars(path, model_path); + if (ec) + { + MMDEPLOY_ERROR("failed to create model, code = {}", ec); + return -1; + } + return (jlong)model; } -void Java_mmdeploy_Model_destroy(JNIEnv *, jobject, jlong model_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Model_destroy"); - mmdeploy_model_destroy((mmdeploy_model_t)model_); +void Java_mmdeploy_Model_destroy(JNIEnv*, jobject, jlong model_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Model_destroy"); + mmdeploy_model_destroy((mmdeploy_model_t)model_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Model.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Model.h index 11e23a1a81..9fc714c259 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Model.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Model.h @@ -3,25 +3,26 @@ /* Header for class mmdeploy_Model */ #ifndef _Included_mmdeploy_Model -#define _Included_mmdeploy_Model -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Model - * Method: create - * Signature: (Ljava/lang/String;)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Model_create(JNIEnv *, jobject, jstring); + #define _Included_mmdeploy_Model + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Model + * Method: create + * Signature: (Ljava/lang/String;)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Model_create(JNIEnv*, jobject, jstring); -/* - * Class: mmdeploy_Model - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Model_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Model + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Model_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.cpp index 4956555a6e..aac54574a0 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.cpp @@ -6,30 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_PoseDetector_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_pose_detector_t pose_estimator{}; - auto ec = mmdeploy_pose_detector_create_by_path(model_path, device_name, (int)device_id, - &pose_estimator); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create pose estimator, code = {}", ec); - return -1; - } - return (jlong)pose_estimator; +jlong Java_mmdeploy_PoseDetector_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_pose_detector_t pose_estimator{}; + auto ec = mmdeploy_pose_detector_create_by_path(model_path, device_name, (int)device_id, &pose_estimator); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create pose estimator, code = {}", ec); + return -1; + } + return (jlong)pose_estimator; } -void Java_mmdeploy_PoseDetector_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_PoseDetector_destroy"); - mmdeploy_pose_detector_destroy((mmdeploy_pose_detector_t)handle); +void Java_mmdeploy_PoseDetector_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_PoseDetector_destroy"); + mmdeploy_pose_detector_destroy((mmdeploy_pose_detector_t)handle); } -jobjectArray Java_mmdeploy_PoseDetector_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_PoseDetector_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_pose_detection_t *results{}; auto ec = mmdeploy_pose_detector_apply((mmdeploy_pose_detector_t)handle, imgs, size, &results); if (ec) { @@ -55,6 +57,5 @@ jobjectArray Java_mmdeploy_PoseDetector_apply(JNIEnv *env, jobject thiz, jlong h env->SetObjectArrayElement(array, i, res); } mmdeploy_pose_detector_release_result(results, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.h b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.h index a50b7fd821..87c70ac0a6 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.h @@ -3,34 +3,33 @@ /* Header for class mmdeploy_PoseDetector */ #ifndef _Included_mmdeploy_PoseDetector -#define _Included_mmdeploy_PoseDetector -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_PoseDetector - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_PoseDetector_create(JNIEnv *, jobject, jstring, jstring, - jint); + #define _Included_mmdeploy_PoseDetector + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_PoseDetector + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_PoseDetector_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_PoseDetector - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_PoseDetector_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_PoseDetector + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_PoseDetector_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_PoseDetector - * Method: apply - * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/PoseDetector/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_PoseDetector_apply(JNIEnv *, jobject, jlong, - jobjectArray); + /* + * Class: mmdeploy_PoseDetector + * Method: apply + * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/PoseDetector/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_PoseDetector_apply(JNIEnv*, jobject, jlong, jobjectArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.cpp index c0d1685729..61fd42eb07 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.cpp @@ -6,143 +6,161 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_PoseTracker_create(JNIEnv *env, jobject, jlong detModel, jlong poseModel, - jlong context) { - mmdeploy_pose_tracker_t pose_tracker{}; - auto ec = mmdeploy_pose_tracker_create((mmdeploy_model_t)detModel, (mmdeploy_model_t)poseModel, - (mmdeploy_context_t)context, &pose_tracker); - if (ec) { - MMDEPLOY_ERROR("failed to create pose tracker, code = {}", ec); - return -1; - } - return (jlong)pose_tracker; +jlong Java_mmdeploy_PoseTracker_create(JNIEnv* env, jobject, jlong detModel, jlong poseModel, jlong context) +{ + mmdeploy_pose_tracker_t pose_tracker{}; + auto ec = mmdeploy_pose_tracker_create((mmdeploy_model_t)detModel, (mmdeploy_model_t)poseModel, (mmdeploy_context_t)context, &pose_tracker); + if (ec) + { + MMDEPLOY_ERROR("failed to create pose tracker, code = {}", ec); + return -1; + } + return (jlong)pose_tracker; } -void Java_mmdeploy_PoseTracker_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_PoseTracker_destroy"); - mmdeploy_pose_tracker_destroy((mmdeploy_pose_tracker_t)handle); +void Java_mmdeploy_PoseTracker_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_PoseTracker_destroy"); + mmdeploy_pose_tracker_destroy((mmdeploy_pose_tracker_t)handle); } -jobject param_cpp_to_java(JNIEnv *env, mmdeploy_pose_tracker_param_t *params) { - auto param_cls = env->FindClass("mmdeploy/PoseTracker$Params"); - auto param_ctor = env->GetMethodID(param_cls, "", "(IIFFFIFIFFF[FIFIIFF[F)V"); +jobject param_cpp_to_java(JNIEnv* env, mmdeploy_pose_tracker_param_t* params) +{ + auto param_cls = env->FindClass("mmdeploy/PoseTracker$Params"); + auto param_ctor = env->GetMethodID(param_cls, "", "(IIFFFIFIFFF[FIFIIFF[F)V"); - jfloatArray keypointSigmas = env->NewFloatArray(params->keypoint_sigmas_size); - env->SetFloatArrayRegion(keypointSigmas, 0, params->keypoint_sigmas_size, - (jfloat *)params->keypoint_sigmas); - jfloatArray smoothParams = env->NewFloatArray(3); - env->SetFloatArrayRegion(smoothParams, 0, 3, (jfloat *)params->smooth_params); + jfloatArray keypointSigmas = env->NewFloatArray(params->keypoint_sigmas_size); + env->SetFloatArrayRegion(keypointSigmas, 0, params->keypoint_sigmas_size, (jfloat*)params->keypoint_sigmas); + jfloatArray smoothParams = env->NewFloatArray(3); + env->SetFloatArrayRegion(smoothParams, 0, 3, (jfloat*)params->smooth_params); - auto param = env->NewObject( - param_cls, param_ctor, (jint)params->det_interval, (jint)params->det_label, - (jfloat)params->det_thr, (jfloat)params->det_min_bbox_size, (jfloat)params->det_nms_thr, - (jint)params->pose_max_num_bboxes, (jfloat)params->pose_kpt_thr, - (jint)params->pose_min_keypoints, (jfloat)params->pose_bbox_scale, - (jfloat)params->pose_min_bbox_size, (jfloat)params->pose_nms_thr, keypointSigmas, - (jint)params->keypoint_sigmas_size, (jfloat)params->track_iou_thr, - (jint)params->track_max_missing, (jint)params->track_history_size, - (jfloat)params->std_weight_position, (jfloat)params->std_weight_velocity, smoothParams); - return param; + auto param = env->NewObject( + param_cls, + param_ctor, + (jint)params->det_interval, + (jint)params->det_label, + (jfloat)params->det_thr, + (jfloat)params->det_min_bbox_size, + (jfloat)params->det_nms_thr, + (jint)params->pose_max_num_bboxes, + (jfloat)params->pose_kpt_thr, + (jint)params->pose_min_keypoints, + (jfloat)params->pose_bbox_scale, + (jfloat)params->pose_min_bbox_size, + (jfloat)params->pose_nms_thr, + keypointSigmas, + (jint)params->keypoint_sigmas_size, + (jfloat)params->track_iou_thr, + (jint)params->track_max_missing, + (jint)params->track_history_size, + (jfloat)params->std_weight_position, + (jfloat)params->std_weight_velocity, + smoothParams); + return param; } -void param_java_to_cpp(JNIEnv *env, mmdeploy_pose_tracker_param_t *params, jobject customParam) { - auto param_cls = env->FindClass("mmdeploy/PoseTracker$Params"); - auto param_ctor = env->GetMethodID(param_cls, "", "(IIFFFIFIFFF[FIFIIFF[F)V"); +void param_java_to_cpp(JNIEnv* env, mmdeploy_pose_tracker_param_t* params, jobject customParam) +{ + auto param_cls = env->FindClass("mmdeploy/PoseTracker$Params"); + auto param_ctor = env->GetMethodID(param_cls, "", "(IIFFFIFIFFF[FIFIIFF[F)V"); - jfieldID fieldID_detInterval = env->GetFieldID(param_cls, "detInterval", "I"); - jint detInterval = env->GetIntField(customParam, fieldID_detInterval); - params->det_interval = (int)detInterval; - jfieldID fieldID_detLabel = env->GetFieldID(param_cls, "detLabel", "I"); - jint detLabel = env->GetIntField(customParam, fieldID_detLabel); - params->det_label = (int)detLabel; - jfieldID fieldID_detThr = env->GetFieldID(param_cls, "detThr", "F"); - jfloat detThr = env->GetFloatField(customParam, fieldID_detThr); - params->det_thr = (float)detThr; - jfieldID fieldID_detMinBboxSize = env->GetFieldID(param_cls, "detMinBboxSize", "F"); - jfloat detMinBboxSize = env->GetFloatField(customParam, fieldID_detMinBboxSize); - params->det_min_bbox_size = (float)detMinBboxSize; - jfieldID fieldID_detNmsThr = env->GetFieldID(param_cls, "detNmsThr", "F"); - jfloat detNmsThr = env->GetFloatField(customParam, fieldID_detNmsThr); - params->det_nms_thr = (float)detNmsThr; - jfieldID fieldID_poseMaxNumBboxes = env->GetFieldID(param_cls, "poseMaxNumBboxes", "I"); - jint poseMaxNumBboxes = env->GetIntField(customParam, fieldID_poseMaxNumBboxes); - params->pose_max_num_bboxes = (int)poseMaxNumBboxes; - jfieldID fieldID_poseKptThr = env->GetFieldID(param_cls, "poseKptThr", "F"); - jfloat poseKptThr = env->GetFloatField(customParam, fieldID_poseKptThr); - params->pose_kpt_thr = (float)poseKptThr; - jfieldID fieldID_poseMinKeypoints = env->GetFieldID(param_cls, "poseMinKeypoints", "I"); - jint poseMinKeypoints = env->GetIntField(customParam, fieldID_poseMinKeypoints); - params->pose_min_keypoints = (int)poseMinKeypoints; - jfieldID fieldID_poseBboxScale = env->GetFieldID(param_cls, "poseBboxScale", "F"); - jfloat poseBboxScale = env->GetFloatField(customParam, fieldID_poseBboxScale); - params->pose_bbox_scale = (float)poseBboxScale; - jfieldID fieldID_poseMinBboxSize = env->GetFieldID(param_cls, "poseMinBboxSize", "F"); - jfloat poseMinBboxSize = env->GetFloatField(customParam, fieldID_poseMinBboxSize); - params->pose_min_bbox_size = (float)poseMinBboxSize; - jfieldID fieldID_poseNmsThr = env->GetFieldID(param_cls, "poseNmsThr", "F"); - jfloat poseNmsThr = env->GetFloatField(customParam, fieldID_poseNmsThr); - params->pose_nms_thr = (float)poseNmsThr; - jfieldID fieldID_keypointSigmas = env->GetFieldID(param_cls, "keypointSigmas", "[F"); - auto keypointSigmasObj = env->GetObjectField(customParam, fieldID_keypointSigmas); - float *keypointSigmas = - (float *)env->GetFloatArrayElements((jfloatArray)keypointSigmasObj, nullptr); - params->keypoint_sigmas = keypointSigmas; - env->ReleaseFloatArrayElements((jfloatArray)keypointSigmasObj, keypointSigmas, JNI_ABORT); - jfieldID fieldID_keypointSigmasSize = env->GetFieldID(param_cls, "keypointSigmasSize", "I"); - jint keypointSigmasSize = env->GetIntField(customParam, fieldID_keypointSigmasSize); - params->keypoint_sigmas_size = keypointSigmasSize; - jfieldID fieldID_trackIouThr = env->GetFieldID(param_cls, "trackIouThr", "F"); - jfloat trackIouThr = env->GetFloatField(customParam, fieldID_trackIouThr); - params->track_iou_thr = trackIouThr; - jfieldID fieldID_trackMaxMissing = env->GetFieldID(param_cls, "trackMaxMissing", "I"); - jint trackMaxMissing = env->GetIntField(customParam, fieldID_trackMaxMissing); - params->track_max_missing = trackMaxMissing; - jfieldID fieldID_trackHistorySize = env->GetFieldID(param_cls, "trackHistorySize", "I"); - jint trackHistorySize = env->GetIntField(customParam, fieldID_trackHistorySize); - params->track_history_size = trackHistorySize; - jfieldID fieldID_stdWeightPosition = env->GetFieldID(param_cls, "stdWeightPosition", "F"); - jfloat stdWeightPosition = env->GetFloatField(customParam, fieldID_stdWeightPosition); - params->std_weight_position = stdWeightPosition; - jfieldID fieldID_stdWeightVelocity = env->GetFieldID(param_cls, "stdWeightVelocity", "F"); - jfloat stdWeightVelocity = env->GetFloatField(customParam, fieldID_stdWeightVelocity); - params->std_weight_velocity = stdWeightVelocity; - jfieldID fieldID_smoothParams = env->GetFieldID(param_cls, "smoothParams", "[F"); - auto smoothParamsObj = env->GetObjectField(customParam, fieldID_smoothParams); - float *smoothParams = (float *)env->GetFloatArrayElements((jfloatArray)smoothParamsObj, nullptr); - params->smooth_params[0] = smoothParams[0]; - params->smooth_params[1] = smoothParams[1]; - params->smooth_params[2] = smoothParams[2]; - env->ReleaseFloatArrayElements((jfloatArray)smoothParamsObj, smoothParams, JNI_ABORT); + jfieldID fieldID_detInterval = env->GetFieldID(param_cls, "detInterval", "I"); + jint detInterval = env->GetIntField(customParam, fieldID_detInterval); + params->det_interval = (int)detInterval; + jfieldID fieldID_detLabel = env->GetFieldID(param_cls, "detLabel", "I"); + jint detLabel = env->GetIntField(customParam, fieldID_detLabel); + params->det_label = (int)detLabel; + jfieldID fieldID_detThr = env->GetFieldID(param_cls, "detThr", "F"); + jfloat detThr = env->GetFloatField(customParam, fieldID_detThr); + params->det_thr = (float)detThr; + jfieldID fieldID_detMinBboxSize = env->GetFieldID(param_cls, "detMinBboxSize", "F"); + jfloat detMinBboxSize = env->GetFloatField(customParam, fieldID_detMinBboxSize); + params->det_min_bbox_size = (float)detMinBboxSize; + jfieldID fieldID_detNmsThr = env->GetFieldID(param_cls, "detNmsThr", "F"); + jfloat detNmsThr = env->GetFloatField(customParam, fieldID_detNmsThr); + params->det_nms_thr = (float)detNmsThr; + jfieldID fieldID_poseMaxNumBboxes = env->GetFieldID(param_cls, "poseMaxNumBboxes", "I"); + jint poseMaxNumBboxes = env->GetIntField(customParam, fieldID_poseMaxNumBboxes); + params->pose_max_num_bboxes = (int)poseMaxNumBboxes; + jfieldID fieldID_poseKptThr = env->GetFieldID(param_cls, "poseKptThr", "F"); + jfloat poseKptThr = env->GetFloatField(customParam, fieldID_poseKptThr); + params->pose_kpt_thr = (float)poseKptThr; + jfieldID fieldID_poseMinKeypoints = env->GetFieldID(param_cls, "poseMinKeypoints", "I"); + jint poseMinKeypoints = env->GetIntField(customParam, fieldID_poseMinKeypoints); + params->pose_min_keypoints = (int)poseMinKeypoints; + jfieldID fieldID_poseBboxScale = env->GetFieldID(param_cls, "poseBboxScale", "F"); + jfloat poseBboxScale = env->GetFloatField(customParam, fieldID_poseBboxScale); + params->pose_bbox_scale = (float)poseBboxScale; + jfieldID fieldID_poseMinBboxSize = env->GetFieldID(param_cls, "poseMinBboxSize", "F"); + jfloat poseMinBboxSize = env->GetFloatField(customParam, fieldID_poseMinBboxSize); + params->pose_min_bbox_size = (float)poseMinBboxSize; + jfieldID fieldID_poseNmsThr = env->GetFieldID(param_cls, "poseNmsThr", "F"); + jfloat poseNmsThr = env->GetFloatField(customParam, fieldID_poseNmsThr); + params->pose_nms_thr = (float)poseNmsThr; + jfieldID fieldID_keypointSigmas = env->GetFieldID(param_cls, "keypointSigmas", "[F"); + auto keypointSigmasObj = env->GetObjectField(customParam, fieldID_keypointSigmas); + float* keypointSigmas = + (float*)env->GetFloatArrayElements((jfloatArray)keypointSigmasObj, nullptr); + params->keypoint_sigmas = keypointSigmas; + env->ReleaseFloatArrayElements((jfloatArray)keypointSigmasObj, keypointSigmas, JNI_ABORT); + jfieldID fieldID_keypointSigmasSize = env->GetFieldID(param_cls, "keypointSigmasSize", "I"); + jint keypointSigmasSize = env->GetIntField(customParam, fieldID_keypointSigmasSize); + params->keypoint_sigmas_size = keypointSigmasSize; + jfieldID fieldID_trackIouThr = env->GetFieldID(param_cls, "trackIouThr", "F"); + jfloat trackIouThr = env->GetFloatField(customParam, fieldID_trackIouThr); + params->track_iou_thr = trackIouThr; + jfieldID fieldID_trackMaxMissing = env->GetFieldID(param_cls, "trackMaxMissing", "I"); + jint trackMaxMissing = env->GetIntField(customParam, fieldID_trackMaxMissing); + params->track_max_missing = trackMaxMissing; + jfieldID fieldID_trackHistorySize = env->GetFieldID(param_cls, "trackHistorySize", "I"); + jint trackHistorySize = env->GetIntField(customParam, fieldID_trackHistorySize); + params->track_history_size = trackHistorySize; + jfieldID fieldID_stdWeightPosition = env->GetFieldID(param_cls, "stdWeightPosition", "F"); + jfloat stdWeightPosition = env->GetFloatField(customParam, fieldID_stdWeightPosition); + params->std_weight_position = stdWeightPosition; + jfieldID fieldID_stdWeightVelocity = env->GetFieldID(param_cls, "stdWeightVelocity", "F"); + jfloat stdWeightVelocity = env->GetFloatField(customParam, fieldID_stdWeightVelocity); + params->std_weight_velocity = stdWeightVelocity; + jfieldID fieldID_smoothParams = env->GetFieldID(param_cls, "smoothParams", "[F"); + auto smoothParamsObj = env->GetObjectField(customParam, fieldID_smoothParams); + float* smoothParams = (float*)env->GetFloatArrayElements((jfloatArray)smoothParamsObj, nullptr); + params->smooth_params[0] = smoothParams[0]; + params->smooth_params[1] = smoothParams[1]; + params->smooth_params[2] = smoothParams[2]; + env->ReleaseFloatArrayElements((jfloatArray)smoothParamsObj, smoothParams, JNI_ABORT); } -jobject Java_mmdeploy_PoseTracker_setDefaultParams(JNIEnv *env, jobject) { - mmdeploy_pose_tracker_param_t params{}; - mmdeploy_pose_tracker_default_params(¶ms); - return param_cpp_to_java(env, ¶ms); +jobject Java_mmdeploy_PoseTracker_setDefaultParams(JNIEnv* env, jobject) +{ + mmdeploy_pose_tracker_param_t params{}; + mmdeploy_pose_tracker_default_params(¶ms); + return param_cpp_to_java(env, ¶ms); } -jlong Java_mmdeploy_PoseTracker_createState(JNIEnv *env, jobject, jlong pipeline, - jobject paramsObject) { - mmdeploy_pose_tracker_state_t state{}; - mmdeploy_pose_tracker_param_t params{}; - param_java_to_cpp(env, ¶ms, paramsObject); - auto ec = mmdeploy_pose_tracker_create_state((mmdeploy_pose_tracker_t)pipeline, ¶ms, &state); - if (ec) { - MMDEPLOY_ERROR("failed to create pose tracker state, code = {}", ec); - return -1; - } - return (jlong)state; +jlong Java_mmdeploy_PoseTracker_createState(JNIEnv* env, jobject, jlong pipeline, jobject paramsObject) +{ + mmdeploy_pose_tracker_state_t state{}; + mmdeploy_pose_tracker_param_t params{}; + param_java_to_cpp(env, ¶ms, paramsObject); + auto ec = mmdeploy_pose_tracker_create_state((mmdeploy_pose_tracker_t)pipeline, ¶ms, &state); + if (ec) + { + MMDEPLOY_ERROR("failed to create pose tracker state, code = {}", ec); + return -1; + } + return (jlong)state; } -void Java_mmdeploy_PoseTracker_destroyState(JNIEnv *, jobject, jlong state) { - MMDEPLOY_DEBUG("Java_mmdeploy_PoseTracker_destroy"); - mmdeploy_pose_tracker_destroy_state((mmdeploy_pose_tracker_state_t)state); +void Java_mmdeploy_PoseTracker_destroyState(JNIEnv*, jobject, jlong state) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_PoseTracker_destroy"); + mmdeploy_pose_tracker_destroy_state((mmdeploy_pose_tracker_state_t)state); } -jobjectArray Java_mmdeploy_PoseTracker_apply(JNIEnv *env, jobject thiz, jlong handle, - jlongArray states, jobjectArray frames, - jintArray detects, jintArray counts) { - return With(env, frames, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_PoseTracker_apply(JNIEnv* env, jobject thiz, jlong handle, jlongArray states, jobjectArray frames, jintArray detects, jintArray counts) +{ + return With(env, frames, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_pose_tracker_target_t *results{}; int *result_count{}; auto states_array = env->GetLongArrayElements(states, nullptr); @@ -189,6 +207,5 @@ jobjectArray Java_mmdeploy_PoseTracker_apply(JNIEnv *env, jobject thiz, jlong ha env->ReleaseLongArrayElements(states, states_array, 0); env->ReleaseIntArrayElements(detects, detects_array, 0); mmdeploy_pose_tracker_release_result(results, result_count, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.h b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.h index 8e8d3905c8..1de79b1eaa 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.h @@ -3,54 +3,54 @@ /* Header for class mmdeploy_PoseTracker */ #ifndef _Included_mmdeploy_PoseTracker -#define _Included_mmdeploy_PoseTracker -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_PoseTracker - * Method: create - * Signature: (JJJ)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_PoseTracker_create(JNIEnv *, jobject, jlong, jlong, jlong); + #define _Included_mmdeploy_PoseTracker + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_PoseTracker + * Method: create + * Signature: (JJJ)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_PoseTracker_create(JNIEnv*, jobject, jlong, jlong, jlong); -/* - * Class: mmdeploy_PoseTracker - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_PoseTracker_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_PoseTracker + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_PoseTracker_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_PoseTracker - * Method: createState - * Signature: (JLmmdeploy/PoseTracker/Params;)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_PoseTracker_createState(JNIEnv *, jobject, jlong, jobject); + /* + * Class: mmdeploy_PoseTracker + * Method: createState + * Signature: (JLmmdeploy/PoseTracker/Params;)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_PoseTracker_createState(JNIEnv*, jobject, jlong, jobject); -/* - * Class: mmdeploy_PoseTracker - * Method: destroyState - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_PoseTracker_destroyState(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_PoseTracker + * Method: destroyState + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_PoseTracker_destroyState(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_PoseTracker - * Method: setDefaultParams - * Signature: ()Lmmdeploy/PoseTracker/Params; - */ -JNIEXPORT jobject JNICALL Java_mmdeploy_PoseTracker_setDefaultParams(JNIEnv *, jobject); + /* + * Class: mmdeploy_PoseTracker + * Method: setDefaultParams + * Signature: ()Lmmdeploy/PoseTracker/Params; + */ + JNIEXPORT jobject JNICALL Java_mmdeploy_PoseTracker_setDefaultParams(JNIEnv*, jobject); -/* - * Class: mmdeploy_PoseTracker - * Method: apply - * Signature: (J[J[Lmmdeploy/Mat;[I[I)[Lmmdeploy/PoseTracker/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_PoseTracker_apply(JNIEnv *, jobject, jlong, jlongArray, - jobjectArray, jintArray, jintArray); + /* + * Class: mmdeploy_PoseTracker + * Method: apply + * Signature: (J[J[Lmmdeploy/Mat;[I[I)[Lmmdeploy/PoseTracker/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_PoseTracker_apply(JNIEnv*, jobject, jlong, jlongArray, jobjectArray, jintArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.cpp index 2c63233c5c..2ff419ec7a 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.cpp @@ -6,19 +6,22 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Profiler_create(JNIEnv *env, jobject, jstring path) { - auto profiler_path = env->GetStringUTFChars(path, nullptr); - mmdeploy_profiler_t profiler{}; - auto ec = mmdeploy_profiler_create(profiler_path, &profiler); - env->ReleaseStringUTFChars(path, profiler_path); - if (ec) { - MMDEPLOY_ERROR("failed to create profiler, code = {}", ec); - return -1; - } - return (jlong)profiler; +jlong Java_mmdeploy_Profiler_create(JNIEnv* env, jobject, jstring path) +{ + auto profiler_path = env->GetStringUTFChars(path, nullptr); + mmdeploy_profiler_t profiler{}; + auto ec = mmdeploy_profiler_create(profiler_path, &profiler); + env->ReleaseStringUTFChars(path, profiler_path); + if (ec) + { + MMDEPLOY_ERROR("failed to create profiler, code = {}", ec); + return -1; + } + return (jlong)profiler; } -void Java_mmdeploy_Profiler_destroy(JNIEnv *, jobject, jlong profiler_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Profiler_destroy"); - mmdeploy_profiler_destroy((mmdeploy_profiler_t)profiler_); +void Java_mmdeploy_Profiler_destroy(JNIEnv*, jobject, jlong profiler_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Profiler_destroy"); + mmdeploy_profiler_destroy((mmdeploy_profiler_t)profiler_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.h index 2bcdbc42cc..9e829ad38c 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.h @@ -3,25 +3,26 @@ /* Header for class mmdeploy_Profiler */ #ifndef _Included_mmdeploy_Profiler -#define _Included_mmdeploy_Profiler -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Profiler - * Method: create - * Signature: (Ljava/lang/String;)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Profiler_create(JNIEnv *, jobject, jstring); + #define _Included_mmdeploy_Profiler + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Profiler + * Method: create + * Signature: (Ljava/lang/String;)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Profiler_create(JNIEnv*, jobject, jstring); -/* - * Class: mmdeploy_Profiler - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Profiler_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Profiler + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Profiler_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.cpp index f124d5edae..abc630afa6 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.cpp @@ -6,29 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Restorer_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_restorer_t restorer{}; - auto ec = mmdeploy_restorer_create_by_path(model_path, device_name, (int)device_id, &restorer); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create restorer, code = {}", ec); - return -1; - } - return (jlong)restorer; +jlong Java_mmdeploy_Restorer_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_restorer_t restorer{}; + auto ec = mmdeploy_restorer_create_by_path(model_path, device_name, (int)device_id, &restorer); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create restorer, code = {}", ec); + return -1; + } + return (jlong)restorer; } -void Java_mmdeploy_Restorer_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_Restorer_destroy"); - mmdeploy_restorer_destroy((mmdeploy_restorer_t)handle); +void Java_mmdeploy_Restorer_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Restorer_destroy"); + mmdeploy_restorer_destroy((mmdeploy_restorer_t)handle); } -jobjectArray Java_mmdeploy_Restorer_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_Restorer_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_mat_t *results{}; auto ec = mmdeploy_restorer_apply((mmdeploy_restorer_t)handle, imgs, size, &results); if (ec) { @@ -68,6 +71,5 @@ jobjectArray Java_mmdeploy_Restorer_apply(JNIEnv *env, jobject thiz, jlong handl current_result++; } mmdeploy_restorer_release_result(results, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.h index 78b09787fe..7a4aec079b 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.h @@ -3,32 +3,33 @@ /* Header for class mmdeploy_Restorer */ #ifndef _Included_mmdeploy_Restorer -#define _Included_mmdeploy_Restorer -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Restorer - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Restorer_create(JNIEnv *, jobject, jstring, jstring, jint); + #define _Included_mmdeploy_Restorer + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Restorer + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Restorer_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_Restorer - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Restorer_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Restorer + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Restorer_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_Restorer - * Method: apply - * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/Restorer/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Restorer_apply(JNIEnv *, jobject, jlong, jobjectArray); + /* + * Class: mmdeploy_Restorer + * Method: apply + * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/Restorer/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Restorer_apply(JNIEnv*, jobject, jlong, jobjectArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.cpp index 3872e7e158..9b34659aa5 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.cpp @@ -6,30 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_RotatedDetector_create(JNIEnv *env, jobject, jstring modelPath, - jstring deviceName, jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_rotated_detector_t rotated_detector{}; - auto ec = mmdeploy_rotated_detector_create_by_path(model_path, device_name, (int)device_id, - &rotated_detector); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create rotated detector, code = {}", ec); - return -1; - } - return (jlong)rotated_detector; +jlong Java_mmdeploy_RotatedDetector_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_rotated_detector_t rotated_detector{}; + auto ec = mmdeploy_rotated_detector_create_by_path(model_path, device_name, (int)device_id, &rotated_detector); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create rotated detector, code = {}", ec); + return -1; + } + return (jlong)rotated_detector; } -void Java_mmdeploy_RotatedDetector_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_RotatedDetector_destroy"); - mmdeploy_rotated_detector_destroy((mmdeploy_rotated_detector_t)handle); +void Java_mmdeploy_RotatedDetector_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_RotatedDetector_destroy"); + mmdeploy_rotated_detector_destroy((mmdeploy_rotated_detector_t)handle); } -jobjectArray Java_mmdeploy_RotatedDetector_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jintArray counts) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_RotatedDetector_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jintArray counts) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_rotated_detection_t *results{}; int *result_count{}; auto ec = mmdeploy_rotated_detector_apply((mmdeploy_rotated_detector_t)handle, imgs, size, @@ -56,6 +58,5 @@ jobjectArray Java_mmdeploy_RotatedDetector_apply(JNIEnv *env, jobject thiz, jlon } env->ReleaseIntArrayElements(counts, counts_array, 0); mmdeploy_rotated_detector_release_result(results, result_count); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.h b/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.h index 6de527ec40..7327b791ea 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.h @@ -3,34 +3,33 @@ /* Header for class mmdeploy_RotatedDetector */ #ifndef _Included_mmdeploy_RotatedDetector -#define _Included_mmdeploy_RotatedDetector -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_RotatedDetector - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_RotatedDetector_create(JNIEnv *, jobject, jstring, jstring, - jint); + #define _Included_mmdeploy_RotatedDetector + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_RotatedDetector + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_RotatedDetector_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_RotatedDetector - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_RotatedDetector_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_RotatedDetector + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_RotatedDetector_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_RotatedDetector - * Method: apply - * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/RotatedDetector/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_RotatedDetector_apply(JNIEnv *, jobject, jlong, - jobjectArray, jintArray); + /* + * Class: mmdeploy_RotatedDetector + * Method: apply + * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/RotatedDetector/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_RotatedDetector_apply(JNIEnv*, jobject, jlong, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.cpp index 2c1f1c42c0..3ab391c44d 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.cpp @@ -7,17 +7,20 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Scheduler_createThreadPool(JNIEnv *env, jobject, jint numThreads) { - mmdeploy_scheduler_t scheduler = mmdeploy_executor_create_thread_pool((int)numThreads); - return (jlong)scheduler; +jlong Java_mmdeploy_Scheduler_createThreadPool(JNIEnv* env, jobject, jint numThreads) +{ + mmdeploy_scheduler_t scheduler = mmdeploy_executor_create_thread_pool((int)numThreads); + return (jlong)scheduler; } -jlong Java_mmdeploy_Scheduler_createThread(JNIEnv *env, jobject) { - mmdeploy_scheduler_t scheduler = mmdeploy_executor_create_thread(); - return (jlong)scheduler; +jlong Java_mmdeploy_Scheduler_createThread(JNIEnv* env, jobject) +{ + mmdeploy_scheduler_t scheduler = mmdeploy_executor_create_thread(); + return (jlong)scheduler; } -void Java_mmdeploy_Scheduler_destroy(JNIEnv *, jobject, jlong scheduler_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Scheduler_destroy"); - mmdeploy_scheduler_destroy((mmdeploy_scheduler_t)scheduler_); +void Java_mmdeploy_Scheduler_destroy(JNIEnv*, jobject, jlong scheduler_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Scheduler_destroy"); + mmdeploy_scheduler_destroy((mmdeploy_scheduler_t)scheduler_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.h index 363015cf95..8774db0fc7 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.h @@ -3,32 +3,33 @@ /* Header for class mmdeploy_Scheduler */ #ifndef _Included_mmdeploy_Scheduler -#define _Included_mmdeploy_Scheduler -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Scheduler - * Method: createThreadPool - * Signature: (I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Scheduler_createThreadPool(JNIEnv *, jclass, jint); + #define _Included_mmdeploy_Scheduler + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Scheduler + * Method: createThreadPool + * Signature: (I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Scheduler_createThreadPool(JNIEnv*, jclass, jint); -/* - * Class: mmdeploy_Scheduler - * Method: createThread - * Signature: ()J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Scheduler_createThread(JNIEnv *, jclass); + /* + * Class: mmdeploy_Scheduler + * Method: createThread + * Signature: ()J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Scheduler_createThread(JNIEnv*, jclass); -/* - * Class: mmdeploy_Scheduler - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Scheduler_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Scheduler + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Scheduler_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.cpp index 12df31a49e..8942041c8c 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.cpp @@ -6,29 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Segmentor_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_segmentor_t segmentor{}; - auto ec = mmdeploy_segmentor_create_by_path(model_path, device_name, (int)device_id, &segmentor); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create segmentor, code = {}", ec); - return -1; - } - return (jlong)segmentor; +jlong Java_mmdeploy_Segmentor_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_segmentor_t segmentor{}; + auto ec = mmdeploy_segmentor_create_by_path(model_path, device_name, (int)device_id, &segmentor); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create segmentor, code = {}", ec); + return -1; + } + return (jlong)segmentor; } -void Java_mmdeploy_Segmentor_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_Segmentor_destroy"); - mmdeploy_segmentor_destroy((mmdeploy_segmentor_t)handle); +void Java_mmdeploy_Segmentor_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Segmentor_destroy"); + mmdeploy_segmentor_destroy((mmdeploy_segmentor_t)handle); } -jobjectArray Java_mmdeploy_Segmentor_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_Segmentor_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_segmentation_t *results{}; auto ec = mmdeploy_segmentor_apply((mmdeploy_segmentor_t)handle, imgs, size, &results); if (ec) { @@ -65,6 +68,5 @@ jobjectArray Java_mmdeploy_Segmentor_apply(JNIEnv *env, jobject thiz, jlong hand env->SetObjectArrayElement(array, i, res); } mmdeploy_segmentor_release_result(results, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.h index afdf157bec..ec42c52dd5 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.h @@ -3,33 +3,33 @@ /* Header for class mmdeploy_Segmentor */ #ifndef _Included_mmdeploy_Segmentor -#define _Included_mmdeploy_Segmentor -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Segmentor - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Segmentor_create(JNIEnv *, jobject, jstring, jstring, jint); + #define _Included_mmdeploy_Segmentor + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Segmentor + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Segmentor_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_Segmentor - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Segmentor_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Segmentor + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Segmentor_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_Segmentor - * Method: apply - * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/Segmentor/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Segmentor_apply(JNIEnv *, jobject, jlong, - jobjectArray); + /* + * Class: mmdeploy_Segmentor + * Method: apply + * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/Segmentor/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Segmentor_apply(JNIEnv*, jobject, jlong, jobjectArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.cpp index 943d1e625b..adc1abe5cd 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.cpp @@ -6,30 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_TextDetector_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_text_detector_t text_detector{}; - auto ec = mmdeploy_text_detector_create_by_path(model_path, device_name, (int)device_id, - &text_detector); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create text_detector, code = {}", ec); - return -1; - } - return (jlong)text_detector; +jlong Java_mmdeploy_TextDetector_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_text_detector_t text_detector{}; + auto ec = mmdeploy_text_detector_create_by_path(model_path, device_name, (int)device_id, &text_detector); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create text_detector, code = {}", ec); + return -1; + } + return (jlong)text_detector; } -void Java_mmdeploy_TextDetector_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_TextDetector_destroy"); - mmdeploy_text_detector_destroy((mmdeploy_text_detector_t)handle); +void Java_mmdeploy_TextDetector_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_TextDetector_destroy"); + mmdeploy_text_detector_destroy((mmdeploy_text_detector_t)handle); } -jobjectArray Java_mmdeploy_TextDetector_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jintArray counts) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_TextDetector_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jintArray counts) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_text_detection_t *results{}; int *result_count{}; auto ec = mmdeploy_text_detector_apply((mmdeploy_text_detector_t)handle, imgs, size, &results, @@ -61,6 +63,5 @@ jobjectArray Java_mmdeploy_TextDetector_apply(JNIEnv *env, jobject thiz, jlong h } env->ReleaseIntArrayElements(counts, counts_array, 0); mmdeploy_text_detector_release_result(results, result_count, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.h b/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.h index dc5574f77b..6a5df47924 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.h @@ -3,34 +3,33 @@ /* Header for class mmdeploy_TextDetector */ #ifndef _Included_mmdeploy_TextDetector -#define _Included_mmdeploy_TextDetector -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_TextDetector - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_TextDetector_create(JNIEnv *, jobject, jstring, jstring, - jint); + #define _Included_mmdeploy_TextDetector + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_TextDetector + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_TextDetector_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_TextDetector - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_TextDetector_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_TextDetector + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_TextDetector_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_TextDetector - * Method: apply - * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/TextDetector/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextDetector_apply(JNIEnv *, jobject, jlong, - jobjectArray, jintArray); + /* + * Class: mmdeploy_TextDetector + * Method: apply + * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/TextDetector/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextDetector_apply(JNIEnv*, jobject, jlong, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.cpp index 06987fb623..607b7c2ee8 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.cpp @@ -6,30 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_TextRecognizer_create(JNIEnv *env, jobject, jstring modelPath, - jstring deviceName, jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_text_recognizer_t text_recognizer{}; - auto ec = mmdeploy_text_recognizer_create_by_path(model_path, device_name, (int)device_id, - &text_recognizer); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create text recognizer, code = {}", ec); - return -1; - } - return (jlong)text_recognizer; +jlong Java_mmdeploy_TextRecognizer_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_text_recognizer_t text_recognizer{}; + auto ec = mmdeploy_text_recognizer_create_by_path(model_path, device_name, (int)device_id, &text_recognizer); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create text recognizer, code = {}", ec); + return -1; + } + return (jlong)text_recognizer; } -void Java_mmdeploy_TextRecognizer_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_TextRecognizer_destroy"); // maybe use info? - mmdeploy_text_recognizer_destroy((mmdeploy_text_recognizer_t)handle); +void Java_mmdeploy_TextRecognizer_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_TextRecognizer_destroy"); // maybe use info? + mmdeploy_text_recognizer_destroy((mmdeploy_text_recognizer_t)handle); } -jobjectArray Java_mmdeploy_TextRecognizer_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_TextRecognizer_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_text_recognition_t *results{}; auto ec = mmdeploy_text_recognizer_apply((mmdeploy_text_recognizer_t)handle, imgs, size, &results); @@ -51,13 +53,12 @@ jobjectArray Java_mmdeploy_TextRecognizer_apply(JNIEnv *env, jobject thiz, jlong env->SetObjectArrayElement(array, i, res); } mmdeploy_text_recognizer_release_result(results, size); - return array; - }); + return array; }); } -jobjectArray Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jobjectArray bboxes, - jintArray bbox_count) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) { +jobjectArray Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jobjectArray bboxes, jintArray bbox_count) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) + { mmdeploy_text_recognition_t *recog_results{}; auto *det_results = new mmdeploy_text_detection_t[env->GetArrayLength(bboxes)]; int *det_result_count = new int[env->GetArrayLength(bbox_count)]; @@ -100,6 +101,5 @@ jobjectArray Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv *env, jobject thiz, j } mmdeploy_text_recognizer_release_result(recog_results, size); mmdeploy_text_detector_release_result(det_results, det_result_count, 1); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.h b/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.h index 721c17f2b6..13ed048b7e 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.h @@ -3,43 +3,40 @@ /* Header for class mmdeploy_TextRecognizer */ #ifndef _Included_mmdeploy_TextRecognizer -#define _Included_mmdeploy_TextRecognizer -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_TextRecognizer - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_TextRecognizer_create(JNIEnv *, jobject, jstring, jstring, - jint); + #define _Included_mmdeploy_TextRecognizer + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_TextRecognizer + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_TextRecognizer_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_TextRecognizer - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_TextRecognizer_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_TextRecognizer + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_TextRecognizer_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_TextRecognizer - * Method: apply - * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/TextRecognizer/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextRecognizer_apply(JNIEnv *, jobject, jlong, - jobjectArray); + /* + * Class: mmdeploy_TextRecognizer + * Method: apply + * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/TextRecognizer/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextRecognizer_apply(JNIEnv*, jobject, jlong, jobjectArray); -/* - * Class: mmdeploy_TextRecognizer - * Method: applyBbox - * Signature: (J[Lmmdeploy/Mat;[Lmmdeploy/TextDetector/Result;[I)[Lmmdeploy/TextRecognizer/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv *, jobject, jlong, - jobjectArray, jobjectArray, - jintArray); + /* + * Class: mmdeploy_TextRecognizer + * Method: applyBbox + * Signature: (J[Lmmdeploy/Mat;[Lmmdeploy/TextDetector/Result;[I)[Lmmdeploy/TextRecognizer/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv*, jobject, jlong, jobjectArray, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/python/CMakeLists.txt b/csrc/mmdeploy/apis/python/CMakeLists.txt index 12e7946e31..173332d0f7 100644 --- a/csrc/mmdeploy/apis/python/CMakeLists.txt +++ b/csrc/mmdeploy/apis/python/CMakeLists.txt @@ -3,53 +3,48 @@ cmake_minimum_required(VERSION 3.14) project(mmdeploy_runtime) -set(MMDEPLOY_RUNTIME_SRCS - common.cpp - internal.cpp - pipeline.cpp) +set(MMDEPLOY_RUNTIME_SRCS common.cpp internal.cpp pipeline.cpp) set(CMAKE_CXX_STANDARD 17) -if (${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) - # standard alone project - add_subdirectory(${CMAKE_SOURCE_DIR}/../../../../third_party/pybind11 - ${CMAKE_CURRENT_BINARY_DIR}/pybind11) - find_package(MMDeploy REQUIRED) -elseif (NOT TARGET pybind11) - add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11) -endif () +if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) + # standard alone project + add_subdirectory(${CMAKE_SOURCE_DIR}/../../../../third_party/pybind11 + ${CMAKE_CURRENT_BINARY_DIR}/pybind11) + find_package(MMDeploy REQUIRED) +elseif(NOT TARGET pybind11) + add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11) +endif() -foreach (task_name ${MMDEPLOY_TASKS}) - list(APPEND MMDEPLOY_RUNTIME_SRCS ${task_name}.cpp) -endforeach () +foreach(task_name ${MMDEPLOY_TASKS}) + list(APPEND MMDEPLOY_RUNTIME_SRCS ${task_name}.cpp) +endforeach() pybind11_add_module(${PROJECT_NAME} ${MMDEPLOY_RUNTIME_SRCS}) # disable MMDEPLOY_CXX_USE_OPENCV in apis/cxx/mmdeploy/common.hpp target_compile_definitions(${PROJECT_NAME} PRIVATE -DMMDEPLOY_CXX_USE_OPENCV=0) -if (APPLE) - set_target_properties(${PROJECT_NAME} PROPERTIES - BUILD_RPATH "@loader_path" - INSTALL_RPATH "@loader_path") -else () - set_target_properties(${PROJECT_NAME} PROPERTIES - BUILD_RPATH "\$ORIGIN" - INSTALL_RPATH "\$ORIGIN") -endif () +if(APPLE) + set_target_properties(${PROJECT_NAME} PROPERTIES BUILD_RPATH "@loader_path" + INSTALL_RPATH "@loader_path") +else() + set_target_properties(${PROJECT_NAME} PROPERTIES BUILD_RPATH "\$ORIGIN" + INSTALL_RPATH "\$ORIGIN") +endif() # https://github.com/pybind/pybind11/issues/1604 -if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") - target_compile_options(${PROJECT_NAME} PRIVATE -fsized-deallocation) -endif () - -if (MMDEPLOY_BUILD_SDK_MONOLITHIC) - target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy) -else () - mmdeploy_load_static(${PROJECT_NAME} MMDeployStaticModules) - mmdeploy_load_dynamic(${PROJECT_NAME} MMDeployDynamicModules) - target_link_libraries(${PROJECT_NAME} PRIVATE MMDeployLibs) -endif () - -target_include_directories(${PROJECT_NAME} PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/.. - ${CMAKE_CURRENT_SOURCE_DIR}) +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + target_compile_options(${PROJECT_NAME} PRIVATE -fsized-deallocation) +endif() + +if(MMDEPLOY_BUILD_SDK_MONOLITHIC) + target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy) +else() + mmdeploy_load_static(${PROJECT_NAME} MMDeployStaticModules) + mmdeploy_load_dynamic(${PROJECT_NAME} MMDeployDynamicModules) + target_link_libraries(${PROJECT_NAME} PRIVATE MMDeployLibs) +endif() + +target_include_directories( + ${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CMAKE_CURRENT_SOURCE_DIR}) install(DIRECTORY ${CMAKE_SOURCE_DIR}/demo/python/ DESTINATION example/python) diff --git a/csrc/mmdeploy/apis/python/classifier.cpp b/csrc/mmdeploy/apis/python/classifier.cpp index 9916909c86..983b3357b5 100644 --- a/csrc/mmdeploy/apis/python/classifier.cpp +++ b/csrc/mmdeploy/apis/python/classifier.cpp @@ -4,64 +4,76 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyClassifier { - public: - PyClassifier(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_classifier_create_by_path(model_path, device_name, device_id, &classifier_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create classifier"); - } - } - ~PyClassifier() { - mmdeploy_classifier_destroy(classifier_); - classifier_ = {}; - } + class PyClassifier + { + public: + PyClassifier(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_classifier_create_by_path(model_path, device_name, device_id, &classifier_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create classifier"); + } + } + ~PyClassifier() + { + mmdeploy_classifier_destroy(classifier_); + classifier_ = {}; + } - std::vector>> Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_classification_t* results{}; - int* result_count{}; - auto status = mmdeploy_classifier_apply(classifier_, mats.data(), (int)mats.size(), &results, - &result_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply classifier, code: " + std::to_string(status)); - } - auto output = std::vector>>{}; - output.reserve(mats.size()); - auto result_ptr = results; - for (int i = 0; i < mats.size(); ++i) { - std::vector> label_score; - for (int j = 0; j < result_count[i]; ++j) { - label_score.emplace_back(result_ptr[j].label_id, result_ptr[j].score); - } - output.push_back(std::move(label_score)); - result_ptr += result_count[i]; - } - mmdeploy_classifier_release_result(results, result_count, (int)mats.size()); - return output; - } + std::vector>> Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_classification_t* results{}; + int* result_count{}; + auto status = mmdeploy_classifier_apply(classifier_, mats.data(), (int)mats.size(), &results, &result_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply classifier, code: " + std::to_string(status)); + } + auto output = std::vector>>{}; + output.reserve(mats.size()); + auto result_ptr = results; + for (int i = 0; i < mats.size(); ++i) + { + std::vector> label_score; + for (int j = 0; j < result_count[i]; ++j) + { + label_score.emplace_back(result_ptr[j].label_id, result_ptr[j].score); + } + output.push_back(std::move(label_score)); + result_ptr += result_count[i]; + } + mmdeploy_classifier_release_result(results, result_count, (int)mats.size()); + return output; + } - private: - mmdeploy_classifier_t classifier_{}; -}; + private: + mmdeploy_classifier_t classifier_{}; + }; -static PythonBindingRegisterer register_classifier{[](py::module& m) { - py::class_(m, "Classifier") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyClassifier* self, const PyImage& img) { return self->Apply(std::vector{img})[0]; }) - .def("batch", &PyClassifier::Apply); -}}; + static PythonBindingRegisterer register_classifier{[](py::module& m) + { + py::class_(m, "Classifier") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyClassifier* self, const PyImage& img) + { return self->Apply(std::vector{img})[0]; }) + .def("batch", &PyClassifier::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/common.cpp b/csrc/mmdeploy/apis/python/common.cpp index de4e1adf0a..72ed22089a 100644 --- a/csrc/mmdeploy/apis/python/common.cpp +++ b/csrc/mmdeploy/apis/python/common.cpp @@ -7,166 +7,214 @@ #include "mmdeploy/core/utils/formatter.h" #include "pybind11/numpy.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -std::vector& gPythonBindings() { - static std::vector v; - return v; -} - -mmdeploy_mat_t GetMat(const PyImage& img) { - auto info = img.request(); - if (info.ndim != 3) { - fprintf(stderr, "info.ndim = %d\n", (int)info.ndim); - throw std::runtime_error("continuous uint8 HWC array expected"); - } - auto channels = (int)info.shape[2]; - mmdeploy_mat_t mat{}; - if (channels == 1) { - mat.format = MMDEPLOY_PIXEL_FORMAT_GRAYSCALE; - } else if (channels == 3) { - mat.format = MMDEPLOY_PIXEL_FORMAT_BGR; - } else { - throw std::runtime_error("images of 1 or 3 channels are supported"); - } - mat.height = (int)info.shape[0]; - mat.width = (int)info.shape[1]; - mat.channel = channels; - mat.type = MMDEPLOY_DATA_TYPE_UINT8; - mat.data = (uint8_t*)info.ptr; - return mat; -} + std::vector& gPythonBindings() + { + static std::vector v; + return v; + } -py::object ToPyObject(const Value& value) { - switch (value.type()) { - case ValueType::kNull: - return py::none(); - case ValueType::kBool: - return py::bool_(value.get()); - case ValueType::kInt: - return py::int_(value.get()); - case ValueType::kUInt: - return py::int_(value.get()); - case ValueType::kFloat: - return py::float_(value.get()); - case ValueType::kString: - return py::str(value.get()); - case ValueType::kArray: { - py::list list; - for (const auto& x : value) { - list.append(ToPyObject(x)); - } - return list; + mmdeploy_mat_t GetMat(const PyImage& img) + { + auto info = img.request(); + if (info.ndim != 3) + { + fprintf(stderr, "info.ndim = %d\n", (int)info.ndim); + throw std::runtime_error("continuous uint8 HWC array expected"); + } + auto channels = (int)info.shape[2]; + mmdeploy_mat_t mat{}; + if (channels == 1) + { + mat.format = MMDEPLOY_PIXEL_FORMAT_GRAYSCALE; + } + else if (channels == 3) + { + mat.format = MMDEPLOY_PIXEL_FORMAT_BGR; + } + else + { + throw std::runtime_error("images of 1 or 3 channels are supported"); + } + mat.height = (int)info.shape[0]; + mat.width = (int)info.shape[1]; + mat.channel = channels; + mat.type = MMDEPLOY_DATA_TYPE_UINT8; + mat.data = (uint8_t*)info.ptr; + return mat; } - case ValueType::kObject: { - py::dict dict; - for (auto it = value.begin(); it != value.end(); ++it) { - dict[it.key().c_str()] = ToPyObject(*it); - } - return dict; + + py::object ToPyObject(const Value& value) + { + switch (value.type()) + { + case ValueType::kNull: + return py::none(); + case ValueType::kBool: + return py::bool_(value.get()); + case ValueType::kInt: + return py::int_(value.get()); + case ValueType::kUInt: + return py::int_(value.get()); + case ValueType::kFloat: + return py::float_(value.get()); + case ValueType::kString: + return py::str(value.get()); + case ValueType::kArray: + { + py::list list; + for (const auto& x : value) + { + list.append(ToPyObject(x)); + } + return list; + } + case ValueType::kObject: + { + py::dict dict; + for (auto it = value.begin(); it != value.end(); ++it) + { + dict[it.key().c_str()] = ToPyObject(*it); + } + return dict; + } + case ValueType::kAny: + return py::str(""); + default: + return py::str(""); + } } - case ValueType::kAny: - return py::str(""); - default: - return py::str(""); - } -} -std::optional _to_value_internal(const void* object, mmdeploy_context_type_t type); + std::optional _to_value_internal(const void* object, mmdeploy_context_type_t type); -Value FromPyObject(const py::object& obj) { - if (py::isinstance(obj)) { - return nullptr; - } else if (py::isinstance(obj)) { - return obj.cast(); - } else if (py::isinstance(obj)) { - return obj.cast(); - } else if (py::isinstance(obj)) { - return obj.cast(); - } else if (py::isinstance(obj)) { - return obj.cast(); - } else if (py::isinstance(obj) || py::isinstance(obj)) { - py::list src(obj); - Value::Array dst; - dst.reserve(src.size()); - for (const auto& item : src) { - dst.push_back(FromPyObject(py::reinterpret_borrow(item))); + Value FromPyObject(const py::object& obj) + { + if (py::isinstance(obj)) + { + return nullptr; + } + else if (py::isinstance(obj)) + { + return obj.cast(); + } + else if (py::isinstance(obj)) + { + return obj.cast(); + } + else if (py::isinstance(obj)) + { + return obj.cast(); + } + else if (py::isinstance(obj)) + { + return obj.cast(); + } + else if (py::isinstance(obj) || py::isinstance(obj)) + { + py::list src(obj); + Value::Array dst; + dst.reserve(src.size()); + for (const auto& item : src) + { + dst.push_back(FromPyObject(py::reinterpret_borrow(item))); + } + return dst; + } + else if (py::isinstance(obj)) + { + py::dict src(obj); + Value::Object dst; + for (const auto& item : src) + { + dst.emplace(item.first.cast(), + FromPyObject(py::reinterpret_borrow(item.second))); + } + return dst; + } + else if (py::isinstance(obj)) + { + const auto& array = obj.cast(); + return *_to_value_internal(&array, MMDEPLOY_TYPE_MAT); + } + else if (py::isinstance(obj)) + { + const auto& model = + *reinterpret_cast(static_cast(obj.cast())); + return model; + } + else + { + std::stringstream ss; + ss << obj.get_type(); + MMDEPLOY_ERROR("unsupported Python object type: {}", ss.str()); + return nullptr; + } + return nullptr; } - return dst; - } else if (py::isinstance(obj)) { - py::dict src(obj); - Value::Object dst; - for (const auto& item : src) { - dst.emplace(item.first.cast(), - FromPyObject(py::reinterpret_borrow(item.second))); - } - return dst; - } else if (py::isinstance(obj)) { - const auto& array = obj.cast(); - return *_to_value_internal(&array, MMDEPLOY_TYPE_MAT); - } else if (py::isinstance(obj)) { - const auto& model = - *reinterpret_cast(static_cast(obj.cast())); - return model; - } else { - std::stringstream ss; - ss << obj.get_type(); - MMDEPLOY_ERROR("unsupported Python object type: {}", ss.str()); - return nullptr; - } - return nullptr; -} -std::pair parse_device(const std::string& device) { - auto pos = device.find(':'); - if (pos == std::string::npos) { - return {device, 0}; // logic for index -1 is not ready on some devices - } - auto name = device.substr(0, pos); - auto index = std::stoi(device.substr(pos + 1)); - return {name, index}; -} + std::pair parse_device(const std::string& device) + { + auto pos = device.find(':'); + if (pos == std::string::npos) + { + return {device, 0}; // logic for index -1 is not ready on some devices + } + auto name = device.substr(0, pos); + auto index = std::stoi(device.substr(pos + 1)); + return {name, index}; + } -static PythonBindingRegisterer register_model{[](py::module& m) { - py::class_(m, "Model") - .def(py::init([](const py::str& path) { + static PythonBindingRegisterer register_model{[](py::module& m) + { + py::class_(m, "Model") + .def(py::init([](const py::str& path) + { MMDEPLOY_DEBUG("py::init([](const py::str& path)"); - return Model(path.cast().c_str()); - })) - .def(py::init([](const py::bytes& buffer) { + return Model(path.cast().c_str()); })) + .def(py::init([](const py::bytes& buffer) + { MMDEPLOY_DEBUG("py::init([](const py::bytes& buffer)"); py::buffer_info info(py::buffer(buffer).request()); - return Model(info.ptr, info.size); - })); -}}; + return Model(info.ptr, info.size); })); + }}; -static PythonBindingRegisterer register_device{[](py::module& m) { - py::class_(m, "Device") - .def(py::init([](const std::string& device) { + static PythonBindingRegisterer register_device{[](py::module& m) + { + py::class_(m, "Device") + .def(py::init([](const std::string& device) + { auto [name, index] = parse_device(device); - return Device(name, index); - })) - .def(py::init([](const std::string& name, int index) { return Device(name, index); })); -}}; + return Device(name, index); })) + .def(py::init([](const std::string& name, int index) + { return Device(name, index); })); + }}; -static PythonBindingRegisterer register_context{[](py::module& m) { - py::class_(m, "Context") - .def(py::init([](const Device& device) { return Context(device); })) - .def("add", [](Context* self, const std::string& name, const Scheduler& sched) { - self->Add(name, sched); - }); -}}; + static PythonBindingRegisterer register_context{[](py::module& m) + { + py::class_(m, "Context") + .def(py::init([](const Device& device) + { return Context(device); })) + .def("add", [](Context* self, const std::string& name, const Scheduler& sched) + { self->Add(name, sched); }); + }}; -static PythonBindingRegisterer register_scheduler{[](py::module& m) { - py::class_(m, "Scheduler") - .def_static("thread_pool", [](int n_workers) { return Scheduler::ThreadPool(n_workers); }) - .def_static("thread", [] { return Scheduler::Thread(); }); -}}; + static PythonBindingRegisterer register_scheduler{[](py::module& m) + { + py::class_(m, "Scheduler") + .def_static("thread_pool", [](int n_workers) + { return Scheduler::ThreadPool(n_workers); }) + .def_static("thread", [] + { return Scheduler::Thread(); }); + }}; } // namespace mmdeploy::python -PYBIND11_MODULE(mmdeploy_runtime, m) { - for (const auto& f : mmdeploy::python::gPythonBindings()) { - f(m); - } +PYBIND11_MODULE(mmdeploy_runtime, m) +{ + for (const auto& f : mmdeploy::python::gPythonBindings()) + { + f(m); + } } diff --git a/csrc/mmdeploy/apis/python/common.h b/csrc/mmdeploy/apis/python/common.h index 5b1ca96b74..e50ed76007 100644 --- a/csrc/mmdeploy/apis/python/common.h +++ b/csrc/mmdeploy/apis/python/common.h @@ -13,24 +13,27 @@ namespace py = pybind11; -namespace mmdeploy::python { +namespace mmdeploy::python +{ -using PyImage = py::array_t; + using PyImage = py::array_t; -std::vector& gPythonBindings(); + std::vector& gPythonBindings(); -mmdeploy_mat_t GetMat(const PyImage& img); + mmdeploy_mat_t GetMat(const PyImage& img); -py::object ToPyObject(const Value& value); + py::object ToPyObject(const Value& value); -Value FromPyObject(const py::object& obj); + Value FromPyObject(const py::object& obj); -class PythonBindingRegisterer { - public: - explicit PythonBindingRegisterer(void (*register_fn)(py::module& m)) { - gPythonBindings().push_back(register_fn); - } -}; + class PythonBindingRegisterer + { + public: + explicit PythonBindingRegisterer(void (*register_fn)(py::module& m)) + { + gPythonBindings().push_back(register_fn); + } + }; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/detector.cpp b/csrc/mmdeploy/apis/python/detector.cpp index 057a92ab00..137998f6b7 100644 --- a/csrc/mmdeploy/apis/python/detector.cpp +++ b/csrc/mmdeploy/apis/python/detector.cpp @@ -4,82 +4,97 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyDetector { - public: - PyDetector(const char* model_path, const char* device_name, int device_id) { - auto status = mmdeploy_detector_create_by_path(model_path, device_name, device_id, &detector_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create detector"); - } - } - py::list Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_detection_t* detection{}; - int* result_count{}; - auto status = mmdeploy_detector_apply(detector_, mats.data(), (int)mats.size(), &detection, - &result_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply detector, code: " + std::to_string(status)); - } - using Sptr = std::shared_ptr; - Sptr holder(detection, [result_count, n = mats.size()](auto p) { - mmdeploy_detector_release_result(p, result_count, n); - }); - auto output = py::list{}; - auto result = detection; - for (int i = 0; i < mats.size(); ++i) { - auto bboxes = py::array_t({result_count[i], 5}); - auto labels = py::array_t(result_count[i]); - auto masks = std::vector(); - masks.reserve(result_count[i]); - for (int j = 0; j < result_count[i]; ++j, ++result) { - auto bbox = bboxes.mutable_data(j); - bbox[0] = result->bbox.left; - bbox[1] = result->bbox.top; - bbox[2] = result->bbox.right; - bbox[3] = result->bbox.bottom; - bbox[4] = result->score; - labels.mutable_at(j) = result->label_id; - if (result->mask) { - masks.emplace_back(std::array{result->mask->height, result->mask->width}, // shape - reinterpret_cast(result->mask->data), // data - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); })); - } else { - masks.emplace_back(); + class PyDetector + { + public: + PyDetector(const char* model_path, const char* device_name, int device_id) + { + auto status = mmdeploy_detector_create_by_path(model_path, device_name, device_id, &detector_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create detector"); + } + } + py::list Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_detection_t* detection{}; + int* result_count{}; + auto status = mmdeploy_detector_apply(detector_, mats.data(), (int)mats.size(), &detection, &result_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply detector, code: " + std::to_string(status)); + } + using Sptr = std::shared_ptr; + Sptr holder(detection, [result_count, n = mats.size()](auto p) + { mmdeploy_detector_release_result(p, result_count, n); }); + auto output = py::list{}; + auto result = detection; + for (int i = 0; i < mats.size(); ++i) + { + auto bboxes = py::array_t({result_count[i], 5}); + auto labels = py::array_t(result_count[i]); + auto masks = std::vector(); + masks.reserve(result_count[i]); + for (int j = 0; j < result_count[i]; ++j, ++result) + { + auto bbox = bboxes.mutable_data(j); + bbox[0] = result->bbox.left; + bbox[1] = result->bbox.top; + bbox[2] = result->bbox.right; + bbox[3] = result->bbox.bottom; + bbox[4] = result->score; + labels.mutable_at(j) = result->label_id; + if (result->mask) + { + masks.emplace_back(std::array{result->mask->height, result->mask->width}, // shape + reinterpret_cast(result->mask->data), // data + py::capsule(new Sptr(holder), // handle + [](void* p) + { delete reinterpret_cast(p); })); + } + else + { + masks.emplace_back(); + } + } + output.append(py::make_tuple(std::move(bboxes), std::move(labels), std::move(masks))); + } + return output; + } + ~PyDetector() + { + mmdeploy_detector_destroy(detector_); + detector_ = {}; } - } - output.append(py::make_tuple(std::move(bboxes), std::move(labels), std::move(masks))); - } - return output; - } - ~PyDetector() { - mmdeploy_detector_destroy(detector_); - detector_ = {}; - } - private: - mmdeploy_detector_t detector_{}; -}; + private: + mmdeploy_detector_t detector_{}; + }; -static PythonBindingRegisterer register_detector{[](py::module& m) { - py::class_(m, "Detector") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyDetector* self, const PyImage& img) -> py::tuple { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PyDetector::Apply); -}}; + static PythonBindingRegisterer register_detector{[](py::module& m) + { + py::class_(m, "Detector") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyDetector* self, const PyImage& img) -> py::tuple + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyDetector::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/executor.cpp b/csrc/mmdeploy/apis/python/executor.cpp index eaa5c1144b..489985f232 100644 --- a/csrc/mmdeploy/apis/python/executor.cpp +++ b/csrc/mmdeploy/apis/python/executor.cpp @@ -8,39 +8,48 @@ #include "mmdeploy/execution/schedulers/single_thread_context.h" #include "mmdeploy/execution/schedulers/static_thread_pool.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -struct PySender { - TypeErasedSender sender_; - - explicit PySender(TypeErasedSender sender) : sender_(std::move(sender)) {} - - struct gil_guarded_deleter { - void operator()(py::object* p) const { - py::gil_scoped_acquire _; - delete p; - } - }; - using object_ptr = std::unique_ptr; - - py::object __await__() { - auto future = py::module::import("concurrent.futures").attr("Future")(); + struct PySender { - py::gil_scoped_release _; - StartDetached(std::move(sender_) | - Then([future = object_ptr{new py::object(future)}](const Value& value) mutable { + TypeErasedSender sender_; + + explicit PySender(TypeErasedSender sender) + : sender_(std::move(sender)) + { + } + + struct gil_guarded_deleter + { + void operator()(py::object* p) const + { + py::gil_scoped_acquire _; + delete p; + } + }; + using object_ptr = std::unique_ptr; + + py::object __await__() + { + auto future = py::module::import("concurrent.futures").attr("Future")(); + { + py::gil_scoped_release _; + StartDetached(std::move(sender_) | + Then([future = object_ptr{new py::object(future)}](const Value& value) mutable + { py::gil_scoped_acquire _; future->attr("set_result")(ToPyObject(value)); - delete future.release(); - })); - } - return py::module::import("asyncio").attr("wrap_future")(future).attr("__await__")(); - } -}; - -static PythonBindingRegisterer register_sender{[](py::module& m) { - py::class_>(m, "PySender") - .def("__await__", &PySender::__await__); -}}; + delete future.release(); })); + } + return py::module::import("asyncio").attr("wrap_future")(future).attr("__await__")(); + } + }; + + static PythonBindingRegisterer register_sender{[](py::module& m) + { + py::class_>(m, "PySender") + .def("__await__", &PySender::__await__); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/internal.cpp b/csrc/mmdeploy/apis/python/internal.cpp index 7373c1f184..8c38f5a7ce 100644 --- a/csrc/mmdeploy/apis/python/internal.cpp +++ b/csrc/mmdeploy/apis/python/internal.cpp @@ -9,49 +9,60 @@ #include "mmdeploy/core/model.h" #include "mmdeploy/core/value.h" -namespace mmdeploy { - -namespace python { - -framework::Mat _get_mat(const PyImage& img) { - auto info = img.request(); - if (info.ndim != 3) { - fprintf(stderr, "info.ndim = %d\n", (int)info.ndim); - throw std::runtime_error("continuous uint8 HWC array expected"); - } - auto channels = (int)info.shape[2]; - PixelFormat format; - if (channels == 1) { - format = PixelFormat::kGRAYSCALE; - } else if (channels == 3) { - format = PixelFormat::kBGR; - } else { - throw std::runtime_error("images of 1 or 3 channels are supported"); - } - - return { - (int)info.shape[0], // height - (int)info.shape[1], // width - format, // format - DataType::kINT8, // type - std::shared_ptr(info.ptr, [](void*) {}), // data - framework::Device(0), // device - }; -} - -std::optional _to_value_internal(const void* object, mmdeploy_context_type_t type) { - switch (type) { - case MMDEPLOY_TYPE_MODEL: - return Value(*(const framework::Model*)object); - case MMDEPLOY_TYPE_DEVICE: - return Value(*(const framework::Device*)object); - case MMDEPLOY_TYPE_MAT: - return _get_mat(*(const py::array*)object); - default: - return std::nullopt; - } -} - -} // namespace python +namespace mmdeploy +{ + + namespace python + { + + framework::Mat _get_mat(const PyImage& img) + { + auto info = img.request(); + if (info.ndim != 3) + { + fprintf(stderr, "info.ndim = %d\n", (int)info.ndim); + throw std::runtime_error("continuous uint8 HWC array expected"); + } + auto channels = (int)info.shape[2]; + PixelFormat format; + if (channels == 1) + { + format = PixelFormat::kGRAYSCALE; + } + else if (channels == 3) + { + format = PixelFormat::kBGR; + } + else + { + throw std::runtime_error("images of 1 or 3 channels are supported"); + } + + return { + (int)info.shape[0], // height + (int)info.shape[1], // width + format, // format + DataType::kINT8, // type + std::shared_ptr(info.ptr, [](void*) {}), // data + framework::Device(0), // device + }; + } + + std::optional _to_value_internal(const void* object, mmdeploy_context_type_t type) + { + switch (type) + { + case MMDEPLOY_TYPE_MODEL: + return Value(*(const framework::Model*)object); + case MMDEPLOY_TYPE_DEVICE: + return Value(*(const framework::Device*)object); + case MMDEPLOY_TYPE_MAT: + return _get_mat(*(const py::array*)object); + default: + return std::nullopt; + } + } + + } // namespace python } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/python/pipeline.cpp b/csrc/mmdeploy/apis/python/pipeline.cpp index e3e6237e44..114bce2095 100644 --- a/csrc/mmdeploy/apis/python/pipeline.cpp +++ b/csrc/mmdeploy/apis/python/pipeline.cpp @@ -7,41 +7,47 @@ #include "mmdeploy/core/logger.h" #include "mmdeploy/core/utils/formatter.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -using namespace std::literals; + using namespace std::literals; -static PythonBindingRegisterer register_pipeline{[](py::module& m) { - py::class_(m, "Pipeline") - .def(py::init([](const py::object& config, const Context& context) { + static PythonBindingRegisterer register_pipeline{[](py::module& m) + { + py::class_(m, "Pipeline") + .def(py::init([](const py::object& config, const Context& context) + { auto _config = FromPyObject(config); - return std::make_unique(_config, context); - })) - .def("__call__", - [](Pipeline* pipeline, const py::args& args) { - auto inputs = FromPyObject(args); - for (auto& input : inputs) { - input = Value::Array{std::move(input)}; - } - auto outputs = pipeline->Apply(inputs); - for (auto& output : outputs) { - output = std::move(output[0]); - } - py::tuple rets(outputs.size()); - for (int i = 0; i < outputs.size(); ++i) { - rets[i] = ToPyObject(outputs[i]); - } - return rets; - }) - .def("batch", [](Pipeline* pipeline, const py::args& args) { + return std::make_unique(_config, context); })) + .def("__call__", + [](Pipeline* pipeline, const py::args& args) + { + auto inputs = FromPyObject(args); + for (auto& input : inputs) + { + input = Value::Array{std::move(input)}; + } + auto outputs = pipeline->Apply(inputs); + for (auto& output : outputs) + { + output = std::move(output[0]); + } + py::tuple rets(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) + { + rets[i] = ToPyObject(outputs[i]); + } + return rets; + }) + .def("batch", [](Pipeline* pipeline, const py::args& args) + { auto inputs = FromPyObject(args); auto outputs = pipeline->Apply(inputs); py::tuple rets(outputs.size()); for (int i = 0; i < outputs.size(); ++i) { rets[i] = ToPyObject(outputs[i]); } - return rets; - }); -}}; + return rets; }); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/pose_detector.cpp b/csrc/mmdeploy/apis/python/pose_detector.cpp index f9d99eaf14..b6dc96560a 100644 --- a/csrc/mmdeploy/apis/python/pose_detector.cpp +++ b/csrc/mmdeploy/apis/python/pose_detector.cpp @@ -7,122 +7,143 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -using Rect = std::array; + using Rect = std::array; -class PyPoseDetector { - public: - PyPoseDetector(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_pose_detector_create_by_path(model_path, device_name, device_id, &detector_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create pose_detector"); - } - } - py::list Apply(const std::vector& imgs, const std::vector>& bboxes) { - if (imgs.size() == 0 && bboxes.size() == 0) { - return py::list{}; - } - if (bboxes.size() != 0 && bboxes.size() != imgs.size()) { - std::ostringstream os; - os << "imgs length not equal with vboxes [" << imgs.size() << " vs " << bboxes.size() << "]"; - throw std::invalid_argument(os.str()); - } + class PyPoseDetector + { + public: + PyPoseDetector(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_pose_detector_create_by_path(model_path, device_name, device_id, &detector_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create pose_detector"); + } + } + py::list Apply(const std::vector& imgs, const std::vector>& bboxes) + { + if (imgs.size() == 0 && bboxes.size() == 0) + { + return py::list{}; + } + if (bboxes.size() != 0 && bboxes.size() != imgs.size()) + { + std::ostringstream os; + os << "imgs length not equal with vboxes [" << imgs.size() << " vs " << bboxes.size() << "]"; + throw std::invalid_argument(os.str()); + } - std::vector mats; - std::vector boxes; - std::vector bbox_count; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } + std::vector mats; + std::vector boxes; + std::vector bbox_count; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } - for (auto _boxes : bboxes) { - for (auto _box : _boxes) { - mmdeploy_rect_t box = {_box[0], _box[1], _box[2], _box[3]}; - boxes.push_back(box); - } - bbox_count.push_back(_boxes.size()); - } + for (auto _boxes : bboxes) + { + for (auto _box : _boxes) + { + mmdeploy_rect_t box = {_box[0], _box[1], _box[2], _box[3]}; + boxes.push_back(box); + } + bbox_count.push_back(_boxes.size()); + } - // full image - if (bboxes.size() == 0) { - for (int i = 0; i < mats.size(); i++) { - mmdeploy_rect_t box = {0.f, 0.f, mats[i].width - 1.f, mats[i].height - 1.f}; - boxes.push_back(box); - bbox_count.push_back(1); - } - } + // full image + if (bboxes.size() == 0) + { + for (int i = 0; i < mats.size(); i++) + { + mmdeploy_rect_t box = {0.f, 0.f, mats[i].width - 1.f, mats[i].height - 1.f}; + boxes.push_back(box); + bbox_count.push_back(1); + } + } - mmdeploy_pose_detection_t* detection{}; - auto status = mmdeploy_pose_detector_apply_bbox(detector_, mats.data(), (int)mats.size(), - boxes.data(), bbox_count.data(), &detection); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply pose_detector, code: " + std::to_string(status)); - } + mmdeploy_pose_detection_t* detection{}; + auto status = mmdeploy_pose_detector_apply_bbox(detector_, mats.data(), (int)mats.size(), boxes.data(), bbox_count.data(), &detection); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply pose_detector, code: " + std::to_string(status)); + } - auto output = py::list{}; - auto result = detection; - for (int i = 0; i < mats.size(); i++) { - int n_point = result->length; - auto pred = py::array_t({bbox_count[i], n_point, 3}); - auto dst = pred.mutable_data(); - for (int j = 0; j < bbox_count[i]; j++) { - for (int k = 0; k < n_point; k++) { - dst[0] = result->point[k].x; - dst[1] = result->point[k].y; - dst[2] = result->score[k]; - dst += 3; - } - result++; - } - output.append(std::move(pred)); - } + auto output = py::list{}; + auto result = detection; + for (int i = 0; i < mats.size(); i++) + { + int n_point = result->length; + auto pred = py::array_t({bbox_count[i], n_point, 3}); + auto dst = pred.mutable_data(); + for (int j = 0; j < bbox_count[i]; j++) + { + for (int k = 0; k < n_point; k++) + { + dst[0] = result->point[k].x; + dst[1] = result->point[k].y; + dst[2] = result->score[k]; + dst += 3; + } + result++; + } + output.append(std::move(pred)); + } - int total = std::accumulate(bbox_count.begin(), bbox_count.end(), 0); - mmdeploy_pose_detector_release_result(detection, total); - return output; - } - ~PyPoseDetector() { - mmdeploy_pose_detector_destroy(detector_); - detector_ = {}; - } + int total = std::accumulate(bbox_count.begin(), bbox_count.end(), 0); + mmdeploy_pose_detector_release_result(detection, total); + return output; + } + ~PyPoseDetector() + { + mmdeploy_pose_detector_destroy(detector_); + detector_ = {}; + } - private: - mmdeploy_pose_detector_t detector_{}; -}; + private: + mmdeploy_pose_detector_t detector_{}; + }; -static PythonBindingRegisterer register_pose_detector{[](py::module& m) { - py::class_(m, "PoseDetector") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyPoseDetector* self, const PyImage& img) -> py::array { - return self->Apply({img}, {})[0]; - }) - .def( - "__call__", - [](PyPoseDetector* self, const PyImage& img, const Rect& box) -> py::array { - std::vector> bboxes; - bboxes.push_back({box}); - return self->Apply({img}, bboxes)[0]; - }, - py::arg("img"), py::arg("box")) - .def( - "__call__", - [](PyPoseDetector* self, const PyImage& img, - const std::vector& bboxes) -> py::array { - std::vector> _bboxes; - _bboxes.push_back(bboxes); - return self->Apply({img}, _bboxes)[0]; - }, - py::arg("img"), py::arg("bboxes")) - .def("batch", &PyPoseDetector::Apply, py::arg("imgs"), - py::arg("bboxes") = std::vector>()); -}}; + static PythonBindingRegisterer register_pose_detector{[](py::module& m) + { + py::class_(m, "PoseDetector") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyPoseDetector* self, const PyImage& img) -> py::array + { + return self->Apply({img}, {})[0]; + }) + .def( + "__call__", + [](PyPoseDetector* self, const PyImage& img, const Rect& box) -> py::array + { + std::vector> bboxes; + bboxes.push_back({box}); + return self->Apply({img}, bboxes)[0]; + }, + py::arg("img"), + py::arg("box")) + .def( + "__call__", + [](PyPoseDetector* self, const PyImage& img, const std::vector& bboxes) -> py::array + { + std::vector> _bboxes; + _bboxes.push_back(bboxes); + return self->Apply({img}, _bboxes)[0]; + }, + py::arg("img"), + py::arg("bboxes")) + .def("batch", &PyPoseDetector::Apply, py::arg("imgs"), py::arg("bboxes") = std::vector>()); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/pose_tracker.cpp b/csrc/mmdeploy/apis/python/pose_tracker.cpp index 035ce3cdd1..c14f2450e8 100644 --- a/csrc/mmdeploy/apis/python/pose_tracker.cpp +++ b/csrc/mmdeploy/apis/python/pose_tracker.cpp @@ -5,146 +5,200 @@ #include "common.h" #include "mmdeploy/common.hpp" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -namespace { + namespace + { -std::vector Apply(mmdeploy::PoseTracker* self, - const std::vector& _states, - const std::vector& _frames, std::vector detect) { - std::vector tmp; - for (const auto& s : _states) { - tmp.push_back(static_cast(*s)); - } - mmdeploy::Span states(reinterpret_cast(tmp.data()), tmp.size()); - std::vector frames; - for (const auto& f : _frames) { - frames.emplace_back(GetMat(f)); - } - if (detect.empty()) { - detect.resize(frames.size(), -1); - } - assert(states.size() == frames.size()); - assert(states.size() == detect.size()); - auto results = self->Apply(states, frames, detect); - std::vector batch_ret; - batch_ret.reserve(frames.size()); - for (const auto& rs : results) { - py::array_t keypoints( - {static_cast(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3}); - py::array_t bboxes({static_cast(rs.size()), 4}); - py::array_t track_ids(static_cast(rs.size())); - auto kpts_ptr = keypoints.mutable_data(); - auto bbox_ptr = bboxes.mutable_data(); - auto track_id_ptr = track_ids.mutable_data(); - for (const auto& r : rs) { - for (int i = 0; i < r.keypoint_count; ++i) { - kpts_ptr[0] = r.keypoints[i].x; - kpts_ptr[1] = r.keypoints[i].y; - kpts_ptr[2] = r.scores[i]; - kpts_ptr += 3; - } - { - auto tmp_bbox = (std::array&)r.bbox; - bbox_ptr[0] = tmp_bbox[0]; - bbox_ptr[1] = tmp_bbox[1]; - bbox_ptr[2] = tmp_bbox[2]; - bbox_ptr[3] = tmp_bbox[3]; - bbox_ptr += 4; - } - *track_id_ptr++ = r.target_id; - } - batch_ret.push_back( - py::make_tuple(std::move(keypoints), std::move(bboxes), std::move(track_ids))); - } - return batch_ret; -} + std::vector Apply(mmdeploy::PoseTracker* self, + const std::vector& _states, + const std::vector& _frames, + std::vector detect) + { + std::vector tmp; + for (const auto& s : _states) + { + tmp.push_back(static_cast(*s)); + } + mmdeploy::Span states(reinterpret_cast(tmp.data()), tmp.size()); + std::vector frames; + for (const auto& f : _frames) + { + frames.emplace_back(GetMat(f)); + } + if (detect.empty()) + { + detect.resize(frames.size(), -1); + } + assert(states.size() == frames.size()); + assert(states.size() == detect.size()); + auto results = self->Apply(states, frames, detect); + std::vector batch_ret; + batch_ret.reserve(frames.size()); + for (const auto& rs : results) + { + py::array_t keypoints( + {static_cast(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3}); + py::array_t bboxes({static_cast(rs.size()), 4}); + py::array_t track_ids(static_cast(rs.size())); + auto kpts_ptr = keypoints.mutable_data(); + auto bbox_ptr = bboxes.mutable_data(); + auto track_id_ptr = track_ids.mutable_data(); + for (const auto& r : rs) + { + for (int i = 0; i < r.keypoint_count; ++i) + { + kpts_ptr[0] = r.keypoints[i].x; + kpts_ptr[1] = r.keypoints[i].y; + kpts_ptr[2] = r.scores[i]; + kpts_ptr += 3; + } + { + auto tmp_bbox = (std::array&)r.bbox; + bbox_ptr[0] = tmp_bbox[0]; + bbox_ptr[1] = tmp_bbox[1]; + bbox_ptr[2] = tmp_bbox[2]; + bbox_ptr[3] = tmp_bbox[3]; + bbox_ptr += 4; + } + *track_id_ptr++ = r.target_id; + } + batch_ret.push_back( + py::make_tuple(std::move(keypoints), std::move(bboxes), std::move(track_ids))); + } + return batch_ret; + } -template -void Copy(const py::handle& h, T (&a)[N]) { - auto array = h.cast>(); - assert(array.size() == N); - auto data = array.data(); - for (int i = 0; i < N; ++i) { - a[i] = data[i]; - } -} + template + void Copy(const py::handle& h, T (&a)[N]) + { + auto array = h.cast>(); + assert(array.size() == N); + auto data = array.data(); + for (int i = 0; i < N; ++i) + { + a[i] = data[i]; + } + } -void Parse(const py::dict& dict, PoseTracker::Params& params, py::array_t& sigmas) { - for (const auto& [_name, value] : dict) { - auto name = _name.cast(); - if (name == "det_interval") { - params->det_interval = value.cast(); - } else if (name == "det_label") { - params->det_label = value.cast(); - } else if (name == "det_thr") { - params->det_thr = value.cast(); - } else if (name == "det_min_bbox_size") { - params->det_min_bbox_size = value.cast(); - } else if (name == "det_nms_thr") { - params->det_nms_thr = value.cast(); - } else if (name == "pose_max_num_bboxes") { - params->pose_max_num_bboxes = value.cast(); - } else if (name == "pose_min_keypoints") { - params->pose_min_keypoints = value.cast(); - } else if (name == "pose_min_bbox_size") { - params->pose_min_bbox_size = value.cast(); - } else if (name == "pose_nms_thr") { - params->pose_nms_thr = value.cast(); - } else if (name == "track_kpt_thr") { - params->pose_kpt_thr = value.cast(); - } else if (name == "track_iou_thr") { - params->track_iou_thr = value.cast(); - } else if (name == "pose_bbox_scale") { - params->pose_bbox_scale = value.cast(); - } else if (name == "track_max_missing") { - params->track_max_missing = value.cast(); - } else if (name == "track_history_size") { - params->track_history_size = value.cast(); - } else if (name == "keypoint_sigmas") { - sigmas = value.cast>(); - params->keypoint_sigmas = const_cast(sigmas.data()); - params->keypoint_sigmas_size = sigmas.size(); - } else if (name == "std_weight_position") { - params->std_weight_position = value.cast(); - } else if (name == "std_weight_velocity") { - params->std_weight_velocity = value.cast(); - } else if (name == "smooth_params") { - Copy(value, params->smooth_params); - } else { - MMDEPLOY_ERROR("unused argument: {}", name); - } - } -} + void Parse(const py::dict& dict, PoseTracker::Params& params, py::array_t& sigmas) + { + for (const auto& [_name, value] : dict) + { + auto name = _name.cast(); + if (name == "det_interval") + { + params->det_interval = value.cast(); + } + else if (name == "det_label") + { + params->det_label = value.cast(); + } + else if (name == "det_thr") + { + params->det_thr = value.cast(); + } + else if (name == "det_min_bbox_size") + { + params->det_min_bbox_size = value.cast(); + } + else if (name == "det_nms_thr") + { + params->det_nms_thr = value.cast(); + } + else if (name == "pose_max_num_bboxes") + { + params->pose_max_num_bboxes = value.cast(); + } + else if (name == "pose_min_keypoints") + { + params->pose_min_keypoints = value.cast(); + } + else if (name == "pose_min_bbox_size") + { + params->pose_min_bbox_size = value.cast(); + } + else if (name == "pose_nms_thr") + { + params->pose_nms_thr = value.cast(); + } + else if (name == "track_kpt_thr") + { + params->pose_kpt_thr = value.cast(); + } + else if (name == "track_iou_thr") + { + params->track_iou_thr = value.cast(); + } + else if (name == "pose_bbox_scale") + { + params->pose_bbox_scale = value.cast(); + } + else if (name == "track_max_missing") + { + params->track_max_missing = value.cast(); + } + else if (name == "track_history_size") + { + params->track_history_size = value.cast(); + } + else if (name == "keypoint_sigmas") + { + sigmas = value.cast>(); + params->keypoint_sigmas = const_cast(sigmas.data()); + params->keypoint_sigmas_size = sigmas.size(); + } + else if (name == "std_weight_position") + { + params->std_weight_position = value.cast(); + } + else if (name == "std_weight_velocity") + { + params->std_weight_velocity = value.cast(); + } + else if (name == "smooth_params") + { + Copy(value, params->smooth_params); + } + else + { + MMDEPLOY_ERROR("unused argument: {}", name); + } + } + } -} // namespace + } // namespace -static PythonBindingRegisterer register_pose_tracker{[](py::module& m) { - py::class_(m, "PoseTracker.State"); - py::class_(m, "PoseTracker") - .def(py::init([](const char* det_model_path, const char* pose_model_path, - const char* device_name, int device_id) { - return mmdeploy::PoseTracker( - mmdeploy::Model(det_model_path), mmdeploy::Model(pose_model_path), - mmdeploy::Context(mmdeploy::Device(device_name, device_id))); - }), - py::arg("det_model"), py::arg("pose_model"), py::arg("device_name"), - py::arg("device_id") = 0) - .def( - "__call__", - [](mmdeploy::PoseTracker* self, mmdeploy::PoseTracker::State* state, const PyImage& img, - int detect) { return Apply(self, {state}, {img}, {detect})[0]; }, - py::arg("state"), py::arg("frame"), py::arg("detect") = -1) - .def("batch", &Apply, py::arg("states"), py::arg("frames"), - py::arg("detects") = std::vector{}) - .def("create_state", [](mmdeploy::PoseTracker* self, const py::kwargs& kwargs) { + static PythonBindingRegisterer register_pose_tracker{[](py::module& m) + { + py::class_(m, "PoseTracker.State"); + py::class_(m, "PoseTracker") + .def(py::init([](const char* det_model_path, const char* pose_model_path, const char* device_name, int device_id) + { return mmdeploy::PoseTracker( + mmdeploy::Model(det_model_path), + mmdeploy::Model(pose_model_path), + mmdeploy::Context(mmdeploy::Device(device_name, device_id))); }), + py::arg("det_model"), + py::arg("pose_model"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def( + "__call__", + [](mmdeploy::PoseTracker* self, mmdeploy::PoseTracker::State* state, const PyImage& img, int detect) + { return Apply(self, {state}, {img}, {detect})[0]; }, + py::arg("state"), + py::arg("frame"), + py::arg("detect") = -1) + .def("batch", &Apply, py::arg("states"), py::arg("frames"), py::arg("detects") = std::vector{}) + .def("create_state", [](mmdeploy::PoseTracker* self, const py::kwargs& kwargs) + { PoseTracker::Params params; py::array_t sigmas; if (kwargs) { Parse(kwargs, params, sigmas); } - return self->CreateState(params); - }); -}}; + return self->CreateState(params); }); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/restorer.cpp b/csrc/mmdeploy/apis/python/restorer.cpp index 771af2a6c4..ddd4c0a8ff 100644 --- a/csrc/mmdeploy/apis/python/restorer.cpp +++ b/csrc/mmdeploy/apis/python/restorer.cpp @@ -4,63 +4,77 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyRestorer { - public: - PyRestorer(const char* model_path, const char* device_name, int device_id) { - auto status = mmdeploy_restorer_create_by_path(model_path, device_name, device_id, &restorer_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create restorer"); - } - } - ~PyRestorer() { - mmdeploy_restorer_destroy(restorer_); - restorer_ = {}; - } + class PyRestorer + { + public: + PyRestorer(const char* model_path, const char* device_name, int device_id) + { + auto status = mmdeploy_restorer_create_by_path(model_path, device_name, device_id, &restorer_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create restorer"); + } + } + ~PyRestorer() + { + mmdeploy_restorer_destroy(restorer_); + restorer_ = {}; + } - std::vector Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_mat_t* results{}; - auto status = mmdeploy_restorer_apply(restorer_, mats.data(), (int)mats.size(), &results); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply restorer, code: " + std::to_string(status)); - } - using Sptr = std::shared_ptr; - Sptr holder(results, [n = mats.size()](auto p) { mmdeploy_restorer_release_result(p, n); }); + std::vector Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_mat_t* results{}; + auto status = mmdeploy_restorer_apply(restorer_, mats.data(), (int)mats.size(), &results); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply restorer, code: " + std::to_string(status)); + } + using Sptr = std::shared_ptr; + Sptr holder(results, [n = mats.size()](auto p) + { mmdeploy_restorer_release_result(p, n); }); - std::vector rets(mats.size()); - for (int i = 0; i < mats.size(); ++i) { - rets[i] = { - {results[i].height, results[i].width, results[i].channel}, // shape - results[i].data, // data - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); }) // - }; - } - return rets; - } + std::vector rets(mats.size()); + for (int i = 0; i < mats.size(); ++i) + { + rets[i] = { + {results[i].height, results[i].width, results[i].channel}, // shape + results[i].data, // data + py::capsule(new Sptr(holder), // handle + [](void* p) + { delete reinterpret_cast(p); }) // + }; + } + return rets; + } - private: - mmdeploy_restorer_t restorer_{}; -}; + private: + mmdeploy_restorer_t restorer_{}; + }; -static PythonBindingRegisterer register_restorer{[](py::module& m) { - py::class_(m, "Restorer") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyRestorer* self, const PyImage& img) -> py::array { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PyRestorer::Apply); -}}; + static PythonBindingRegisterer register_restorer{[](py::module& m) + { + py::class_(m, "Restorer") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyRestorer* self, const PyImage& img) -> py::array + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyRestorer::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/rotated_detector.cpp b/csrc/mmdeploy/apis/python/rotated_detector.cpp index bc760b04e4..148b31fa6e 100644 --- a/csrc/mmdeploy/apis/python/rotated_detector.cpp +++ b/csrc/mmdeploy/apis/python/rotated_detector.cpp @@ -4,74 +4,87 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyRotatedDetector { - public: - PyRotatedDetector(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_rotated_detector_create_by_path(model_path, device_name, device_id, &detector_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create rotated detector"); - } - } - py::list Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } + class PyRotatedDetector + { + public: + PyRotatedDetector(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_rotated_detector_create_by_path(model_path, device_name, device_id, &detector_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create rotated detector"); + } + } + py::list Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } - mmdeploy_rotated_detection_t* rbboxes{}; - int* res_count{}; - auto status = mmdeploy_rotated_detector_apply(detector_, mats.data(), (int)mats.size(), - &rbboxes, &res_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply rotated detector, code: " + std::to_string(status)); - } - auto output = py::list{}; - auto result = rbboxes; - auto counts = res_count; - for (int i = 0; i < mats.size(); i++) { - auto _dets = py::array_t({*counts, 6}); - auto _labels = py::array_t({*counts}); - auto dets = _dets.mutable_data(); - auto labels = _labels.mutable_data(); - for (int j = 0; j < *counts; j++) { - for (int k = 0; k < 5; k++) { - *dets++ = result->rbbox[k]; + mmdeploy_rotated_detection_t* rbboxes{}; + int* res_count{}; + auto status = mmdeploy_rotated_detector_apply(detector_, mats.data(), (int)mats.size(), &rbboxes, &res_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply rotated detector, code: " + std::to_string(status)); + } + auto output = py::list{}; + auto result = rbboxes; + auto counts = res_count; + for (int i = 0; i < mats.size(); i++) + { + auto _dets = py::array_t({*counts, 6}); + auto _labels = py::array_t({*counts}); + auto dets = _dets.mutable_data(); + auto labels = _labels.mutable_data(); + for (int j = 0; j < *counts; j++) + { + for (int k = 0; k < 5; k++) + { + *dets++ = result->rbbox[k]; + } + *dets++ = result->score; + *labels++ = result->label_id; + result++; + } + counts++; + output.append(py::make_tuple(std::move(_dets), std::move(_labels))); + } + mmdeploy_rotated_detector_release_result(rbboxes, res_count); + return output; + } + ~PyRotatedDetector() + { + mmdeploy_rotated_detector_destroy(detector_); + detector_ = {}; } - *dets++ = result->score; - *labels++ = result->label_id; - result++; - } - counts++; - output.append(py::make_tuple(std::move(_dets), std::move(_labels))); - } - mmdeploy_rotated_detector_release_result(rbboxes, res_count); - return output; - } - ~PyRotatedDetector() { - mmdeploy_rotated_detector_destroy(detector_); - detector_ = {}; - } - private: - mmdeploy_rotated_detector_t detector_{}; -}; + private: + mmdeploy_rotated_detector_t detector_{}; + }; -static PythonBindingRegisterer register_rotated_detector{[](py::module& m) { - py::class_(m, "RotatedDetector") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyRotatedDetector* self, const PyImage& img) -> py::tuple { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PyRotatedDetector::Apply); -}}; + static PythonBindingRegisterer register_rotated_detector{[](py::module& m) + { + py::class_(m, "RotatedDetector") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyRotatedDetector* self, const PyImage& img) -> py::tuple + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyRotatedDetector::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/segmentor.cpp b/csrc/mmdeploy/apis/python/segmentor.cpp index 940972ab61..9e1db508c7 100644 --- a/csrc/mmdeploy/apis/python/segmentor.cpp +++ b/csrc/mmdeploy/apis/python/segmentor.cpp @@ -4,74 +4,91 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PySegmentor { - public: - PySegmentor(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_segmentor_create_by_path(model_path, device_name, device_id, &segmentor_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create segmentor"); - } - } - ~PySegmentor() { - mmdeploy_segmentor_destroy(segmentor_); - segmentor_ = {}; - } + class PySegmentor + { + public: + PySegmentor(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_segmentor_create_by_path(model_path, device_name, device_id, &segmentor_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create segmentor"); + } + } + ~PySegmentor() + { + mmdeploy_segmentor_destroy(segmentor_); + segmentor_ = {}; + } - std::vector Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_segmentation_t* segm{}; - auto status = mmdeploy_segmentor_apply(segmentor_, mats.data(), (int)mats.size(), &segm); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply segmentor, code: " + std::to_string(status)); - } - using Sptr = std::shared_ptr; - Sptr holder(segm, [n = mats.size()](auto p) { mmdeploy_segmentor_release_result(p, n); }); + std::vector Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_segmentation_t* segm{}; + auto status = mmdeploy_segmentor_apply(segmentor_, mats.data(), (int)mats.size(), &segm); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply segmentor, code: " + std::to_string(status)); + } + using Sptr = std::shared_ptr; + Sptr holder(segm, [n = mats.size()](auto p) + { mmdeploy_segmentor_release_result(p, n); }); - std::vector rets(mats.size()); - for (size_t i = 0; i < mats.size(); ++i) { - if (segm[i].mask != nullptr) { - rets[i] = { - {segm[i].height, segm[i].width}, // shape - segm[i].mask, // mask - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); }) // - }; - } - if (segm[i].score != nullptr) { - rets[i] = { - {segm[i].classes, segm[i].height, segm[i].width}, // shape - segm[i].score, // score - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); }) // - }; - } - } - return rets; - } + std::vector rets(mats.size()); + for (size_t i = 0; i < mats.size(); ++i) + { + if (segm[i].mask != nullptr) + { + rets[i] = { + {segm[i].height, segm[i].width}, // shape + segm[i].mask, // mask + py::capsule(new Sptr(holder), // handle + [](void* p) + { delete reinterpret_cast(p); }) // + }; + } + if (segm[i].score != nullptr) + { + rets[i] = { + {segm[i].classes, segm[i].height, segm[i].width}, // shape + segm[i].score, // score + py::capsule(new Sptr(holder), // handle + [](void* p) + { delete reinterpret_cast(p); }) // + }; + } + } + return rets; + } - private: - mmdeploy_segmentor_t segmentor_{}; -}; + private: + mmdeploy_segmentor_t segmentor_{}; + }; -static PythonBindingRegisterer register_segmentor{[](py::module& m) { - py::class_(m, "Segmentor") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PySegmentor* self, const PyImage& img) -> py::array { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PySegmentor::Apply); -}}; + static PythonBindingRegisterer register_segmentor{[](py::module& m) + { + py::class_(m, "Segmentor") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PySegmentor* self, const PyImage& img) -> py::array + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PySegmentor::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/text_detector.cpp b/csrc/mmdeploy/apis/python/text_detector.cpp index 19762d08ec..1326588a1f 100644 --- a/csrc/mmdeploy/apis/python/text_detector.cpp +++ b/csrc/mmdeploy/apis/python/text_detector.cpp @@ -4,68 +4,81 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyTextDetector { - public: - PyTextDetector(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_text_detector_create_by_path(model_path, device_name, device_id, &detector_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create text_detector"); - } - } - std::vector> Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_text_detection_t* detection{}; - int* result_count{}; - auto status = mmdeploy_text_detector_apply(detector_, mats.data(), (int)mats.size(), &detection, - &result_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply text_detector, code: " + std::to_string(status)); - } - auto output = std::vector>{}; - auto result = detection; - for (int i = 0; i < mats.size(); ++i) { - auto bboxes = py::array_t({result_count[i], 9}); - for (int j = 0; j < result_count[i]; ++j, ++result) { - auto data = bboxes.mutable_data(j); - for (const auto& p : result->bbox) { - *data++ = p.x; - *data++ = p.y; + class PyTextDetector + { + public: + PyTextDetector(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_text_detector_create_by_path(model_path, device_name, device_id, &detector_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create text_detector"); + } + } + std::vector> Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_text_detection_t* detection{}; + int* result_count{}; + auto status = mmdeploy_text_detector_apply(detector_, mats.data(), (int)mats.size(), &detection, &result_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply text_detector, code: " + std::to_string(status)); + } + auto output = std::vector>{}; + auto result = detection; + for (int i = 0; i < mats.size(); ++i) + { + auto bboxes = py::array_t({result_count[i], 9}); + for (int j = 0; j < result_count[i]; ++j, ++result) + { + auto data = bboxes.mutable_data(j); + for (const auto& p : result->bbox) + { + *data++ = p.x; + *data++ = p.y; + } + *data++ = result->score; + } + output.push_back(std::move(bboxes)); + } + mmdeploy_text_detector_release_result(detection, result_count, (int)mats.size()); + return output; + } + ~PyTextDetector() + { + mmdeploy_text_detector_destroy(detector_); + detector_ = {}; } - *data++ = result->score; - } - output.push_back(std::move(bboxes)); - } - mmdeploy_text_detector_release_result(detection, result_count, (int)mats.size()); - return output; - } - ~PyTextDetector() { - mmdeploy_text_detector_destroy(detector_); - detector_ = {}; - } - private: - mmdeploy_text_detector_t detector_{}; -}; + private: + mmdeploy_text_detector_t detector_{}; + }; -static PythonBindingRegisterer register_text_detector{[](py::module& m) { - py::class_(m, "TextDetector") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyTextDetector* self, const PyImage& img) -> py::array { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PyTextDetector::Apply); -}}; + static PythonBindingRegisterer register_text_detector{[](py::module& m) + { + py::class_(m, "TextDetector") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyTextDetector* self, const PyImage& img) -> py::array + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyTextDetector::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/text_recognizer.cpp b/csrc/mmdeploy/apis/python/text_recognizer.cpp index 317f55103a..1b3bc92af8 100644 --- a/csrc/mmdeploy/apis/python/text_recognizer.cpp +++ b/csrc/mmdeploy/apis/python/text_recognizer.cpp @@ -4,79 +4,99 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyTextRecognizer { - public: - PyTextRecognizer(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_text_recognizer_create_by_path(model_path, device_name, device_id, &recognizer_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create text_recognizer"); - } - } - std::vector>> Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_text_recognition_t* results{}; - auto status = - mmdeploy_text_recognizer_apply(recognizer_, mats.data(), (int)mats.size(), &results); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status)); - } - auto output = std::vector>>{}; - for (int i = 0; i < mats.size(); ++i) { - std::vector score(results[i].score, results[i].score + results[i].length); - output.emplace_back(results[i].text, std::move(score)); - } - mmdeploy_text_recognizer_release_result(results, (int)mats.size()); - return output; - } - std::vector>> Apply(const PyImage& img, - const std::vector& bboxes) { - if (bboxes.size() * sizeof(float) % sizeof(mmdeploy_text_detection_t)) { - throw std::invalid_argument("bboxes is not a list of 'mmdeploy_text_detection_t'"); - } - auto mat = GetMat(img); - int bbox_count = bboxes.size() * sizeof(float) / sizeof(mmdeploy_text_detection_t); - mmdeploy_text_recognition_t* results{}; - auto status = mmdeploy_text_recognizer_apply_bbox( - recognizer_, &mat, 1, (mmdeploy_text_detection_t*)bboxes.data(), &bbox_count, &results); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status)); - } - auto output = std::vector>>{}; - for (int i = 0; i < bbox_count; ++i) { - std::vector score(results[i].score, results[i].score + results[i].length); - output.emplace_back(results[i].text, std::move(score)); - } - mmdeploy_text_recognizer_release_result(results, bbox_count); - return output; - } - ~PyTextRecognizer() { - mmdeploy_text_recognizer_destroy(recognizer_); - recognizer_ = {}; - } + class PyTextRecognizer + { + public: + PyTextRecognizer(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_text_recognizer_create_by_path(model_path, device_name, device_id, &recognizer_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create text_recognizer"); + } + } + std::vector>> Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_text_recognition_t* results{}; + auto status = + mmdeploy_text_recognizer_apply(recognizer_, mats.data(), (int)mats.size(), &results); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status)); + } + auto output = std::vector>>{}; + for (int i = 0; i < mats.size(); ++i) + { + std::vector score(results[i].score, results[i].score + results[i].length); + output.emplace_back(results[i].text, std::move(score)); + } + mmdeploy_text_recognizer_release_result(results, (int)mats.size()); + return output; + } + std::vector>> Apply(const PyImage& img, + const std::vector& bboxes) + { + if (bboxes.size() * sizeof(float) % sizeof(mmdeploy_text_detection_t)) + { + throw std::invalid_argument("bboxes is not a list of 'mmdeploy_text_detection_t'"); + } + auto mat = GetMat(img); + int bbox_count = bboxes.size() * sizeof(float) / sizeof(mmdeploy_text_detection_t); + mmdeploy_text_recognition_t* results{}; + auto status = mmdeploy_text_recognizer_apply_bbox( + recognizer_, + &mat, + 1, + (mmdeploy_text_detection_t*)bboxes.data(), + &bbox_count, + &results); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status)); + } + auto output = std::vector>>{}; + for (int i = 0; i < bbox_count; ++i) + { + std::vector score(results[i].score, results[i].score + results[i].length); + output.emplace_back(results[i].text, std::move(score)); + } + mmdeploy_text_recognizer_release_result(results, bbox_count); + return output; + } + ~PyTextRecognizer() + { + mmdeploy_text_recognizer_destroy(recognizer_); + recognizer_ = {}; + } - private: - mmdeploy_text_recognizer_t recognizer_{}; -}; + private: + mmdeploy_text_recognizer_t recognizer_{}; + }; -static PythonBindingRegisterer register_text_recognizer{[](py::module& m) { - py::class_(m, "TextRecognizer") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", [](PyTextRecognizer* self, - const PyImage& img) { return self->Apply(std::vector{img})[0]; }) - .def("__call__", [](PyTextRecognizer* self, const PyImage& img, - const std::vector& bboxes) { return self->Apply(img, bboxes); }) - .def("batch", py::overload_cast&>(&PyTextRecognizer::Apply)); -}}; + static PythonBindingRegisterer register_text_recognizer{[](py::module& m) + { + py::class_(m, "TextRecognizer") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", [](PyTextRecognizer* self, const PyImage& img) + { return self->Apply(std::vector{img})[0]; }) + .def("__call__", [](PyTextRecognizer* self, const PyImage& img, const std::vector& bboxes) + { return self->Apply(img, bboxes); }) + .def("batch", py::overload_cast&>(&PyTextRecognizer::Apply)); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/video_recognizer.cpp b/csrc/mmdeploy/apis/python/video_recognizer.cpp index 7c70337e51..ac2e691be3 100644 --- a/csrc/mmdeploy/apis/python/video_recognizer.cpp +++ b/csrc/mmdeploy/apis/python/video_recognizer.cpp @@ -4,85 +4,102 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyVideoRecognizer { - public: - PyVideoRecognizer(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_video_recognizer_create_by_path(model_path, device_name, device_id, &recognizer_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create video_recognizer"); - } - } - std::vector>> Apply( - const std::vector>& imgs, const std::vector>& info) { - if (info.size() != imgs.size()) { - throw std::invalid_argument("the length of info is not equal with imgs"); - } - for (int i = 0; i < info.size(); i++) { - if (imgs[i].size() != info[i].first * info[i].second) { - throw std::invalid_argument("invalid info"); - } - } - int total = 0; - for (int i = 0; i < imgs.size(); i++) { - total += imgs[i].size(); - } - std::vector clips; - std::vector clip_info; - clips.reserve(total); - clip_info.reserve(total); - for (int i = 0; i < imgs.size(); i++) { - for (const auto& img : imgs[i]) { - auto mat = GetMat(img); - clips.push_back(mat); - } - clip_info.push_back({info[i].first, info[i].second}); - } + class PyVideoRecognizer + { + public: + PyVideoRecognizer(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_video_recognizer_create_by_path(model_path, device_name, device_id, &recognizer_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create video_recognizer"); + } + } + std::vector>> Apply( + const std::vector>& imgs, + const std::vector>& info) + { + if (info.size() != imgs.size()) + { + throw std::invalid_argument("the length of info is not equal with imgs"); + } + for (int i = 0; i < info.size(); i++) + { + if (imgs[i].size() != info[i].first * info[i].second) + { + throw std::invalid_argument("invalid info"); + } + } + int total = 0; + for (int i = 0; i < imgs.size(); i++) + { + total += imgs[i].size(); + } + std::vector clips; + std::vector clip_info; + clips.reserve(total); + clip_info.reserve(total); + for (int i = 0; i < imgs.size(); i++) + { + for (const auto& img : imgs[i]) + { + auto mat = GetMat(img); + clips.push_back(mat); + } + clip_info.push_back({info[i].first, info[i].second}); + } - mmdeploy_video_recognition_t* results{}; - int* result_count{}; - auto status = mmdeploy_video_recognizer_apply(recognizer_, clips.data(), clip_info.data(), 1, - &results, &result_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply video_recognizer, code: " + std::to_string(status)); - } + mmdeploy_video_recognition_t* results{}; + int* result_count{}; + auto status = mmdeploy_video_recognizer_apply(recognizer_, clips.data(), clip_info.data(), 1, &results, &result_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply video_recognizer, code: " + std::to_string(status)); + } - auto output = std::vector>>{}; - output.reserve(imgs.size()); - auto result_ptr = results; - for (int i = 0; i < imgs.size(); ++i) { - std::vector> label_score; - for (int j = 0; j < result_count[i]; ++j) { - label_score.emplace_back(result_ptr[j].label_id, result_ptr[j].score); - } - output.push_back(std::move(label_score)); - result_ptr += result_count[i]; - } - mmdeploy_video_recognizer_release_result(results, result_count, (int)imgs.size()); - return output; - } + auto output = std::vector>>{}; + output.reserve(imgs.size()); + auto result_ptr = results; + for (int i = 0; i < imgs.size(); ++i) + { + std::vector> label_score; + for (int j = 0; j < result_count[i]; ++j) + { + label_score.emplace_back(result_ptr[j].label_id, result_ptr[j].score); + } + output.push_back(std::move(label_score)); + result_ptr += result_count[i]; + } + mmdeploy_video_recognizer_release_result(results, result_count, (int)imgs.size()); + return output; + } - ~PyVideoRecognizer() { - mmdeploy_video_recognizer_destroy(recognizer_); - recognizer_ = {}; - } + ~PyVideoRecognizer() + { + mmdeploy_video_recognizer_destroy(recognizer_); + recognizer_ = {}; + } - private: - mmdeploy_video_recognizer_t recognizer_{}; -}; + private: + mmdeploy_video_recognizer_t recognizer_{}; + }; -static PythonBindingRegisterer register_video_recognizer{[](py::module& m) { - py::class_(m, "VideoRecognizer") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyVideoRecognizer* self, const std::vector& imgs, - const std::pair& info) { return self->Apply({imgs}, {info})[0]; }) - .def("batch", &PyVideoRecognizer::Apply); -}}; + static PythonBindingRegisterer register_video_recognizer{[](py::module& m) + { + py::class_(m, "VideoRecognizer") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyVideoRecognizer* self, const std::vector& imgs, const std::pair& info) + { return self->Apply({imgs}, {info})[0]; }) + .def("batch", &PyVideoRecognizer::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/archive/CMakeLists.txt b/csrc/mmdeploy/archive/CMakeLists.txt index 3f3d1f1104..68c34d3d05 100644 --- a/csrc/mmdeploy/archive/CMakeLists.txt +++ b/csrc/mmdeploy/archive/CMakeLists.txt @@ -6,8 +6,10 @@ add_library(${PROJECT_NAME} INTERFACE) target_link_libraries(${PROJECT_NAME} INTERFACE mmdeploy::core) add_library(mmdeploy::archive ALIAS mmdeploy_archive) -install(DIRECTORY ${CMAKE_SOURCE_DIR}/csrc/mmdeploy/archive - DESTINATION include/mmdeploy - FILES_MATCHING PATTERN "*.h") +install( + DIRECTORY ${CMAKE_SOURCE_DIR}/csrc/mmdeploy/archive + DESTINATION include/mmdeploy + FILES_MATCHING + PATTERN "*.h") install(FILES ${CMAKE_SOURCE_DIR}/third_party/json/json.hpp DESTINATION include/mmdeploy/third_party/json) diff --git a/csrc/mmdeploy/archive/json_archive.h b/csrc/mmdeploy/archive/json_archive.h index 2803ee22b2..cf03005856 100644 --- a/csrc/mmdeploy/archive/json_archive.h +++ b/csrc/mmdeploy/archive/json_archive.h @@ -7,207 +7,247 @@ #include "mmdeploy/core/archive.h" #include "mmdeploy/core/value.h" -namespace mmdeploy { - -namespace detail { - -template -nlohmann::json to_json_impl(T&& val); - -inline nlohmann::json value_to_json(const Value& value) { - switch (value.type()) { - case ValueType::kNull: - return {}; - case ValueType::kBool: - return value.get(); - case ValueType::kInt: - return value.get(); - case ValueType::kUInt: - return value.get(); - case ValueType::kFloat: - return value.get(); - case ValueType::kString: - return value.get(); - case ValueType::kArray: { - nlohmann::json json = nlohmann::json::value_t::array; - for (const auto& x : value) { - json.push_back(value_to_json(x)); - } - return json; +namespace mmdeploy +{ + + namespace detail + { + + template + nlohmann::json to_json_impl(T&& val); + + inline nlohmann::json value_to_json(const Value& value) + { + switch (value.type()) + { + case ValueType::kNull: + return {}; + case ValueType::kBool: + return value.get(); + case ValueType::kInt: + return value.get(); + case ValueType::kUInt: + return value.get(); + case ValueType::kFloat: + return value.get(); + case ValueType::kString: + return value.get(); + case ValueType::kArray: + { + nlohmann::json json = nlohmann::json::value_t::array; + for (const auto& x : value) + { + json.push_back(value_to_json(x)); + } + return json; + } + case ValueType::kObject: + { + nlohmann::json json = nlohmann::json::value_t::object; + for (auto it = value.begin(); it != value.end(); ++it) + { + auto key = it.key(); + json[key] = value_to_json(*it); + } + return json; + } + case ValueType::kAny: + return ""; + default: + return ""; + } + } + + } // namespace detail + + template>, int> = 0> + nlohmann::json to_json(T&& val) + { + return detail::to_json_impl(std::forward(val)); } - case ValueType::kObject: { - nlohmann::json json = nlohmann::json::value_t::object; - for (auto it = value.begin(); it != value.end(); ++it) { - auto key = it.key(); - json[key] = value_to_json(*it); - } - return json; + + inline nlohmann::json to_json(const Value& value) + { + return detail::value_to_json(value); + } + + // save to JSON + class JsonOutputArchive : public OutputArchive + { + public: + explicit JsonOutputArchive(nlohmann::json& data) + : data_(data) + { + } + + void init(...) {} + + template + void named_value(const std::string& name, T&& val) + { + data_[name] = to_json(std::forward(val)); + } + + template + void item(T&& val) + { + data_.push_back(to_json(std::forward(val))); + } + + template, std::enable_if_t, std::is_same, std::is_same, std::is_same>, int> = 0> + void native(T&& val) + { + data_ = std::forward(val); + } + + private: + nlohmann::json& data_; + }; + + namespace detail + { + + template + inline nlohmann::json to_json_impl(T&& val) + { + nlohmann::json json; + JsonOutputArchive archive(json); + archive(std::forward(val)); + return json; + } + + } // namespace detail + + namespace detail + { + + inline Value json_to_value(const nlohmann::json& json) + { + using value_t = nlohmann::json::value_t; + switch (json.type()) + { + case value_t::null: + return {}; + case value_t::boolean: + return json.get(); + case value_t::number_integer: + return json.get(); + case value_t::number_unsigned: + return json.get(); + case value_t::number_float: + return json.get(); + case value_t::string: + return json.get(); + case value_t::array: + { + Value value = ValueType::kArray; + for (const auto& x : json) + { + value.push_back(json_to_value(x)); + } + return value; + } + case value_t::object: + { + Value value = ValueType::kObject; + for (const auto& proxy : json.items()) + { + value[proxy.key()] = json_to_value(proxy.value()); + } + return value; + } + default: + MMDEPLOY_ERROR("unsupported json type: {}", json.type_name()); + return {}; + } + } + + template + void from_json_impl(const nlohmann::json& json, T&& val); + + } // namespace detail + + template>, int> = 0> + void from_json(const nlohmann::json& json, T&& val) + { + detail::from_json_impl(json, std::forward(val)); } - case ValueType::kAny: - return ""; - default: - return ""; - } -} - -} // namespace detail - -template >, int> = 0> -nlohmann::json to_json(T&& val) { - return detail::to_json_impl(std::forward(val)); -} - -inline nlohmann::json to_json(const Value& value) { return detail::value_to_json(value); } - -// save to JSON -class JsonOutputArchive : public OutputArchive { - public: - explicit JsonOutputArchive(nlohmann::json& data) : data_(data) {} - - void init(...) {} - - template - void named_value(const std::string& name, T&& val) { - data_[name] = to_json(std::forward(val)); - } - - template - void item(T&& val) { - data_.push_back(to_json(std::forward(val))); - } - - template , - std::enable_if_t< - std::disjunction_v, std::is_same, - std::is_same, std::is_same>, - int> = 0> - void native(T&& val) { - data_ = std::forward(val); - } - - private: - nlohmann::json& data_; -}; - -namespace detail { - -template -inline nlohmann::json to_json_impl(T&& val) { - nlohmann::json json; - JsonOutputArchive archive(json); - archive(std::forward(val)); - return json; -} - -} // namespace detail - -namespace detail { - -inline Value json_to_value(const nlohmann::json& json) { - using value_t = nlohmann::json::value_t; - switch (json.type()) { - case value_t::null: - return {}; - case value_t::boolean: - return json.get(); - case value_t::number_integer: - return json.get(); - case value_t::number_unsigned: - return json.get(); - case value_t::number_float: - return json.get(); - case value_t::string: - return json.get(); - case value_t::array: { - Value value = ValueType::kArray; - for (const auto& x : json) { - value.push_back(json_to_value(x)); - } - return value; + + inline void from_json(const nlohmann::json& json, Value& val) + { + val = detail::json_to_value(json); } - case value_t::object: { - Value value = ValueType::kObject; - for (const auto& proxy : json.items()) { - value[proxy.key()] = json_to_value(proxy.value()); - } - return value; + + template + T from_json(const nlohmann::json& json); + + // load from JSON + class JsonInputArchive : public InputArchive + { + public: + explicit JsonInputArchive(const nlohmann::json& data) + : data_(data) + { + } + + template + void init(SizeType& size) + { + size = static_cast(data_.size()); + iter_ = data_.begin(); + } + + template + void named_value(std::string& name, T& val) + { + name = iter_.key(); + from_json(*iter_++, std::forward(val)); + } + + template + void named_value(const std::string& name, T&& val) + { + from_json(data_[name], std::forward(val)); + } + + template + void item(T&& val) + { + from_json(*iter_++, std::forward(val)); + } + + template + void native(T&& val) + { + data_.get_to(val); + } + + private: + const nlohmann::json& data_; + nlohmann::json::const_iterator iter_; + }; + + namespace detail + { + + template + inline void from_json_impl(const nlohmann::json& json, T&& val) + { + JsonInputArchive archive(json); + archive(std::forward(val)); + } + + } // namespace detail + + template + inline T from_json(const nlohmann::json& json) + { + T val{}; + from_json(json, val); + return val; } - default: - MMDEPLOY_ERROR("unsupported json type: {}", json.type_name()); - return {}; - } -} - -template -void from_json_impl(const nlohmann::json& json, T&& val); - -} // namespace detail - -template >, int> = 0> -void from_json(const nlohmann::json& json, T&& val) { - detail::from_json_impl(json, std::forward(val)); -} - -inline void from_json(const nlohmann::json& json, Value& val) { val = detail::json_to_value(json); } - -template -T from_json(const nlohmann::json& json); - -// load from JSON -class JsonInputArchive : public InputArchive { - public: - explicit JsonInputArchive(const nlohmann::json& data) : data_(data) {} - - template - void init(SizeType& size) { - size = static_cast(data_.size()); - iter_ = data_.begin(); - } - - template - void named_value(std::string& name, T& val) { - name = iter_.key(); - from_json(*iter_++, std::forward(val)); - } - - template - void named_value(const std::string& name, T&& val) { - from_json(data_[name], std::forward(val)); - } - - template - void item(T&& val) { - from_json(*iter_++, std::forward(val)); - } - - template - void native(T&& val) { - data_.get_to(val); - } - - private: - const nlohmann::json& data_; - nlohmann::json::const_iterator iter_; -}; - -namespace detail { - -template -inline void from_json_impl(const nlohmann::json& json, T&& val) { - JsonInputArchive archive(json); - archive(std::forward(val)); -} - -} // namespace detail - -template -inline T from_json(const nlohmann::json& json) { - T val{}; - from_json(json, val); - return val; -} - -void from_json(const nlohmann::json& json, Value& val); + + void from_json(const nlohmann::json& json, Value& val); } // namespace mmdeploy diff --git a/csrc/mmdeploy/archive/value_archive.h b/csrc/mmdeploy/archive/value_archive.h index 2f559c1a10..f3245f0dfc 100644 --- a/csrc/mmdeploy/archive/value_archive.h +++ b/csrc/mmdeploy/archive/value_archive.h @@ -6,131 +6,169 @@ #include "mmdeploy/core/archive.h" #include "mmdeploy/core/value.h" -namespace mmdeploy { - -template -Value to_value(T&& val); - -// save to Value -class ValueOutputArchive : public OutputArchive { - public: - explicit ValueOutputArchive(Value& data) : data_(data) {} - - template - void init(array_tag) { - data_ = ValueType::kArray; - } - - template - void init(object_tag) { - data_ = ValueType::kObject; - } - - template - void named_value(const std::string& name, T&& val) { - data_[name] = to_value(std::forward(val)); - } - - template - void item(T&& val) { - data_.push_back(to_value(std::forward(val))); - } - - template , int> = 0> - void native(T&& val) { - data_ = std::forward(val); - }; - - private: - Value& data_; -}; - -template -inline Value to_value(T&& val) { - Value value; - ValueOutputArchive archive(value); - archive(std::forward(val)); - return value; -} - -// fast path -inline Value to_value(const Value& v) { return v; } -inline Value to_value(Value&& v) { return std::move(v); } - -template -void from_value(const Value& value, T&& x); - -template -T from_value(const Value& value); - -// load from Value -class ValueInputArchive : public InputArchive { - public: - explicit ValueInputArchive(const Value& data) : data_(data) {} - - template - void init(SizeType& size) { - size = static_cast(data_.size()); - iter_ = data_.begin(); - } - - template - void named_value(std::string& name, T& val) { - name = iter_.key(); - from_value(*iter_, std::forward(val)); - ++iter_; - } - - template - void named_value(const std::string& name, T&& val) { - from_value(data_[name], std::forward(val)); - } - - template - void item(T&& val) { - from_value(*iter_, std::forward(val)); - ++iter_; - } - - template - void native(T&& val) { - data_.get_to(val); - } - - template - void value(T&& value) {} - - private: - const Value& data_; - Value::const_iterator iter_; -}; - -template -void from_value(const Value& value, T&& x) { - ValueInputArchive archive(value); - archive(std::forward(x)); -} - -// Required to avoid Value::Pointer being unwrapped by Value::get_to() -inline void from_value(const Value& value, Value& x) { x = value; } - -template -inline T from_value(const Value& value) { - T x{}; - from_value(value, x); - return x; -} - -namespace detail { - -inline void load(ValueInputArchive& archive, Value& v) { archive.native(v); } - -template , Value>::value, bool> = true> -inline void save(ValueOutputArchive& archive, T&& v) { - archive.native(std::forward(v)); -} - -} // namespace detail +namespace mmdeploy +{ + + template + Value to_value(T&& val); + + // save to Value + class ValueOutputArchive : public OutputArchive + { + public: + explicit ValueOutputArchive(Value& data) + : data_(data) + { + } + + template + void init(array_tag) + { + data_ = ValueType::kArray; + } + + template + void init(object_tag) + { + data_ = ValueType::kObject; + } + + template + void named_value(const std::string& name, T&& val) + { + data_[name] = to_value(std::forward(val)); + } + + template + void item(T&& val) + { + data_.push_back(to_value(std::forward(val))); + } + + template, int> = 0> + void native(T&& val) + { + data_ = std::forward(val); + }; + + private: + Value& data_; + }; + + template + inline Value to_value(T&& val) + { + Value value; + ValueOutputArchive archive(value); + archive(std::forward(val)); + return value; + } + + // fast path + inline Value to_value(const Value& v) + { + return v; + } + inline Value to_value(Value&& v) + { + return std::move(v); + } + + template + void from_value(const Value& value, T&& x); + + template + T from_value(const Value& value); + + // load from Value + class ValueInputArchive : public InputArchive + { + public: + explicit ValueInputArchive(const Value& data) + : data_(data) + { + } + + template + void init(SizeType& size) + { + size = static_cast(data_.size()); + iter_ = data_.begin(); + } + + template + void named_value(std::string& name, T& val) + { + name = iter_.key(); + from_value(*iter_, std::forward(val)); + ++iter_; + } + + template + void named_value(const std::string& name, T&& val) + { + from_value(data_[name], std::forward(val)); + } + + template + void item(T&& val) + { + from_value(*iter_, std::forward(val)); + ++iter_; + } + + template + void native(T&& val) + { + data_.get_to(val); + } + + template + void value(T&& value) + { + } + + private: + const Value& data_; + Value::const_iterator iter_; + }; + + template + void from_value(const Value& value, T&& x) + { + ValueInputArchive archive(value); + archive(std::forward(x)); + } + + // Required to avoid Value::Pointer being unwrapped by Value::get_to() + inline void from_value(const Value& value, Value& x) + { + x = value; + } + + template + inline T from_value(const Value& value) + { + T x{}; + from_value(value, x); + return x; + } + + namespace detail + { + + inline void load(ValueInputArchive& archive, Value& v) + { + archive.native(v); + } + + template, Value>::value, bool> = true> + inline void save(ValueOutputArchive& archive, T&& v) + { + archive.native(std::forward(v)); + } + + } // namespace detail } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/CMakeLists.txt b/csrc/mmdeploy/backend_ops/CMakeLists.txt index 761c35a59a..4fc59bbf8c 100644 --- a/csrc/mmdeploy/backend_ops/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/CMakeLists.txt @@ -1,39 +1,39 @@ -if (NOT MSVC) - set(CMAKE_CXX_STANDARD 14) - set(CMAKE_CXX_FLAGS_RELEASE "-O3") -endif () +if(NOT MSVC) + set(CMAKE_CXX_STANDARD 14) + set(CMAKE_CXX_FLAGS_RELEASE "-O3") +endif() # build ONNXRUNTIME ops -if ("ort" IN_LIST MMDEPLOY_TARGET_BACKENDS) - if (NOT DEFINED ONNXRUNTIME_DIR) - set(ONNXRUNTIME_DIR $ENV{ONNXRUNTIME_DIR}) - endif () - if (NOT ONNXRUNTIME_DIR) - message(FATAL_ERROR " ONNXRUNTIME_DIR is not found.") - else () - message(STATUS "Build ONNXRUNTIME custom ops.") - add_subdirectory(onnxruntime) - endif () -endif () +if("ort" IN_LIST MMDEPLOY_TARGET_BACKENDS) + if(NOT DEFINED ONNXRUNTIME_DIR) + set(ONNXRUNTIME_DIR $ENV{ONNXRUNTIME_DIR}) + endif() + if(NOT ONNXRUNTIME_DIR) + message(FATAL_ERROR " ONNXRUNTIME_DIR is not found.") + else() + message(STATUS "Build ONNXRUNTIME custom ops.") + add_subdirectory(onnxruntime) + endif() +endif() # build TensorRT ops -if ("trt" IN_LIST MMDEPLOY_TARGET_BACKENDS) - if (NOT DEFINED TENSORRT_DIR) - set(TENSORRT_DIR $ENV{TENSORRT_DIR}) - endif () - message(STATUS "Build TensorRT custom ops.") - add_subdirectory(tensorrt) -endif () +if("trt" IN_LIST MMDEPLOY_TARGET_BACKENDS) + if(NOT DEFINED TENSORRT_DIR) + set(TENSORRT_DIR $ENV{TENSORRT_DIR}) + endif() + message(STATUS "Build TensorRT custom ops.") + add_subdirectory(tensorrt) +endif() # build ncnn ops -if ("ncnn" IN_LIST MMDEPLOY_TARGET_BACKENDS) - message(STATUS "Build ncnn custom ops") - add_subdirectory(ncnn) -endif () +if("ncnn" IN_LIST MMDEPLOY_TARGET_BACKENDS) + message(STATUS "Build ncnn custom ops") + add_subdirectory(ncnn) +endif() # build TorchScript ops -if ("torchscript" IN_LIST MMDEPLOY_TARGET_BACKENDS - OR "coreml" IN_LIST MMDEPLOY_TARGET_BACKENDS) +if("torchscript" IN_LIST MMDEPLOY_TARGET_BACKENDS OR "coreml" IN_LIST + MMDEPLOY_TARGET_BACKENDS) message(STATUS "Build torchscript custom ops") add_subdirectory(torchscript) -endif () +endif() diff --git a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/common_cuda_helper.cuh b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/common_cuda_helper.cuh index 02c57c62e6..d5b0f57bfc 100644 --- a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/common_cuda_helper.cuh +++ b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/common_cuda_helper.cuh @@ -8,25 +8,27 @@ #include #define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) #define THREADS_PER_BLOCK 512 #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) -inline int GET_BLOCKS(const int N) { - int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); - int max_block_num = 4096; - return std::min(optimal_block_num, max_block_num); +inline int GET_BLOCKS(const int N) +{ + int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); + int max_block_num = 4096; + return std::min(optimal_block_num, max_block_num); } -#define cudaCheckError() \ - { \ - cudaError_t e = cudaGetLastError(); \ - if (e != cudaSuccess) { \ - printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(0); \ - } \ - } +#define cudaCheckError() \ + { \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) \ + { \ + printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(0); \ + } \ + } /** * Returns a view of the original tensor with its dimensions permuted. @@ -38,57 +40,81 @@ inline int GET_BLOCKS(const int N) { * @param[in] src_dim dim of src tensor * @param[in] stream cuda stream handle */ -template -void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim, - cudaStream_t stream = 0); - -template -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha, - const scalar_t* A, int lda, const scalar_t* B, int ldb, - const scalar_t* beta, scalar_t* C, int ldc); - -template -__device__ scalar_t bilinear_interpolate(const scalar_t* input, const int height, const int width, - scalar_t y, scalar_t x) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; - - if (y <= 0) y = 0; - if (x <= 0) x = 0; - - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (scalar_t)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (scalar_t)x_low; - } else { - x_high = x_low + 1; - } - - scalar_t ly = y - y_low; - scalar_t lx = x - x_low; - scalar_t hy = 1. - ly, hx = 1. - lx; - // do bilinear interpolation - scalar_t v1 = input[y_low * width + x_low]; - scalar_t v2 = input[y_low * width + x_high]; - scalar_t v3 = input[y_high * width + x_low]; - scalar_t v4 = input[y_high * width + x_high]; - scalar_t w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - return val; +template +void memcpyPermute(scalar_t* dst, + const scalar_t* src, + int* src_size, + int* permute, + int src_dim, + cudaStream_t stream = 0); + +template +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const scalar_t* alpha, + const scalar_t* A, + int lda, + const scalar_t* B, + int ldb, + const scalar_t* beta, + scalar_t* C, + int ldc); + +template +__device__ scalar_t bilinear_interpolate(const scalar_t* input, + const int height, + const int width, + scalar_t y, + scalar_t x) +{ + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) + { + y_high = y_low = height - 1; + y = (scalar_t)y_low; + } + else + { + y_high = y_low + 1; + } + + if (x_low >= width - 1) + { + x_high = x_low = width - 1; + x = (scalar_t)x_low; + } + else + { + x_high = x_low + 1; + } + + scalar_t ly = y - y_low; + scalar_t lx = x - x_low; + scalar_t hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + scalar_t v1 = input[y_low * width + x_low]; + scalar_t v2 = input[y_low * width + x_high]; + scalar_t v3 = input[y_high * width + x_low]; + scalar_t v4 = input[y_high * width + x_high]; + scalar_t w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; } #endif // COMMON_CUDA_HELPER diff --git a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cpu.h b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cpu.h index a37e243109..a65096df08 100644 --- a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cpu.h +++ b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cpu.h @@ -1,82 +1,105 @@ #include #include -template -T bilinear_interpolate_2d(const T *src, const int64_t src_h, const int64_t src_w, const T h, - const T w) { - if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) { - return 0; - } +template +T bilinear_interpolate_2d(const T* src, + const int64_t src_h, + const int64_t src_w, + const T h, + const T w) +{ + if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) + { + return 0; + } - int64_t h_low = floor(h); - int64_t w_low = floor(w); - int64_t h_high = h_low + 1; - int64_t w_high = w_low + 1; + int64_t h_low = floor(h); + int64_t w_low = floor(w); + int64_t h_high = h_low + 1; + int64_t w_high = w_low + 1; - T lh = h - h_low; - T lw = w - w_low; - T hh = 1 - lh; - T hw = 1 - lw; + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh; + T hw = 1 - lw; - T v1 = 0; - if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low]; - T v2 = 0; - if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high]; - T v3 = 0; - if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low]; - T v4 = 0; - if (h_high <= src_h - 1 && w_high <= src_w - 1) v4 = src[h_high * src_w + w_high]; + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high]; + T v3 = 0; + if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low]; + T v4 = 0; + if (h_high <= src_h - 1 && w_high <= src_w - 1) v4 = src[h_high * src_w + w_high]; - T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; } // output: (channels * kernel_h * kernel_w, dst_h * dst_w) -template -void deformable_im2col_2d(const T *input, const T *offset, const T *mask, const int64_t src_h, - const int64_t src_w, const int64_t kernel_h, const int64_t kernel_w, - const int64_t pad_h, const int64_t pad_w, const int64_t stride_h, - const int64_t stride_w, const int64_t dilation_h, - const int64_t dilation_w, const int64_t channels, - const int64_t offset_groups, const int64_t dst_h, const int64_t dst_w, - const bool use_mask, T *columns) { - const int64_t workload = channels * dst_h * dst_w; - for (int64_t index = 0; index != workload; ++index) { - const int64_t ow = index % dst_w; - const int64_t oh = (index / dst_w) % dst_h; - const int64_t ic = index / (dst_w * dst_h); - const int64_t oc = ic * kernel_h * kernel_w; +template +void deformable_im2col_2d(const T* input, + const T* offset, + const T* mask, + const int64_t src_h, + const int64_t src_w, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t channels, + const int64_t offset_groups, + const int64_t dst_h, + const int64_t dst_w, + const bool use_mask, + T* columns) +{ + const int64_t workload = channels * dst_h * dst_w; + for (int64_t index = 0; index != workload; ++index) + { + const int64_t ow = index % dst_w; + const int64_t oh = (index / dst_w) % dst_h; + const int64_t ic = index / (dst_w * dst_h); + const int64_t oc = ic * kernel_h * kernel_w; - int64_t c_per_offset_grp = channels / offset_groups; - const int64_t grp_idx = ic / c_per_offset_grp; + int64_t c_per_offset_grp = channels / offset_groups; + const int64_t grp_idx = ic / c_per_offset_grp; - auto columns_ptr = columns + (oc * (dst_h * dst_w) + oh * dst_w + ow); - auto input_ptr = input + ic * (src_h * src_w); - auto offset_ptr = offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w; - auto mask_ptr = mask; - if (use_mask) { - mask_ptr += grp_idx * kernel_h * kernel_w * dst_h * dst_w; - } + auto columns_ptr = columns + (oc * (dst_h * dst_w) + oh * dst_w + ow); + auto input_ptr = input + ic * (src_h * src_w); + auto offset_ptr = offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w; + auto mask_ptr = mask; + if (use_mask) + { + mask_ptr += grp_idx * kernel_h * kernel_w * dst_h * dst_w; + } - for (int64_t kh = 0; kh < kernel_h; ++kh) { - for (int64_t kw = 0; kw < kernel_w; ++kw) { - const int64_t mask_idx = kh * kernel_w + kw; - const int64_t offset_idx = 2 * mask_idx; + for (int64_t kh = 0; kh < kernel_h; ++kh) + { + for (int64_t kw = 0; kw < kernel_w; ++kw) + { + const int64_t mask_idx = kh * kernel_w + kw; + const int64_t offset_idx = 2 * mask_idx; - T mask_value = 1; - if (use_mask) { - mask_value = mask_ptr[mask_idx * (dst_h * dst_w) + oh * dst_w + ow]; - } + T mask_value = 1; + if (use_mask) + { + mask_value = mask_ptr[mask_idx * (dst_h * dst_w) + oh * dst_w + ow]; + } - const T offset_h = offset_ptr[offset_idx * (dst_h * dst_w) + oh * dst_w + ow]; - const T offset_w = offset_ptr[(offset_idx + 1) * (dst_h * dst_w) + oh * dst_w + ow]; - const T ih = (oh * stride_h - pad_h) + kh * dilation_h + offset_h; - const T iw = (ow * stride_w - pad_w) + kw * dilation_w + offset_w; - *columns_ptr = mask_value * bilinear_interpolate_2d(input_ptr, src_h, src_w, ih, iw); - columns_ptr += dst_h * dst_w; - } + const T offset_h = offset_ptr[offset_idx * (dst_h * dst_w) + oh * dst_w + ow]; + const T offset_w = offset_ptr[(offset_idx + 1) * (dst_h * dst_w) + oh * dst_w + ow]; + const T ih = (oh * stride_h - pad_h) + kh * dilation_h + offset_h; + const T iw = (ow * stride_w - pad_w) + kw * dilation_w + offset_w; + *columns_ptr = mask_value * bilinear_interpolate_2d(input_ptr, src_h, src_w, ih, iw); + columns_ptr += dst_h * dst_w; + } + } } - } } diff --git a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cuda.cuh b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cuda.cuh index 43166e7d6b..20429a37c9 100644 --- a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cuda.cuh +++ b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cuda.cuh @@ -71,110 +71,139 @@ #include "common_cuda_helper.cuh" -template -__device__ float mdcn_im2col_bilinear(const T *input, const int data_width, const int height, - const int width, float h, float w) { - int h_low = floorf(h); - int w_low = floorf(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - T lh = h - h_low; - T lw = w - w_low; - T hh = 1 - lh, hw = 1 - lw; - - T v1 = 0; - if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low]; - T v2 = 0; - if (h_low >= 0 && w_high <= width - 1) v2 = input[h_low * data_width + w_high]; - T v3 = 0; - if (h_high <= height - 1 && w_low >= 0) v3 = input[h_high * data_width + w_low]; - T v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) v4 = input[h_high * data_width + w_high]; - - T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return float(val); +template +__device__ float mdcn_im2col_bilinear(const T* input, + const int data_width, + const int height, + const int width, + float h, + float w) +{ + int h_low = floorf(h); + int w_low = floorf(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh, hw = 1 - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) v2 = input[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) v3 = input[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) v4 = input[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return float(val); } -template <> -__device__ float mdcn_im2col_bilinear<__half>(const __half *input, const int data_width, - const int height, const int width, float h, float w) { - int h_low = floorf(h); - int w_low = floorf(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - float lh = h - h_low; - float lw = w - w_low; - float hh = 1 - lh, hw = 1 - lw; - - float v1 = 0; - if (h_low >= 0 && w_low >= 0) v1 = __half2float(input[h_low * data_width + w_low]); - float v2 = 0; - if (h_low >= 0 && w_high <= width - 1) v2 = __half2float(input[h_low * data_width + w_high]); - float v3 = 0; - if (h_high <= height - 1 && w_low >= 0) v3 = __half2float(input[h_high * data_width + w_low]); - float v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - v4 = __half2float(input[h_high * data_width + w_high]); - - float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; +template<> +__device__ float mdcn_im2col_bilinear<__half>(const __half* input, + const int data_width, + const int height, + const int width, + float h, + float w) +{ + int h_low = floorf(h); + int w_low = floorf(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h - h_low; + float lw = w - w_low; + float hh = 1 - lh, hw = 1 - lw; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = __half2float(input[h_low * data_width + w_low]); + float v2 = 0; + if (h_low >= 0 && w_high <= width - 1) v2 = __half2float(input[h_low * data_width + w_high]); + float v3 = 0; + if (h_high <= height - 1 && w_low >= 0) v3 = __half2float(input[h_high * data_width + w_low]); + float v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = __half2float(input[h_high * data_width + w_high]); + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; } -template -__global__ void modulated_deformable_im2col_gpu_kernel( - const int n, const T *data_im, const T *data_offset, const T *data_mask, const int height, - const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, const int num_channels, - const int deformable_group, const int height_col, const int width_col, T *data_col) { - CUDA_1D_KERNEL_LOOP(index, n) { - // index index of output matrix - const int w_col = index % width_col; - const int h_col = (index / width_col) % height_col; - const int b_col = (index / width_col / height_col) % batch_size; - const int c_im = (index / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - // compute deformable group index - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - T *data_col_ptr = - data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - const T *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; - const T *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * width_col; - - const T *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * - kernel_h * kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; - const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - float val = 0.0f; - const float h_im = h_in + i * dilation_h + (float)offset_h; - const float w_im = w_in + j * dilation_w + (float)offset_w; - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) - val = mdcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); - *data_col_ptr = (T)(val * (float)mask); - data_col_ptr += batch_size * height_col * width_col; - } +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const T* data_im, + const T* data_offset, + const T* data_mask, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + T* data_col) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T* data_col_ptr = + data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T* data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const T* data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + + const T* data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + float val = 0.0f; + const float h_im = h_in + i * dilation_h + (float)offset_h; + const float w_im = w_in + j * dilation_w + (float)offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + val = mdcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + *data_col_ptr = (T)(val * (float)mask); + data_col_ptr += batch_size * height_col * width_col; + } + } } - } } #endif // TRT_MODULATED_DEFORM_CONV_KERNEL_CUH diff --git a/csrc/mmdeploy/backend_ops/ncnn/CMakeLists.txt b/csrc/mmdeploy/backend_ops/ncnn/CMakeLists.txt index 4df9ad1233..f3e2aeb51e 100755 --- a/csrc/mmdeploy/backend_ops/ncnn/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/ncnn/CMakeLists.txt @@ -3,20 +3,21 @@ # ncnn find_package(ncnn) -if (ncnn_FOUND) - message(STATUS "ncnn library found!") -else () - message(FATAL_ERROR "Could not locate ncnn") -endif () +if(ncnn_FOUND) + message(STATUS "ncnn library found!") +else() + message(FATAL_ERROR "Could not locate ncnn") +endif() - -if (NOT ANDROID AND NOT IOS AND NOT CMAKE_CROSSCOMPILING) - add_subdirectory(ops) - add_subdirectory(onnx2ncnn) - add_subdirectory(pyncnn_ext) -else () - # In case of embedded platform, like android, or ios, we only build custom ncnn - # ops, and leave the executable converter(onnx2ncnn, pyncnn_ext) built under - # the host platforms - add_subdirectory(ops) -endif () +if(NOT ANDROID + AND NOT IOS + AND NOT CMAKE_CROSSCOMPILING) + add_subdirectory(ops) + add_subdirectory(onnx2ncnn) + add_subdirectory(pyncnn_ext) +else() + # In case of embedded platform, like android, or ios, we only build custom + # ncnn ops, and leave the executable converter(onnx2ncnn, pyncnn_ext) built + # under the host platforms + add_subdirectory(ops) +endif() diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/CMakeLists.txt b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/CMakeLists.txt index fe1687e951..deeb1e1241 100755 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/CMakeLists.txt @@ -4,22 +4,28 @@ project(onnx2ncnn) find_package(Protobuf) -if (PROTOBUF_FOUND) - if (${Protobuf_PROTOC_EXECUTABLE} STREQUAL "") - message(FATAL_ERROR "protoc not found, try `-DProtobuf_PROTOC_EXECUTABLE=/path/to/protoc`") - endif () - protobuf_generate_cpp(ONNX_PROTO_SRCS ONNX_PROTO_HDRS - ${CMAKE_CURRENT_SOURCE_DIR}/onnx.proto) - add_executable(mmdeploy_onnx2ncnn onnx2ncnn.cpp fuse_pass.cpp shape_inference.cpp ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) - target_include_directories(mmdeploy_onnx2ncnn PRIVATE ${PROTOBUF_INCLUDE_DIR} - ${CMAKE_CURRENT_BINARY_DIR}) - target_link_libraries(mmdeploy_onnx2ncnn PRIVATE ${PROTOBUF_LIBRARIES}) - if (MSVC) - target_compile_options(mmdeploy_onnx2ncnn PUBLIC $<$:/Za>) - endif() - set(_NCNN_CONVERTER_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/backend/ncnn) - install(TARGETS mmdeploy_onnx2ncnn DESTINATION ${_NCNN_CONVERTER_DIR}) -else () +if(PROTOBUF_FOUND) + if(${Protobuf_PROTOC_EXECUTABLE} STREQUAL "") message( - FATAL_ERROR "Protobuf not found, onnx model convert tool won't be built") -endif () + FATAL_ERROR + "protoc not found, try `-DProtobuf_PROTOC_EXECUTABLE=/path/to/protoc`") + endif() + protobuf_generate_cpp(ONNX_PROTO_SRCS ONNX_PROTO_HDRS + ${CMAKE_CURRENT_SOURCE_DIR}/onnx.proto) + add_executable( + mmdeploy_onnx2ncnn onnx2ncnn.cpp fuse_pass.cpp shape_inference.cpp + ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) + target_include_directories( + mmdeploy_onnx2ncnn PRIVATE ${PROTOBUF_INCLUDE_DIR} + ${CMAKE_CURRENT_BINARY_DIR}) + target_link_libraries(mmdeploy_onnx2ncnn PRIVATE ${PROTOBUF_LIBRARIES}) + if(MSVC) + target_compile_options(mmdeploy_onnx2ncnn + PUBLIC $<$:/Za>) + endif() + set(_NCNN_CONVERTER_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/backend/ncnn) + install(TARGETS mmdeploy_onnx2ncnn DESTINATION ${_NCNN_CONVERTER_DIR}) +else() + message( + FATAL_ERROR "Protobuf not found, onnx model convert tool won't be built") +endif() diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp index 4d620e4c82..274ba76bca 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp @@ -1,355 +1,402 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "fuse_pass.h" -void fuse_identity(onnx::GraphProto* mutable_graph, +void fuse_identity(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - // fuse - // identity --> op - // to - // noop_reducencnn --> op - const int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; ++i) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - for (int j = 0; j < node->input_size(); ++j) { - std::string output_name = node->input(j); - onnx::NodeProto* last_node = find_node_by_output_name(mutable_graph, output_name); - if (last_node && last_node->op_type() == "Identity") { - node->set_input(j, last_node->input(0)); - node_reference[last_node->output(0)] -= 1; - node_reference[last_node->input(0)] += 1; - if (node_reference[last_node->output(0)] == 0) { - last_node->set_op_type("noop_reducedncnn"); - node_reference[last_node->input(0)] -= 1; - reduced_node_count += 1; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + // fuse + // identity --> op + // to + // noop_reducencnn --> op + const int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; ++i) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + for (int j = 0; j < node->input_size(); ++j) + { + std::string output_name = node->input(j); + onnx::NodeProto* last_node = find_node_by_output_name(mutable_graph, output_name); + if (last_node && last_node->op_type() == "Identity") + { + node->set_input(j, last_node->input(0)); + node_reference[last_node->output(0)] -= 1; + node_reference[last_node->input(0)] += 1; + if (node_reference[last_node->output(0)] == 0) + { + last_node->set_op_type("noop_reducedncnn"); + node_reference[last_node->input(0)] -= 1; + reduced_node_count += 1; + } + } } - } } - } } -void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, +void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - const int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; ++i) { - onnx::NodeProto* gather = mutable_graph->mutable_node(i); - if (gather->op_type() != "Gather") { - continue; - } - if (weights.find(std::string(gather->input(1))) == weights.end()) { - continue; - } - auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]); - if (indices.size() != 1) { - continue; - } - + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + const int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; ++i) { - // reconstruct node connections - node_reference[gather->input(1)] -= 1; - std::string origin_inp = gather->input(0); - gather->clear_input(); - gather->add_input(origin_inp); - } + onnx::NodeProto* gather = mutable_graph->mutable_node(i); + if (gather->op_type() != "Gather") + { + continue; + } + if (weights.find(std::string(gather->input(1))) == weights.end()) + { + continue; + } + auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]); + if (indices.size() != 1) + { + continue; + } - { - // update axis, starts and ends - int axis = get_node_attr_i(*gather, "axis", 1) - 1; + { + // reconstruct node connections + node_reference[gather->input(1)] -= 1; + std::string origin_inp = gather->input(0); + gather->clear_input(); + gather->add_input(origin_inp); + } + + { + // update axis, starts and ends + int axis = get_node_attr_i(*gather, "axis", 1) - 1; - gather->set_op_type("Crop"); - gather->clear_attribute(); + gather->set_op_type("Crop"); + gather->clear_attribute(); - int indice = indices[0]; - set_node_attr_ai(*gather, "starts", std::vector{indice}); - set_node_attr_ai(*gather, "ends", std::vector{indice + 1}); - set_node_attr_ai(*gather, "axis", std::vector{axis}); + int indice = indices[0]; + set_node_attr_ai(*gather, "starts", std::vector{indice}); + set_node_attr_ai(*gather, "ends", std::vector{indice + 1}); + set_node_attr_ai(*gather, "axis", std::vector{axis}); + } } - } } -void fuse_weight_reshape(onnx::GraphProto* mutable_graph, +void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // weight <= Reshape(weight) - if (node->op_type() == "Reshape") { - // check weight - if (weights.find(node->input(0)) == weights.end()) continue; - - weights[node->output(0)] = weights[node->input(0)]; - - // set weight shape directly - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else if (node->input_size() == 2) { - // opset 5 - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } - - weights[node->output(0)].clear_dims(); - for (int j = 0; j < shape.size(); j++) { - weights[node->output(0)].add_dims(shape[j]); - } - - // reduce - node->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= 1; - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - - reduced_node_count += 1; - i += 1; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // weight <= Reshape(weight) + if (node->op_type() == "Reshape") + { + // check weight + if (weights.find(node->input(0)) == weights.end()) continue; + + weights[node->output(0)] = weights[node->input(0)]; + + // set weight shape directly + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else if (node->input_size() == 2) + { + // opset 5 + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } + + weights[node->output(0)].clear_dims(); + for (int j = 0; j < shape.size(); j++) + { + weights[node->output(0)].add_dims(shape[j]); + } + + // reduce + node->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_weight_transpose(onnx::GraphProto* mutable_graph, +void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // weight <= Transpose(weight) - if (node->op_type() == "Transpose") { - // check weight - if (weights.find(node->input(0)) == weights.end()) continue; - - if (weights[node->input(0)].dims_size() != 2) continue; - - // perm = (1, 0) - std::vector perm = get_node_attr_ai(*node, "perm"); - if (perm.size() != 2) continue; - if (perm[0] != 1 || perm[1] != 0) continue; - - weights[node->output(0)] = weights[node->input(0)]; - - // permute weight - { - onnx::TensorProto& B = weights[node->output(0)]; - - const int h = B.dims(0); - const int w = B.dims(1); - - std::vector permuted_data; - permuted_data.reserve((size_t)h * w); - const float* bptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - - for (int j = 0; j < w; j++) { - for (int k = 0; k < h; k++) { - float vb = bptr[k * w + j]; - permuted_data.push_back(vb); - } - } - - B.set_dims(0, w); - B.set_dims(1, h); - - if (B.has_raw_data()) { - B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float)); - } else { - for (int j = 0; j < (int)permuted_data.size(); j++) B.set_float_data(j, permuted_data[j]); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // weight <= Transpose(weight) + if (node->op_type() == "Transpose") + { + // check weight + if (weights.find(node->input(0)) == weights.end()) continue; + + if (weights[node->input(0)].dims_size() != 2) continue; + + // perm = (1, 0) + std::vector perm = get_node_attr_ai(*node, "perm"); + if (perm.size() != 2) continue; + if (perm[0] != 1 || perm[1] != 0) continue; + + weights[node->output(0)] = weights[node->input(0)]; + + // permute weight + { + onnx::TensorProto& B = weights[node->output(0)]; + + const int h = B.dims(0); + const int w = B.dims(1); + + std::vector permuted_data; + permuted_data.reserve((size_t)h * w); + const float* bptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + + for (int j = 0; j < w; j++) + { + for (int k = 0; k < h; k++) + { + float vb = bptr[k * w + j]; + permuted_data.push_back(vb); + } + } + + B.set_dims(0, w); + B.set_dims(1, h); + + if (B.has_raw_data()) + { + B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float)); + } + else + { + for (int j = 0; j < (int)permuted_data.size(); j++) B.set_float_data(j, permuted_data[j]); + } + } + + // reduce + node->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + + reduced_node_count += 1; + i += 1; } - } - - // reduce - node->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= 1; - - reduced_node_count += 1; - i += 1; } - } } -void fuse_shufflechannel(onnx::GraphProto* mutable_graph, +void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // ShuffleChannel <= Reshape - Transpose - Reshape - // ShuffleChannel <= Reshape - Transpose - Constant - Reshape - if (node->op_type() == "Reshape") { - if (node_reference[node->output(0)] != 1) continue; - - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else { - // skip weight reshape - if (weights.find(node->input(1)) == weights.end()) continue; - - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } - - // 1 groups channels_per_group, height, width - // reverse style = channels_per_group, groups, height * width - if (shape.size() != 5 && shape.size() != 3) continue; - - if (shape.size() == 5 && shape[0] != 1) continue; - - if (i + 2 >= node_count) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; - - node3 = mutable_graph->mutable_node(i + 3); - } - - if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; - - if (node_reference[node2->output(0)] != 1) continue; - - // 0 2 1 3 4 - // reverse style = 1 0 2 - std::vector perm = get_node_attr_ai(*node2, "perm"); - if (perm.size() != 5 && perm.size() != 3) continue; - - if (perm.size() == 5 && - (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4)) - continue; - - if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)) continue; - - std::vector shape3; - if (node3->input_size() == 1) { - shape3 = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; - - shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); - } - - // 1, -1, height, width - // reverse style = group, -1, channels_per_group, height, width - if (shape3.size() != 4 && shape3.size() != 5) continue; - - if (shape3.size() == 4 && - (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2]))) - continue; - - if (shape3.size() == 5 && - (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2])) - continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - - node3->set_op_type("ShuffleChannel"); - node3->set_input(0, node->input(0)); - - onnx::AttributeProto* attr_group = node3->add_attribute(); - attr_group->set_name("group"); - attr_group->set_i(shape[1]); - - onnx::AttributeProto* attr_reverse = node3->add_attribute(); - attr_reverse->set_name("reverse"); - attr_reverse->set_i(shape.size() == 3); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // ShuffleChannel <= Reshape - Transpose - Reshape + // ShuffleChannel <= Reshape - Transpose - Constant - Reshape + if (node->op_type() == "Reshape") + { + if (node_reference[node->output(0)] != 1) continue; + + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node->input(1)) == weights.end()) continue; + + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } + + // 1 groups channels_per_group, height, width + // reverse style = channels_per_group, groups, height * width + if (shape.size() != 5 && shape.size() != 3) continue; + + if (shape.size() == 5 && shape[0] != 1) continue; + + if (i + 2 >= node_count) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; + + node3 = mutable_graph->mutable_node(i + 3); + } + + if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; + + if (node_reference[node2->output(0)] != 1) continue; + + // 0 2 1 3 4 + // reverse style = 1 0 2 + std::vector perm = get_node_attr_ai(*node2, "perm"); + if (perm.size() != 5 && perm.size() != 3) continue; + + if (perm.size() == 5 && + (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4)) + continue; + + if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)) continue; + + std::vector shape3; + if (node3->input_size() == 1) + { + shape3 = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; + + shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); + } + + // 1, -1, height, width + // reverse style = group, -1, channels_per_group, height, width + if (shape3.size() != 4 && shape3.size() != 5) continue; + + if (shape3.size() == 4 && + (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2]))) + continue; + + if (shape3.size() == 5 && + (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2])) + continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + + node3->set_op_type("ShuffleChannel"); + node3->set_input(0, node->input(0)); + + onnx::AttributeProto* attr_group = node3->add_attribute(); + attr_group->set_name("group"); + attr_group->set_i(shape[1]); + + onnx::AttributeProto* attr_reverse = node3->add_attribute(); + attr_reverse->set_name("reverse"); + attr_reverse->set_i(shape.size() == 3); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, +void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1) - if (node->op_type() == "ShuffleChannel") { - // reverse = 1 - int reverse = get_node_attr_i(*node, "reverse"); - if (reverse != 1) continue; + // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1) + if (node->op_type() == "ShuffleChannel") + { + // reverse = 1 + int reverse = get_node_attr_i(*node, "reverse"); + if (reverse != 1) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - if (node2->op_type() != "Gather" || node3->op_type() != "Gather") continue; + if (node2->op_type() != "Gather" || node3->op_type() != "Gather") continue; - if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0)) continue; + if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0)) continue; - // axis = 0 - int gather2_axis = get_node_attr_i(*node2, "axis"); - if (gather2_axis != 0) continue; + // axis = 0 + int gather2_axis = get_node_attr_i(*node2, "axis"); + if (gather2_axis != 0) continue; - // indices = 0 - if (weights.find(node2->input(1)) == weights.end()) continue; + // indices = 0 + if (weights.find(node2->input(1)) == weights.end()) continue; - std::vector gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); - if (gather2_indices.size() != 1 || gather2_indices[0] != 0) continue; + std::vector gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); + if (gather2_indices.size() != 1 || gather2_indices[0] != 0) continue; - // axis = 0 - int gather3_axis = get_node_attr_i(*node3, "axis"); - if (gather3_axis != 0) continue; + // axis = 0 + int gather3_axis = get_node_attr_i(*node3, "axis"); + if (gather3_axis != 0) continue; - // indices = 1 - if (weights.find(node3->input(1)) == weights.end()) continue; + // indices = 1 + if (weights.find(node3->input(1)) == weights.end()) continue; - std::vector gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]); - if (gather3_indices.size() != 1 || gather3_indices[0] != 1) continue; + std::vector gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]); + if (gather3_indices.size() != 1 || gather3_indices[0] != 1) continue; - // reduce - node2->set_op_type("noop_reducedncnn"); + // reduce + node2->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 2; - node_reference[node2->input(1)] -= 1; - node_reference[node3->input(1)] -= 1; + node_reference[node->output(0)] -= 2; + node_reference[node2->input(1)] -= 1; + node_reference[node3->input(1)] -= 1; - node3->set_op_type("Split"); - node3->clear_input(); - node3->add_input(node->output(0)); - node3->add_output(node3->output(0)); - node3->set_output(0, node2->output(0)); + node3->set_op_type("Split"); + node3->clear_input(); + node3->add_input(node->output(0)); + node3->add_output(node3->output(0)); + node3->set_output(0, node2->output(0)); - node3->clear_attribute(); - onnx::AttributeProto* attr_axis = node3->add_attribute(); - attr_axis->set_name("axis"); - attr_axis->set_i(1); + node3->clear_attribute(); + onnx::AttributeProto* attr_axis = node3->add_attribute(); + attr_axis->set_name("axis"); + attr_axis->set_i(1); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } /** @@ -369,2034 +416,2209 @@ void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, * @param blob_names * @param reduced_node_count */ -void fuse_conv_reshape(onnx::GraphProto* mutable_graph, +void fuse_conv_reshape(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - std::map> shape_context; - const int node_count = mutable_graph->node_size(); - - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* conv = mutable_graph->mutable_node(i); - - if (conv->op_type() != "Conv") { - continue; - } - - if (i + 4 >= node_count) { - continue; - } - - onnx::NodeProto *shape = nullptr, *slice = nullptr, *concat = nullptr, *reshape = nullptr; - - // match [Shape ... Slice, Concat ... Reshape] from near sequence, skip useless Constant - std::vector> candidates = { - {"Shape", &shape}, {"Slice", &slice}, {"Concat", &concat}, {"Reshape", &reshape}}; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + std::map> shape_context; + const int node_count = mutable_graph->node_size(); + + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* conv = mutable_graph->mutable_node(i); - int MAX = std::min(10, node_count - i - 1); - int pos_candidate = 0; + if (conv->op_type() != "Conv") + { + continue; + } - for (int j = 0; j < MAX; ++j) { - auto node_ptr = mutable_graph->mutable_node(j + i + 1); - if (node_ptr->op_type() == "Constant") { - continue; - } - if (node_ptr->op_type() == std::get<0>(candidates[pos_candidate])) { - *(std::get<1>(candidates[pos_candidate])) = node_ptr; - pos_candidate++; - } - } + if (i + 4 >= node_count) + { + continue; + } - if (pos_candidate != candidates.size()) { - // not match the sequence - continue; - } + onnx::NodeProto * shape = nullptr, *slice = nullptr, *concat = nullptr, *reshape = nullptr; + + // match [Shape ... Slice, Concat ... Reshape] from near sequence, skip useless Constant + std::vector> candidates = { + {"Shape", &shape}, + {"Slice", &slice}, + {"Concat", &concat}, + {"Reshape", &reshape}}; + + int MAX = std::min(10, node_count - i - 1); + int pos_candidate = 0; + + for (int j = 0; j < MAX; ++j) + { + auto node_ptr = mutable_graph->mutable_node(j + i + 1); + if (node_ptr->op_type() == "Constant") + { + continue; + } + if (node_ptr->op_type() == std::get<0>(candidates[pos_candidate])) + { + *(std::get<1>(candidates[pos_candidate])) = node_ptr; + pos_candidate++; + } + } - if (node_reference[conv->output(0)] != 2 || node_reference[shape->output(0)] != 1 || - node_reference[slice->output(0)] != 1 || node_reference[concat->output(0)] != 1 || - node_reference[reshape->output(0)] != 1) { - continue; - } + if (pos_candidate != candidates.size()) + { + // not match the sequence + continue; + } - // check the connections - if (shape->input(0) != conv->output(0) || reshape->input(0) != conv->output(0)) { - continue; - } - if (slice->input(0) != shape->output(0)) { - continue; - } - if (concat->input(0) != slice->output(0)) { - continue; - } - if (reshape->input(0) != conv->output(0) || reshape->input(1) != concat->output(0)) { - continue; - } + if (node_reference[conv->output(0)] != 2 || node_reference[shape->output(0)] != 1 || + node_reference[slice->output(0)] != 1 || node_reference[concat->output(0)] != 1 || + node_reference[reshape->output(0)] != 1) + { + continue; + } - // add reshape attr - auto result = query_shape(mutable_graph, concat, weights, shape_context); - if (!std::get<0>(result)) { - continue; - } - set_node_attr_ai(*reshape, "shape", std::get<1>(result)); + // check the connections + if (shape->input(0) != conv->output(0) || reshape->input(0) != conv->output(0)) + { + continue; + } + if (slice->input(0) != shape->output(0)) + { + continue; + } + if (concat->input(0) != slice->output(0)) + { + continue; + } + if (reshape->input(0) != conv->output(0) || reshape->input(1) != concat->output(0)) + { + continue; + } - // reconstruct graph - { - // remove reference - node_reference[reshape->input(1)] -= 1; - node_reference[concat->input(0)] -= 1; - node_reference[slice->input(0)] -= 1; - node_reference[shape->input(0)] -= 1; - - // remove tensor/blob on edge - blob_names.erase(slice->input(0)); - blob_names.erase(slice->input(1)); - blob_names.erase(slice->input(2)); - blob_names.erase(slice->input(3)); - weights.erase(slice->input(1)); - weights.erase(slice->input(2)); - weights.erase(slice->input(3)); - - blob_names.erase(concat->input(0)); - blob_names.erase(concat->input(1)); - weights.erase(concat->input(1)); - - blob_names.erase(reshape->input(0)); - - // update edge - shape->clear_input(); - reshape->clear_input(); - reshape->add_input(conv->output(0)); - - shape->set_op_type("noop_reducedncnn"); - slice->set_op_type("noop_reducedncnn"); - concat->set_op_type("noop_reducedncnn"); - - reduced_node_count += 3; + // add reshape attr + auto result = query_shape(mutable_graph, concat, weights, shape_context); + if (!std::get<0>(result)) + { + continue; + } + set_node_attr_ai(*reshape, "shape", std::get<1>(result)); + + // reconstruct graph + { + // remove reference + node_reference[reshape->input(1)] -= 1; + node_reference[concat->input(0)] -= 1; + node_reference[slice->input(0)] -= 1; + node_reference[shape->input(0)] -= 1; + + // remove tensor/blob on edge + blob_names.erase(slice->input(0)); + blob_names.erase(slice->input(1)); + blob_names.erase(slice->input(2)); + blob_names.erase(slice->input(3)); + weights.erase(slice->input(1)); + weights.erase(slice->input(2)); + weights.erase(slice->input(3)); + + blob_names.erase(concat->input(0)); + blob_names.erase(concat->input(1)); + weights.erase(concat->input(1)); + + blob_names.erase(reshape->input(0)); + + // update edge + shape->clear_input(); + reshape->clear_input(); + reshape->add_input(conv->output(0)); + + shape->set_op_type("noop_reducedncnn"); + slice->set_op_type("noop_reducedncnn"); + concat->set_op_type("noop_reducedncnn"); + + reduced_node_count += 3; + } + i += 3; } - i += 3; - } } -void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, +void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // Add/Sub/Mul/Div/Min/Max/Pow - if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || - node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || - node->op_type() == "Pow") { - if (weights.find(node->input(1)) == weights.end()) continue; + // Add/Sub/Mul/Div/Min/Max/Pow + if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || + node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || + node->op_type() == "Pow") + { + if (weights.find(node->input(1)) == weights.end()) continue; - const onnx::TensorProto& scalar_b = weights[node->input(1)]; - if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1) continue; + const onnx::TensorProto& scalar_b = weights[node->input(1)]; + if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1) continue; - float b = get_node_attr_from_input(scalar_b); + float b = get_node_attr_from_input(scalar_b); - node_reference[node->input(1)] -= 1; + node_reference[node->input(1)] -= 1; - std::string input = node->input(0); + std::string input = node->input(0); - node->clear_input(); - node->add_input(input); + node->clear_input(); + node->add_input(input); - onnx::AttributeProto* attr_with_scalar = node->add_attribute(); - attr_with_scalar->set_name("with_scalar"); - attr_with_scalar->set_i(1); + onnx::AttributeProto* attr_with_scalar = node->add_attribute(); + attr_with_scalar->set_name("with_scalar"); + attr_with_scalar->set_i(1); - onnx::AttributeProto* attr_b = node->add_attribute(); - attr_b->set_name("b"); - attr_b->set_f(b); + onnx::AttributeProto* attr_b = node->add_attribute(); + attr_b->set_name("b"); + attr_b->set_f(b); + } } - } } -void fuse_hardswish(onnx::GraphProto* mutable_graph, +void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6) - // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6)) - // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6) - // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6)) - // out = x * F.relu6(x + 3, inplace=True) / 6 - if (node->op_type() == "Add") { - if (node_reference[node->output(0)] != 1) continue; - - if (i + 3 >= node_count) continue; - - if (weights.find(node->input(1)) == weights.end()) continue; - - const onnx::TensorProto& add_three = weights[node->input(1)]; - if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; - - float constant_add_three = get_node_attr_from_input(add_three); - if (constant_add_three != 3.f) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - - if (node4->op_type() == "Constant") { - if (i + 4 >= node_count) continue; - - node4 = mutable_graph->mutable_node(i + 4); - } - - if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || - (node4->op_type() != "Div" && node4->op_type() != "Mul")) - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - float relu6_min; - float relu6_max; - if (node2->input_size() == 1) { - relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); - relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); - } else { - const onnx::TensorProto& min_tp = weights[node2->input(1)]; - const onnx::TensorProto& max_tp = weights[node2->input(2)]; - - relu6_min = get_node_attr_from_input(min_tp); - relu6_max = get_node_attr_from_input(max_tp); - } - if (relu6_min != 0.f || relu6_max != 6.f) continue; - - if (node_reference[node3->output(0)] != 1) continue; - - if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0)) continue; - - if (weights.find(node4->input(1)) == weights.end()) continue; - - const onnx::TensorProto& div_six = weights[node4->input(1)]; - if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; - - float constant_div_six = get_node_attr_from_input(div_six); - if (node4->op_type() == "Div" && constant_div_six != 6.f) continue; - if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= 1; - node_reference[node->input(1)] -= 1; - node_reference[node->output(0)] -= 1; - if (node2->input_size() == 3) { - node_reference[node2->input(1)] -= 1; - node_reference[node2->input(2)] -= 1; - } - node_reference[node2->output(0)] -= 1; - node_reference[node3->output(0)] -= 1; - node_reference[node4->input(1)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - - node4->set_op_type("HardSwish"); - node4->clear_input(); - node4->add_input(node->input(0)); - - onnx::AttributeProto* attr_alpha = node4->add_attribute(); - attr_alpha->set_name("alpha"); - attr_alpha->set_f(1.f / 6.f); - - onnx::AttributeProto* attr_beta = node4->add_attribute(); - attr_beta->set_name("beta"); - attr_beta->set_f(3.f / 6.f); - - reduced_node_count += 3; - i += 3; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6) + // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6)) + // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6) + // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6)) + // out = x * F.relu6(x + 3, inplace=True) / 6 + if (node->op_type() == "Add") + { + if (node_reference[node->output(0)] != 1) continue; + + if (i + 3 >= node_count) continue; + + if (weights.find(node->input(1)) == weights.end()) continue; + + const onnx::TensorProto& add_three = weights[node->input(1)]; + if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; + + float constant_add_three = get_node_attr_from_input(add_three); + if (constant_add_three != 3.f) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + + if (node4->op_type() == "Constant") + { + if (i + 4 >= node_count) continue; + + node4 = mutable_graph->mutable_node(i + 4); + } + + if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || + (node4->op_type() != "Div" && node4->op_type() != "Mul")) + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + float relu6_min; + float relu6_max; + if (node2->input_size() == 1) + { + relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); + relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); + } + else + { + const onnx::TensorProto& min_tp = weights[node2->input(1)]; + const onnx::TensorProto& max_tp = weights[node2->input(2)]; + + relu6_min = get_node_attr_from_input(min_tp); + relu6_max = get_node_attr_from_input(max_tp); + } + if (relu6_min != 0.f || relu6_max != 6.f) continue; + + if (node_reference[node3->output(0)] != 1) continue; + + if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0)) continue; + + if (weights.find(node4->input(1)) == weights.end()) continue; + + const onnx::TensorProto& div_six = weights[node4->input(1)]; + if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; + + float constant_div_six = get_node_attr_from_input(div_six); + if (node4->op_type() == "Div" && constant_div_six != 6.f) continue; + if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + node_reference[node->input(1)] -= 1; + node_reference[node->output(0)] -= 1; + if (node2->input_size() == 3) + { + node_reference[node2->input(1)] -= 1; + node_reference[node2->input(2)] -= 1; + } + node_reference[node2->output(0)] -= 1; + node_reference[node3->output(0)] -= 1; + node_reference[node4->input(1)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + + node4->set_op_type("HardSwish"); + node4->clear_input(); + node4->add_input(node->input(0)); + + onnx::AttributeProto* attr_alpha = node4->add_attribute(); + attr_alpha->set_name("alpha"); + attr_alpha->set_f(1.f / 6.f); + + onnx::AttributeProto* attr_beta = node4->add_attribute(); + attr_beta->set_name("beta"); + attr_beta->set_f(3.f / 6.f); + + reduced_node_count += 3; + i += 3; + } } - } - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // HardSwish <= HardSigmoid - Mul - // out = x * hsigmoid(x) - if (node->op_type() == "HardSigmoid") { - if (node_reference[node->output(0)] != 1) continue; + // HardSwish <= HardSigmoid - Mul + // out = x * hsigmoid(x) + if (node->op_type() == "HardSigmoid") + { + if (node_reference[node->output(0)] != 1) continue; - float alpha = get_node_attr_f(*node, "alpha", 0.2f); - float beta = get_node_attr_f(*node, "beta", 0.5f); + float alpha = get_node_attr_f(*node, "alpha", 0.2f); + float beta = get_node_attr_f(*node, "beta", 0.5f); - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "Mul") continue; + if (node2->op_type() != "Mul") continue; - if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; + if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->input(0)] -= 1; - node_reference[node->output(0)] -= 1; + node_reference[node->input(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node2->set_op_type("HardSwish"); - node2->clear_input(); - node2->add_input(node->input(0)); + node2->set_op_type("HardSwish"); + node2->clear_input(); + node2->add_input(node->input(0)); - onnx::AttributeProto* attr_alpha = node2->add_attribute(); - attr_alpha->set_name("alpha"); - attr_alpha->set_f(alpha); + onnx::AttributeProto* attr_alpha = node2->add_attribute(); + attr_alpha->set_name("alpha"); + attr_alpha->set_f(alpha); - onnx::AttributeProto* attr_beta = node2->add_attribute(); - attr_beta->set_name("beta"); - attr_beta->set_f(beta); + onnx::AttributeProto* attr_beta = node2->add_attribute(); + attr_beta->set_name("beta"); + attr_beta->set_f(beta); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, +void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6) - // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6)) - // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6) - // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6)) - // out = F.relu6(x + 3, inplace=True) / 6 - if (node->op_type() == "Add") { - if (node_reference[node->output(0)] != 1) continue; - - if (i + 2 >= node_count) continue; - - if (weights.find(node->input(1)) == weights.end()) continue; - - const onnx::TensorProto& add_three = weights[node->input(1)]; - if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; - - float constant_add_three = get_node_attr_from_input(add_three); - if (constant_add_three != 3.f) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; - - node3 = mutable_graph->mutable_node(i + 3); - } - - if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul")) - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - float relu6_min; - float relu6_max; - if (node2->input_size() == 1) { - relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); - relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); - } else { - const onnx::TensorProto& min_tp = weights[node2->input(1)]; - const onnx::TensorProto& max_tp = weights[node2->input(2)]; - - relu6_min = get_node_attr_from_input(min_tp); - relu6_max = get_node_attr_from_input(max_tp); - } - if (relu6_min != 0.f || relu6_max != 6.f) continue; - - if (weights.find(node3->input(1)) == weights.end()) continue; - - const onnx::TensorProto& div_six = weights[node3->input(1)]; - if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; - - float constant_div_six = get_node_attr_from_input(div_six); - if (node3->op_type() == "Div" && constant_div_six != 6.f) continue; - if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - - node_reference[node->input(1)] -= 1; - node_reference[node->output(0)] -= 1; - if (node2->input_size() == 3) { - node_reference[node2->input(1)] -= 1; - node_reference[node2->input(2)] -= 1; - } - node_reference[node2->output(0)] -= 1; - node_reference[node3->input(1)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - - node3->set_op_type("HardSigmoid"); - node3->clear_input(); - node3->add_input(node->input(0)); - - onnx::AttributeProto* attr_alpha = node3->add_attribute(); - attr_alpha->set_name("alpha"); - attr_alpha->set_f(1.f / 6.f); - - onnx::AttributeProto* attr_beta = node3->add_attribute(); - attr_beta->set_name("beta"); - attr_beta->set_f(3.f / 6.f); - - reduced_node_count += 2; - i += 2; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6) + // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6)) + // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6) + // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6)) + // out = F.relu6(x + 3, inplace=True) / 6 + if (node->op_type() == "Add") + { + if (node_reference[node->output(0)] != 1) continue; + + if (i + 2 >= node_count) continue; + + if (weights.find(node->input(1)) == weights.end()) continue; + + const onnx::TensorProto& add_three = weights[node->input(1)]; + if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; + + float constant_add_three = get_node_attr_from_input(add_three); + if (constant_add_three != 3.f) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; + + node3 = mutable_graph->mutable_node(i + 3); + } + + if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul")) + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + float relu6_min; + float relu6_max; + if (node2->input_size() == 1) + { + relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); + relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); + } + else + { + const onnx::TensorProto& min_tp = weights[node2->input(1)]; + const onnx::TensorProto& max_tp = weights[node2->input(2)]; + + relu6_min = get_node_attr_from_input(min_tp); + relu6_max = get_node_attr_from_input(max_tp); + } + if (relu6_min != 0.f || relu6_max != 6.f) continue; + + if (weights.find(node3->input(1)) == weights.end()) continue; + + const onnx::TensorProto& div_six = weights[node3->input(1)]; + if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; + + float constant_div_six = get_node_attr_from_input(div_six); + if (node3->op_type() == "Div" && constant_div_six != 6.f) continue; + if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + + node_reference[node->input(1)] -= 1; + node_reference[node->output(0)] -= 1; + if (node2->input_size() == 3) + { + node_reference[node2->input(1)] -= 1; + node_reference[node2->input(2)] -= 1; + } + node_reference[node2->output(0)] -= 1; + node_reference[node3->input(1)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + + node3->set_op_type("HardSigmoid"); + node3->clear_input(); + node3->add_input(node->input(0)); + + onnx::AttributeProto* attr_alpha = node3->add_attribute(); + attr_alpha->set_name("alpha"); + attr_alpha->set_f(1.f / 6.f); + + onnx::AttributeProto* attr_beta = node3->add_attribute(); + attr_beta->set_name("beta"); + attr_beta->set_f(3.f / 6.f); + + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); +void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // Swish <= Sigmoid - Mul - // x * torch.sigmoid(x) - if (node->op_type() == "Sigmoid") { - if (node_reference[node->output(0)] != 1) continue; + // Swish <= Sigmoid - Mul + // x * torch.sigmoid(x) + if (node->op_type() == "Sigmoid") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "Mul") continue; + if (node2->op_type() != "Mul") continue; - if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; + if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->input(0)] -= 1; - node_reference[node->output(0)] -= 1; + node_reference[node->input(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node2->set_op_type("Swish"); - node2->clear_input(); - node2->add_input(node->input(0)); + node2->set_op_type("Swish"); + node2->clear_input(); + node2->add_input(node->input(0)); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, +void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze - if (node->op_type() == "Unsqueeze") { - if (node_reference[node->output(0)] != 1) continue; + // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze + if (node->op_type() == "Unsqueeze") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze") continue; + if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze") continue; - if (node_reference[node2->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 1) continue; - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); - node2->set_input(0, node->input(0)); - node2->set_output(0, node3->output(0)); + node2->set_input(0, node->input(0)); + node2->set_output(0, node3->output(0)); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, +void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // PReLU <= Unsqueeze - PReLU - if (node->op_type() == "Unsqueeze") { - // check weight - if (weights.find(node->input(0)) == weights.end()) continue; + // PReLU <= Unsqueeze - PReLU + if (node->op_type() == "Unsqueeze") + { + // check weight + if (weights.find(node->input(0)) == weights.end()) continue; - onnx::TensorProto& B = weights[node->input(0)]; - if (B.dims_size() != 1) continue; + onnx::TensorProto& B = weights[node->input(0)]; + if (B.dims_size() != 1) continue; - if (node_reference[node->output(0)] != 1) continue; + if (node_reference[node->output(0)] != 1) continue; - // axes = (1, 2) - std::vector axes = get_node_attr_ai(*node, "axes"); - if (axes.size() != 2) continue; - if (axes[0] != 1 || axes[1] != 2) continue; + // axes = (1, 2) + std::vector axes = get_node_attr_ai(*node, "axes"); + if (axes.size() != 2) continue; + if (axes[0] != 1 || axes[1] != 2) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "PRelu") continue; + if (node2->op_type() != "PRelu") continue; - if (node2->input(1) != node->output(0)) continue; + if (node2->input(1) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node2->set_input(1, node->input(0)); + node2->set_input(1, node->input(0)); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_normalize(onnx::GraphProto* mutable_graph, +void fuse_normalize(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // Normalize <= X - ReduceL2 - Clip - Expand - Div - // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div - if (node->op_type() == "ReduceL2") { - if (node_reference[node->output(0)] != 1) continue; - - // axes = (1) - std::vector axes = get_node_attr_ai(*node, "axes"); - if (axes.size() != 1) continue; - if (axes[0] != 1) continue; - - if (i + 3 >= node_count) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - - bool has_shape_node = node3->op_type() == "Shape"; - onnx::NodeProto* node_shape = 0; - if (has_shape_node) { - if (i + 4 >= node_count) continue; - - node_shape = node3; - node3 = mutable_graph->mutable_node(i + 3); - node4 = mutable_graph->mutable_node(i + 4); - } - - if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div") - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - if (node_reference[node3->output(0)] != 1) continue; - - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || - node4->input(0) != node->input(0) || node4->input(1) != node3->output(0)) - continue; - - if (has_shape_node) { - if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0)) - continue; - } - - // +eps - float clip_min; - if (node2->input_size() == 1) { - clip_min = get_node_attr_f(*node2, "min", -FLT_MAX); - } else { - const onnx::TensorProto& min_tp = weights[node2->input(1)]; - - clip_min = get_node_attr_from_input(min_tp); - } - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - if (has_shape_node) { - node_shape->set_op_type("noop_reducedncnn"); - } - node3->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= has_shape_node ? 2 : 1; - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (has_shape_node) { - node_reference[node_shape->output(0)] -= 1; - } - node_reference[node3->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - if (has_shape_node) { - blob_names.erase(node_shape->output(0)); - } - blob_names.erase(node3->output(0)); - - node4->set_op_type("Normalize"); - node4->clear_input(); - node4->add_input(node->input(0)); - - onnx::AttributeProto* attr_alpha = node4->add_attribute(); - attr_alpha->set_name("eps"); - attr_alpha->set_f(clip_min); - - reduced_node_count += has_shape_node ? 4 : 3; - i += has_shape_node ? 4 : 3; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // Normalize <= X - ReduceL2 - Clip - Expand - Div + // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div + if (node->op_type() == "ReduceL2") + { + if (node_reference[node->output(0)] != 1) continue; + + // axes = (1) + std::vector axes = get_node_attr_ai(*node, "axes"); + if (axes.size() != 1) continue; + if (axes[0] != 1) continue; + + if (i + 3 >= node_count) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + + bool has_shape_node = node3->op_type() == "Shape"; + onnx::NodeProto* node_shape = 0; + if (has_shape_node) + { + if (i + 4 >= node_count) continue; + + node_shape = node3; + node3 = mutable_graph->mutable_node(i + 3); + node4 = mutable_graph->mutable_node(i + 4); + } + + if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div") + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + if (node_reference[node3->output(0)] != 1) continue; + + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || + node4->input(0) != node->input(0) || node4->input(1) != node3->output(0)) + continue; + + if (has_shape_node) + { + if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0)) + continue; + } + + // +eps + float clip_min; + if (node2->input_size() == 1) + { + clip_min = get_node_attr_f(*node2, "min", -FLT_MAX); + } + else + { + const onnx::TensorProto& min_tp = weights[node2->input(1)]; + + clip_min = get_node_attr_from_input(min_tp); + } + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + if (has_shape_node) + { + node_shape->set_op_type("noop_reducedncnn"); + } + node3->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= has_shape_node ? 2 : 1; + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (has_shape_node) + { + node_reference[node_shape->output(0)] -= 1; + } + node_reference[node3->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + if (has_shape_node) + { + blob_names.erase(node_shape->output(0)); + } + blob_names.erase(node3->output(0)); + + node4->set_op_type("Normalize"); + node4->clear_input(); + node4->add_input(node->input(0)); + + onnx::AttributeProto* attr_alpha = node4->add_attribute(); + attr_alpha->set_name("eps"); + attr_alpha->set_f(clip_min); + + reduced_node_count += has_shape_node ? 4 : 3; + i += has_shape_node ? 4 : 3; + } } - } } -void fuse_groupnorm(onnx::GraphProto* mutable_graph, +void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add - if (node->op_type() == "Reshape") { - if (node_reference[node->output(0)] != 1) continue; - - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else { - // skip weight reshape - if (weights.find(node->input(1)) == weights.end()) continue; - - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } - - // 0, group, -1 - if (shape.size() != 3) continue; - - if (shape[0] != 0 || shape[2] != -1) continue; - - int groups = shape[1]; - - if (i + 4 >= node_count) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - - if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || - node4->op_type() != "Mul" || node5->op_type() != "Add") - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - if (node_reference[node3->output(0)] != 1) continue; - - if (node_reference[node4->output(0)] != 1) continue; - - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || - node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0)) - continue; - - // +eps - float eps = get_node_attr_f(*node2, "epsilon", 1e-05f); - - // InstanceNormalization S=1 B=0 - std::vector S = get_node_attr_from_input_af(weights[node2->input(1)]); - std::vector B = get_node_attr_from_input_af(weights[node2->input(2)]); - if ((int)S.size() != groups || (int)B.size() != groups) continue; - - bool instancenorm_affine = false; - for (int j = 0; j < groups; j++) { - if (S[j] != 1.f || B[j] != 0.f) { - instancenorm_affine = true; - break; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add + if (node->op_type() == "Reshape") + { + if (node_reference[node->output(0)] != 1) continue; + + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node->input(1)) == weights.end()) continue; + + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } + + // 0, group, -1 + if (shape.size() != 3) continue; + + if (shape[0] != 0 || shape[2] != -1) continue; + + int groups = shape[1]; + + if (i + 4 >= node_count) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + + if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || + node4->op_type() != "Mul" || node5->op_type() != "Add") + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + if (node_reference[node3->output(0)] != 1) continue; + + if (node_reference[node4->output(0)] != 1) continue; + + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || + node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0)) + continue; + + // +eps + float eps = get_node_attr_f(*node2, "epsilon", 1e-05f); + + // InstanceNormalization S=1 B=0 + std::vector S = get_node_attr_from_input_af(weights[node2->input(1)]); + std::vector B = get_node_attr_from_input_af(weights[node2->input(2)]); + if ((int)S.size() != groups || (int)B.size() != groups) continue; + + bool instancenorm_affine = false; + for (int j = 0; j < groups; j++) + { + if (S[j] != 1.f || B[j] != 0.f) + { + instancenorm_affine = true; + break; + } + } + + if (instancenorm_affine) continue; + + std::vector shape2; + if (node3->input_size() == 1) + { + shape2 = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; + + shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]); + } + + // 1, channels, w, h + if (shape2.size() != 4) continue; + + if (shape2[0] != 1) continue; + + int channels = shape2[1]; + + // affine + std::vector affine_S = get_node_attr_from_input_af(weights[node4->input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node5->input(1)]); + if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && + affine_B[0] == 0.f) + { + // no affine + } + else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels) + { + // we only allow per-channel affine + continue; + } + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->input(1)] -= 1; + node_reference[node2->input(2)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + node_reference[node3->output(0)] -= 1; + node_reference[node4->output(0)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); + + std::string affine_scale = node4->input(1); + std::string affine_bias = node5->input(1); + + node5->set_op_type("GroupNorm"); + node5->clear_input(); + node5->add_input(node->input(0)); + node5->add_input(affine_scale); + node5->add_input(affine_bias); + + onnx::AttributeProto* attr_groups = node5->add_attribute(); + attr_groups->set_name("groups"); + attr_groups->set_i(groups); + + onnx::AttributeProto* attr_channels = node5->add_attribute(); + attr_channels->set_name("channels"); + attr_channels->set_i(channels); + + onnx::AttributeProto* attr_eps = node5->add_attribute(); + attr_eps->set_name("epsilon"); + attr_eps->set_f(eps); + + onnx::AttributeProto* attr_affine = node5->add_attribute(); + attr_affine->set_name("affine"); + attr_affine->set_i(1); + + reduced_node_count += 4; + i += 4; } - } - - if (instancenorm_affine) continue; - - std::vector shape2; - if (node3->input_size() == 1) { - shape2 = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; - - shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]); - } - - // 1, channels, w, h - if (shape2.size() != 4) continue; - - if (shape2[0] != 1) continue; - - int channels = shape2[1]; - - // affine - std::vector affine_S = get_node_attr_from_input_af(weights[node4->input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node5->input(1)]); - if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && - affine_B[0] == 0.f) { - // no affine - } else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels) { - // we only allow per-channel affine - continue; - } - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - node_reference[node->output(0)] -= 1; - node_reference[node2->input(1)] -= 1; - node_reference[node2->input(2)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } - node_reference[node3->output(0)] -= 1; - node_reference[node4->output(0)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - blob_names.erase(node4->output(0)); - - std::string affine_scale = node4->input(1); - std::string affine_bias = node5->input(1); - - node5->set_op_type("GroupNorm"); - node5->clear_input(); - node5->add_input(node->input(0)); - node5->add_input(affine_scale); - node5->add_input(affine_bias); - - onnx::AttributeProto* attr_groups = node5->add_attribute(); - attr_groups->set_name("groups"); - attr_groups->set_i(groups); - - onnx::AttributeProto* attr_channels = node5->add_attribute(); - attr_channels->set_name("channels"); - attr_channels->set_i(channels); - - onnx::AttributeProto* attr_eps = node5->add_attribute(); - attr_eps->set_name("epsilon"); - attr_eps->set_f(eps); - - onnx::AttributeProto* attr_affine = node5->add_attribute(); - attr_affine->set_name("affine"); - attr_affine->set_i(1); - - reduced_node_count += 4; - i += 4; } - } } -void fuse_layernorm(onnx::GraphProto* mutable_graph, +void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - - // Mul - Add - if (node->op_type() == "ReduceMean") { - if (node_reference[node->output(0)] != 1) continue; - - std::vector axes = get_node_attr_ai(*node, "axes"); - - // -1 - // -2 -1 - if (axes.size() != 1 && axes.size() != 2) continue; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - int normed_axes = (int)axes.size(); - if (normed_axes == 1 && axes[0] != -1) continue; - if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1)) continue; + // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div + // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - + // Mul - Add + if (node->op_type() == "ReduceMean") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 6 >= node_count) continue; + std::vector axes = get_node_attr_ai(*node, "axes"); - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); - onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); + // -1 + // -2 -1 + if (axes.size() != 1 && axes.size() != 2) continue; - if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || - node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || - node6->op_type() != "Sqrt" || node7->op_type() != "Div") - continue; + int normed_axes = (int)axes.size(); + if (normed_axes == 1 && axes[0] != -1) continue; + if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1)) continue; - if (node_reference[node2->output(0)] != 2) continue; + if (i + 6 >= node_count) continue; - if (node_reference[node3->output(0)] != 1) continue; + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); + onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); - if (node_reference[node4->output(0)] != 1) continue; + if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || + node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || + node6->op_type() != "Sqrt" || node7->op_type() != "Div") + continue; - if (node_reference[node5->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 2) continue; - if (node_reference[node6->output(0)] != 1) continue; + if (node_reference[node3->output(0)] != 1) continue; - if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0) || - node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || - node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || - node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0)) - continue; + if (node_reference[node4->output(0)] != 1) continue; - if (weights.find(node3->input(1)) == weights.end()) continue; + if (node_reference[node5->output(0)] != 1) continue; - const onnx::TensorProto& pow_two = weights[node3->input(1)]; - if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1) continue; + if (node_reference[node6->output(0)] != 1) continue; - float constant_pow_two = get_node_attr_from_input(pow_two); - if (constant_pow_two != 2.f) continue; + if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0) || + node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || + node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || + node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0)) + continue; - std::vector axes4 = get_node_attr_ai(*node4, "axes"); + if (weights.find(node3->input(1)) == weights.end()) continue; - // -1 - // -2 -1 - if ((int)axes4.size() != normed_axes) continue; + const onnx::TensorProto& pow_two = weights[node3->input(1)]; + if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1) continue; - if (normed_axes == 1 && axes4[0] != -1) continue; - if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1)) continue; + float constant_pow_two = get_node_attr_from_input(pow_two); + if (constant_pow_two != 2.f) continue; - if (weights.find(node5->input(1)) == weights.end()) continue; + std::vector axes4 = get_node_attr_ai(*node4, "axes"); - const onnx::TensorProto& add_eps = weights[node5->input(1)]; - if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1) continue; + // -1 + // -2 -1 + if ((int)axes4.size() != normed_axes) continue; - float eps = get_node_attr_from_input(add_eps); + if (normed_axes == 1 && axes4[0] != -1) continue; + if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1)) continue; - int affine = 0; - while (i + 8 < node_count) { - onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); - onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); + if (weights.find(node5->input(1)) == weights.end()) continue; - if (node8->op_type() != "Mul" || node9->op_type() != "Add") break; + const onnx::TensorProto& add_eps = weights[node5->input(1)]; + if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1) continue; - if (node_reference[node7->output(0)] != 1) break; + float eps = get_node_attr_from_input(add_eps); - if (node_reference[node8->output(0)] != 1) break; + int affine = 0; + while (i + 8 < node_count) + { + onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); + onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); - if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0)) break; + if (node8->op_type() != "Mul" || node9->op_type() != "Add") break; - // affine - std::vector affine_S = get_node_attr_from_input_af(weights[node8->input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node9->input(1)]); - if (affine_S.size() != affine_B.size()) break; + if (node_reference[node7->output(0)] != 1) break; - affine = 1; - break; - } + if (node_reference[node8->output(0)] != 1) break; - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - node5->set_op_type("noop_reducedncnn"); - node6->set_op_type("noop_reducedncnn"); + if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0)) break; - node_reference[node->input(0)] -= 1; - node_reference[node2->input(0)] -= 1; - node_reference[node2->input(1)] -= 1; - node_reference[node3->input(0)] -= 1; - node_reference[node3->input(1)] -= 1; - node_reference[node4->input(0)] -= 1; - node_reference[node5->input(0)] -= 1; - node_reference[node5->input(1)] -= 1; - node_reference[node6->input(0)] -= 1; - node_reference[node7->input(0)] -= 1; - node_reference[node7->input(1)] -= 1; + // affine + std::vector affine_S = get_node_attr_from_input_af(weights[node8->input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node9->input(1)]); + if (affine_S.size() != affine_B.size()) break; - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - blob_names.erase(node4->output(0)); - blob_names.erase(node5->output(0)); - blob_names.erase(node6->output(0)); + affine = 1; + break; + } - node_reference[node->input(0)] += 1; + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + node5->set_op_type("noop_reducedncnn"); + node6->set_op_type("noop_reducedncnn"); - if (affine == 0) { - node7->set_op_type("LayerNorm"); - node7->clear_input(); - node7->add_input(node->input(0)); + node_reference[node->input(0)] -= 1; + node_reference[node2->input(0)] -= 1; + node_reference[node2->input(1)] -= 1; + node_reference[node3->input(0)] -= 1; + node_reference[node3->input(1)] -= 1; + node_reference[node4->input(0)] -= 1; + node_reference[node5->input(0)] -= 1; + node_reference[node5->input(1)] -= 1; + node_reference[node6->input(0)] -= 1; + node_reference[node7->input(0)] -= 1; + node_reference[node7->input(1)] -= 1; - onnx::AttributeProto* attr_eps = node7->add_attribute(); - attr_eps->set_name("epsilon"); - attr_eps->set_f(eps); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); + blob_names.erase(node5->output(0)); + blob_names.erase(node6->output(0)); - onnx::AttributeProto* attr_affine = node7->add_attribute(); - attr_affine->set_name("affine"); - attr_affine->set_i(affine); + node_reference[node->input(0)] += 1; - reduced_node_count += 6; - i += 6; - } else // if (affine == 1) - { - onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); - onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); + if (affine == 0) + { + node7->set_op_type("LayerNorm"); + node7->clear_input(); + node7->add_input(node->input(0)); - node7->set_op_type("noop_reducedncnn"); - node8->set_op_type("noop_reducedncnn"); + onnx::AttributeProto* attr_eps = node7->add_attribute(); + attr_eps->set_name("epsilon"); + attr_eps->set_f(eps); - node_reference[node8->input(0)] -= 1; - node_reference[node9->input(0)] -= 1; + onnx::AttributeProto* attr_affine = node7->add_attribute(); + attr_affine->set_name("affine"); + attr_affine->set_i(affine); - blob_names.erase(node7->output(0)); - blob_names.erase(node8->output(0)); + reduced_node_count += 6; + i += 6; + } + else // if (affine == 1) + { + onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); + onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); - std::string affine_scale = node8->input(1); - std::string affine_bias = node9->input(1); + node7->set_op_type("noop_reducedncnn"); + node8->set_op_type("noop_reducedncnn"); - node9->set_op_type("LayerNorm"); - node9->clear_input(); - node9->add_input(node->input(0)); - node9->add_input(affine_scale); - node9->add_input(affine_bias); - - onnx::AttributeProto* attr_eps = node9->add_attribute(); - attr_eps->set_name("epsilon"); - attr_eps->set_f(eps); - - onnx::AttributeProto* attr_affine = node9->add_attribute(); - attr_affine->set_name("affine"); - attr_affine->set_i(affine); - - reduced_node_count += 8; - i += 8; - } + node_reference[node8->input(0)] -= 1; + node_reference[node9->input(0)] -= 1; + + blob_names.erase(node7->output(0)); + blob_names.erase(node8->output(0)); + + std::string affine_scale = node8->input(1); + std::string affine_bias = node9->input(1); + + node9->set_op_type("LayerNorm"); + node9->clear_input(); + node9->add_input(node->input(0)); + node9->add_input(affine_scale); + node9->add_input(affine_bias); + + onnx::AttributeProto* attr_eps = node9->add_attribute(); + attr_eps->set_name("epsilon"); + attr_eps->set_f(eps); + + onnx::AttributeProto* attr_affine = node9->add_attribute(); + attr_affine->set_name("affine"); + attr_affine->set_i(affine); + + reduced_node_count += 8; + i += 8; + } + } } - } } -void fuse_flatten(onnx::GraphProto* mutable_graph, +void fuse_flatten(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - // - Reshape - if (node->op_type() == "Shape") { - if (node_reference[node->output(0)] != 1) continue; - - if (i + 6 >= node_count) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); - onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); - - if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || - node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze" || - node6->op_type() != "Concat" || node7->op_type() != "Reshape") - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - // if (node_reference[node3->output(0)] != 1) - // continue; - - if (node_reference[node4->output(0)] != 1) continue; - - if (node_reference[node5->output(0)] != 1) continue; - - if (node_reference[node6->output(0)] != 1) continue; - - if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || - node5->input(0) != node3->output(0) || node6->input(0) != node4->output(0) || - node6->input(1) != node5->output(0) || node7->input(0) != node->input(0) || - node7->input(1) != node6->output(0)) - continue; - - // axis = 0 - int gather_axis = get_node_attr_i(*node2, "axis"); - if (gather_axis != 0) continue; - - // indices = 0 - if (weights.find(node2->input(1)) == weights.end()) continue; - - std::vector gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); - if (gather_indices.size() != 1 || gather_indices[0] != 0) continue; - - // axes = (0) - std::vector unsqueeze_axes = get_node_attr_ai(*node4, "axes"); - if (unsqueeze_axes.size() != 1) continue; - if (unsqueeze_axes[0] != 0) continue; - - // axes = (0) - std::vector unsqueeze2_axes = get_node_attr_ai(*node5, "axes"); - if (unsqueeze2_axes.size() != 1) continue; - if (unsqueeze2_axes[0] != 0) continue; - - // data = -1 - if (weights.find(node5->input(0)) == weights.end()) continue; - - std::vector unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]); - if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1) continue; - - // axis = 0 - int concat_axis = get_node_attr_i(*node6, "axis"); - if (concat_axis != 0) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - // node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - node5->set_op_type("noop_reducedncnn"); - node6->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= 1; - node_reference[node->output(0)] -= 1; - node_reference[node2->input(1)] -= 1; - node_reference[node2->output(0)] -= 1; - // node_reference[node3->output(0)] -= 1; - node_reference[node4->output(0)] -= 1; - node_reference[node5->input(0)] -= 1; - node_reference[node5->output(0)] -= 1; - node_reference[node6->output(0)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - // blob_names.erase(node3->output(0)); - blob_names.erase(node4->output(0)); - blob_names.erase(node5->output(0)); - blob_names.erase(node6->output(0)); - - node7->set_op_type("Flatten"); - node7->clear_input(); - node7->add_input(node->input(0)); - - reduced_node_count += 5; - i += 5; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat + // - Reshape + if (node->op_type() == "Shape") + { + if (node_reference[node->output(0)] != 1) continue; + + if (i + 6 >= node_count) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); + onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); + + if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || + node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze" || + node6->op_type() != "Concat" || node7->op_type() != "Reshape") + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + // if (node_reference[node3->output(0)] != 1) + // continue; + + if (node_reference[node4->output(0)] != 1) continue; + + if (node_reference[node5->output(0)] != 1) continue; + + if (node_reference[node6->output(0)] != 1) continue; + + if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || + node5->input(0) != node3->output(0) || node6->input(0) != node4->output(0) || + node6->input(1) != node5->output(0) || node7->input(0) != node->input(0) || + node7->input(1) != node6->output(0)) + continue; + + // axis = 0 + int gather_axis = get_node_attr_i(*node2, "axis"); + if (gather_axis != 0) continue; + + // indices = 0 + if (weights.find(node2->input(1)) == weights.end()) continue; + + std::vector gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); + if (gather_indices.size() != 1 || gather_indices[0] != 0) continue; + + // axes = (0) + std::vector unsqueeze_axes = get_node_attr_ai(*node4, "axes"); + if (unsqueeze_axes.size() != 1) continue; + if (unsqueeze_axes[0] != 0) continue; + + // axes = (0) + std::vector unsqueeze2_axes = get_node_attr_ai(*node5, "axes"); + if (unsqueeze2_axes.size() != 1) continue; + if (unsqueeze2_axes[0] != 0) continue; + + // data = -1 + if (weights.find(node5->input(0)) == weights.end()) continue; + + std::vector unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]); + if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1) continue; + + // axis = 0 + int concat_axis = get_node_attr_i(*node6, "axis"); + if (concat_axis != 0) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + // node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + node5->set_op_type("noop_reducedncnn"); + node6->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + node_reference[node->output(0)] -= 1; + node_reference[node2->input(1)] -= 1; + node_reference[node2->output(0)] -= 1; + // node_reference[node3->output(0)] -= 1; + node_reference[node4->output(0)] -= 1; + node_reference[node5->input(0)] -= 1; + node_reference[node5->output(0)] -= 1; + node_reference[node6->output(0)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + // blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); + blob_names.erase(node5->output(0)); + blob_names.erase(node6->output(0)); + + node7->set_op_type("Flatten"); + node7->clear_input(); + node7->add_input(node->input(0)); + + reduced_node_count += 5; + i += 5; + } } - } } -void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, +void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // PixelShuffle <= Reshape - Transpose - Reshape - // PixelShuffle <= Reshape - Transpose - Constant - Reshape - if (node->op_type() == "Reshape") { - if (node_reference[node->output(0)] != 1) continue; + // PixelShuffle <= Reshape - Transpose - Reshape + // PixelShuffle <= Reshape - Transpose - Constant - Reshape + if (node->op_type() == "Reshape") + { + if (node_reference[node->output(0)] != 1) continue; - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else { - // skip weight reshape - if (weights.find(node->input(1)) == weights.end()) continue; + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node->input(1)) == weights.end()) continue; - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } - // -1, 3, upscale_factor, upscale_factor, height, width - if (shape.size() != 6) continue; + // -1, 3, upscale_factor, upscale_factor, height, width + if (shape.size() != 6) continue; - if (shape[0] != 1 && shape[0] != -1) continue; + if (shape[0] != 1 && shape[0] != -1) continue; - if (shape[2] != shape[3]) continue; + if (shape[2] != shape[3]) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; - node3 = mutable_graph->mutable_node(i + 3); - } + node3 = mutable_graph->mutable_node(i + 3); + } - if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; + if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; - if (node_reference[node2->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 1) continue; - // 0 1 4 2 5 3 - std::vector perm = get_node_attr_ai(*node2, "perm"); - if (perm.size() != 6) continue; + // 0 1 4 2 5 3 + std::vector perm = get_node_attr_ai(*node2, "perm"); + if (perm.size() != 6) continue; - if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || - perm[5] != 3) - continue; + if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || + perm[5] != 3) + continue; - std::vector shape3; - if (node3->input_size() == 1) { - shape3 = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; + std::vector shape3; + if (node3->input_size() == 1) + { + shape3 = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; - shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); - } + shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); + } - // -1, 3, height, width - if (shape3.size() != 4) continue; + // -1, 3, height, width + if (shape3.size() != 4) continue; - if (shape3[0] != 1 && shape3[0] != -1) continue; + if (shape3[0] != 1 && shape3[0] != -1) continue; - if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || - shape3[3] != shape[3] * shape[5]) - continue; + if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || + shape3[3] != shape[3] * shape[5]) + continue; - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); - node3->set_op_type("PixelShuffle"); - node3->set_input(0, node->input(0)); + node3->set_op_type("PixelShuffle"); + node3->set_input(0, node->input(0)); - onnx::AttributeProto* attr_group = node3->add_attribute(); - attr_group->set_name("scale_factor"); - attr_group->set_i(shape[2]); + onnx::AttributeProto* attr_group = node3->add_attribute(); + attr_group->set_name("scale_factor"); + attr_group->set_i(shape[2]); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); +void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // PixelShuffle <= Reshape - Transpose - Reshape - // PixelShuffle <= Reshape - Transpose - Constant - Reshape - if (node->op_type() == "Reshape") { - if (node_reference[node->output(0)] != 1) continue; + // PixelShuffle <= Reshape - Transpose - Reshape + // PixelShuffle <= Reshape - Transpose - Constant - Reshape + if (node->op_type() == "Reshape") + { + if (node_reference[node->output(0)] != 1) continue; - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else { - // skip weight reshape - if (weights.find(node->input(1)) == weights.end()) continue; + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node->input(1)) == weights.end()) continue; - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } - // -1, 3, out_height, block_size, out_width, block_size - if (shape.size() != 6) continue; + // -1, 3, out_height, block_size, out_width, block_size + if (shape.size() != 6) continue; - if (shape[0] != 1 && shape[0] != -1) continue; + if (shape[0] != 1 && shape[0] != -1) continue; - if (shape[3] != shape[5]) continue; + if (shape[3] != shape[5]) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; - node3 = mutable_graph->mutable_node(i + 3); - } + node3 = mutable_graph->mutable_node(i + 3); + } - if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; + if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; - if (node_reference[node2->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 1) continue; - // 0 1 3 5 2 4 - std::vector perm = get_node_attr_ai(*node2, "perm"); - if (perm.size() != 6) continue; + // 0 1 3 5 2 4 + std::vector perm = get_node_attr_ai(*node2, "perm"); + if (perm.size() != 6) continue; - if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || - perm[5] != 4) - continue; + if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || + perm[5] != 4) + continue; - std::vector shape3; - if (node3->input_size() == 1) { - shape3 = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; + std::vector shape3; + if (node3->input_size() == 1) + { + shape3 = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; - shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); - } + shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); + } - // -1, out_channels, out_height, out_width - if (shape3.size() != 4) continue; + // -1, out_channels, out_height, out_width + if (shape3.size() != 4) continue; - if (shape3[0] != 1 && shape3[0] != -1) continue; + if (shape3[0] != 1 && shape3[0] != -1) continue; - if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || - shape3[3] != shape[4]) - continue; + if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || + shape3[3] != shape[4]) + continue; - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); - node3->set_op_type("Reorg"); - node3->set_input(0, node->input(0)); + node3->set_op_type("Reorg"); + node3->set_input(0, node->input(0)); - onnx::AttributeProto* attr_group = node3->add_attribute(); - attr_group->set_name("stride"); - attr_group->set_i(shape[3]); + onnx::AttributeProto* attr_group = node3->add_attribute(); + attr_group->set_name("stride"); + attr_group->set_i(shape[3]); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, +void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max - if (node->op_type() == "Expand") { - if (node_reference[node->output(0)] != 1) continue; + // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max + if (node->op_type() == "Expand") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && - node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max") - continue; + if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && + node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max") + continue; - if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0)) continue; + if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } + node_reference[node->output(0)] -= 1; + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - if (node2->input(0) == node->output(0)) { - node2->set_input(0, node->input(0)); - } else { - node2->set_input(1, node->input(0)); - } + if (node2->input(0) == node->output(0)) + { + node2->set_input(0, node->input(0)); + } + else + { + node2->set_input(1, node->input(0)); + } - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, +void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose - // or LSTM(bi) <= LSTM(bi) - Transpose Constant - Reshape - Transpose - if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") { - if (node_reference[node->output(0)] != 1) continue; + // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose + // or LSTM(bi) <= LSTM(bi) - Transpose Constant - Reshape - Transpose + if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - // skip if second ops is constant - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; - node3 = mutable_graph->mutable_node(i + 3); - i += 1; - } + // skip if second ops is constant + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; + node3 = mutable_graph->mutable_node(i + 3); + i += 1; + } - if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; + if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; - if (node_reference[node2->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 1) continue; - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; - std::string direction = get_node_attr_s(*node, "direction"); - if (direction != "bidirectional") continue; + std::string direction = get_node_attr_s(*node, "direction"); + if (direction != "bidirectional") continue; - // 0 2 1 3 - std::vector perm = get_node_attr_ai(*node2, "perm"); - if (perm.size() != 4) continue; + // 0 2 1 3 + std::vector perm = get_node_attr_ai(*node2, "perm"); + if (perm.size() != 4) continue; - if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3) continue; + if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3) continue; - std::vector shape; - if (node3->input_size() == 1) { - shape = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; + std::vector shape; + if (node3->input_size() == 1) + { + shape = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; - shape = get_node_attr_from_input_ai(weights[node3->input(1)]); - } + shape = get_node_attr_from_input_ai(weights[node3->input(1)]); + } - // 0 0 -1 - if (shape.size() != 3) continue; + // 0 0 -1 + if (shape.size() != 3) continue; - if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1) continue; + if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1) continue; - // reduce - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); + // reduce + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); - node->set_output(0, node3->output(0)); + node->set_output(0, node3->output(0)); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; - if (i + 1 < node_count) { - if (node_reference[node3->output(0)] != 1) continue; + if (i + 1 < node_count) + { + if (node_reference[node3->output(0)] != 1) continue; - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1); - if (node4->op_type() != "Transpose") continue; + if (node4->op_type() != "Transpose") continue; - if (node4->input(0) != node->output(0)) continue; + if (node4->input(0) != node->output(0)) continue; - // 1 0 2 - std::vector perm4 = get_node_attr_ai(*node4, "perm"); - if (perm4.size() != 3) continue; + // 1 0 2 + std::vector perm4 = get_node_attr_ai(*node4, "perm"); + if (perm4.size() != 3) continue; - if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; + if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; - // reduce - node4->set_op_type("noop_reducedncnn"); + // reduce + node4->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node->set_output(0, node4->output(0)); + node->set_output(0, node4->output(0)); - reduced_node_count += 1; - i += 1; - } + reduced_node_count += 1; + i += 1; + } + } } - } - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose - if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") { - if (node_reference[node->output(0)] != 1) continue; + // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose + if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "Squeeze") continue; + if (node2->op_type() != "Squeeze") continue; - if (node2->input(0) != node->output(0)) continue; + if (node2->input(0) != node->output(0)) continue; - std::string direction = get_node_attr_s(*node, "direction"); - if (direction == "bidirectional") continue; + std::string direction = get_node_attr_s(*node, "direction"); + if (direction == "bidirectional") continue; - // 1 - std::vector axes = get_node_attr_ai(*node2, "axes"); - if (axes.size() != 1) continue; + // 1 + std::vector axes = get_node_attr_ai(*node2, "axes"); + if (axes.size() != 1) continue; - if (axes[0] != 1) continue; + if (axes[0] != 1) continue; - // reduce - node2->set_op_type("noop_reducedncnn"); + // reduce + node2->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node->set_output(0, node2->output(0)); + node->set_output(0, node2->output(0)); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; - if (i + 1 < node_count) { - if (node_reference[node2->output(0)] != 1) continue; + if (i + 1 < node_count) + { + if (node_reference[node2->output(0)] != 1) continue; - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1); - if (node3->op_type() != "Transpose") continue; + if (node3->op_type() != "Transpose") continue; - if (node3->input(0) != node->output(0)) continue; + if (node3->input(0) != node->output(0)) continue; - // 1 0 2 - std::vector perm4 = get_node_attr_ai(*node3, "perm"); - if (perm4.size() != 3) continue; + // 1 0 2 + std::vector perm4 = get_node_attr_ai(*node3, "perm"); + if (perm4.size() != 3) continue; - if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; + if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; - // reduce - node3->set_op_type("noop_reducedncnn"); + // reduce + node3->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node->set_output(0, node3->output(0)); + node->set_output(0, node3->output(0)); - reduced_node_count += 1; - i += 1; - } + reduced_node_count += 1; + i += 1; + } + } } - } - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // LSTM <= Transpose - LSTM - if (node->op_type() == "Transpose") { - if (node_reference[node->output(0)] != 1) continue; + // LSTM <= Transpose - LSTM + if (node->op_type() == "Transpose") + { + if (node_reference[node->output(0)] != 1) continue; - // 1 0 2 - std::vector perm = get_node_attr_ai(*node, "perm"); - if (perm.size() != 3) continue; + // 1 0 2 + std::vector perm = get_node_attr_ai(*node, "perm"); + if (perm.size() != 3) continue; - if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2) continue; + if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN") - continue; + if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN") + continue; - if (node2->input(0) != node->output(0)) continue; + if (node2->input(0) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node2->set_input(0, node->input(0)); + node2->set_input(0, node->input(0)); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_multiheadattention(onnx::GraphProto* mutable_graph, +void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // MultiHeadAttention <= MatMul(q) - Add - // - MatMul(k) - Add - // - MatMul(v) - Add - // - Mul - // - Reshape - Transpose - // - Reshape - Reshape - Transpose - Transpose - // - Gemm - Softmax - Gemm - Transpose - Reshape - - // MatMul - Add - if (node->op_type() == "MatMul") { - if (i + 19 >= node_count) continue; - - if (node_reference[node->output(0)] != 1) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); - onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); - onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); - onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); - onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); - onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); - onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); - onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); - onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); - onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); - onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); - onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); - onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17); - onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18); - onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19); - - if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || - node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || - node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || - node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || - node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || - node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || - node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || - node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || - node20->op_type() != "Add") - continue; - - if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || - node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || - node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || - node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || - node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || - node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || - node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || - node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || - node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1) - continue; - - if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || - node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || - node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || - node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || - node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || - node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || - node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || - node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || - node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || - node20->input(0) != node19->output(0)) - continue; - - std::vector q_B = get_node_attr_from_input_af(weights[node2->input(1)]); - std::vector k_B = get_node_attr_from_input_af(weights[node4->input(1)]); - std::vector v_B = get_node_attr_from_input_af(weights[node6->input(1)]); - std::vector o_B = get_node_attr_from_input_af(weights[node20->input(1)]); - - if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size()) - continue; - - int embed_dim = q_B.size(); - - // 1 0 2 - std::vector perm9 = get_node_attr_ai(*node9, "perm"); - std::vector perm12 = get_node_attr_ai(*node12, "perm"); - if (perm9.size() != 3 || perm12.size() != 3) continue; - - if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || - perm12[2] != 2) - continue; - - // 1 2 0 - std::vector perm13 = get_node_attr_ai(*node13, "perm"); - if (perm13.size() != 3) continue; - - if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0) continue; - - // 1 0 2 - std::vector perm17 = get_node_attr_ai(*node17, "perm"); - if (perm17.size() != 3) continue; - - if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2) continue; - - int softmax_axis = get_node_attr_i(*node15, "axis"); - if (softmax_axis != 2) continue; - - // 1/-1, seqlen * num_heads, embed_dim / num_heads - std::vector shape8; - std::vector shape10; - std::vector shape11; - if (node8->input_size() == 1) { - shape8 = get_node_attr_ai(*node8, "shape"); - } else { - // skip weight reshape - if (weights.find(node8->input(1)) == weights.end()) continue; - - shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); - } - if (node10->input_size() == 1) { - shape10 = get_node_attr_ai(*node10, "shape"); - } else { - // skip weight reshape - if (weights.find(node10->input(1)) == weights.end()) continue; - - shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]); - } - if (node11->input_size() == 1) { - shape11 = get_node_attr_ai(*node11, "shape"); - } else { - // skip weight reshape - if (weights.find(node11->input(1)) == weights.end()) continue; - - shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]); - } - - if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3) continue; - - if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || - shape8[2] != shape11[2]) - continue; - - int num_heads = embed_dim / shape8[2]; - - // 1, seqlen, embed_dim - std::vector shape18; - if (node18->input_size() == 1) { - shape18 = get_node_attr_ai(*node18, "shape"); - } else { - // skip weight reshape - if (weights.find(node18->input(1)) == weights.end()) continue; - - shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]); - } - - if (shape18.size() != 3) continue; - - if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1]) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - node5->set_op_type("noop_reducedncnn"); - node6->set_op_type("noop_reducedncnn"); - node7->set_op_type("noop_reducedncnn"); - node8->set_op_type("noop_reducedncnn"); - node9->set_op_type("noop_reducedncnn"); - node10->set_op_type("noop_reducedncnn"); - node11->set_op_type("noop_reducedncnn"); - node12->set_op_type("noop_reducedncnn"); - node13->set_op_type("noop_reducedncnn"); - node14->set_op_type("noop_reducedncnn"); - node15->set_op_type("noop_reducedncnn"); - node16->set_op_type("noop_reducedncnn"); - node17->set_op_type("noop_reducedncnn"); - node18->set_op_type("noop_reducedncnn"); - node19->set_op_type("noop_reducedncnn"); - - node_reference[node2->input(0)] -= 1; - node_reference[node4->input(0)] -= 1; - node_reference[node6->input(0)] -= 1; - node_reference[node7->input(0)] -= 1; - node_reference[node7->input(1)] -= 1; - node_reference[node8->input(0)] -= 1; - if (node8->input_size() == 2) { - node_reference[node8->input(1)] -= 1; - } - node_reference[node9->input(0)] -= 1; - node_reference[node10->input(0)] -= 1; - if (node10->input_size() == 2) { - node_reference[node10->input(1)] -= 1; - } - node_reference[node11->input(0)] -= 1; - if (node11->input_size() == 2) { - node_reference[node11->input(1)] -= 1; - } - node_reference[node12->input(0)] -= 1; - node_reference[node13->input(0)] -= 1; - node_reference[node14->input(0)] -= 1; - node_reference[node14->input(1)] -= 1; - node_reference[node15->input(0)] -= 1; - node_reference[node16->input(0)] -= 1; - node_reference[node16->input(1)] -= 1; - node_reference[node17->input(0)] -= 1; - node_reference[node18->input(0)] -= 1; - if (node18->input_size() == 2) { - node_reference[node18->input(1)] -= 1; - } - node_reference[node19->input(0)] -= 1; - node_reference[node20->input(0)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - blob_names.erase(node4->output(0)); - blob_names.erase(node5->output(0)); - blob_names.erase(node6->output(0)); - blob_names.erase(node7->output(0)); - blob_names.erase(node8->output(0)); - blob_names.erase(node9->output(0)); - blob_names.erase(node10->output(0)); - blob_names.erase(node11->output(0)); - blob_names.erase(node12->output(0)); - blob_names.erase(node13->output(0)); - blob_names.erase(node14->output(0)); - blob_names.erase(node15->output(0)); - blob_names.erase(node16->output(0)); - blob_names.erase(node17->output(0)); - blob_names.erase(node18->output(0)); - blob_names.erase(node19->output(0)); - - std::string qw = node->input(1); - std::string qb = node2->input(1); - std::string kw = node3->input(1); - std::string kb = node4->input(1); - std::string vw = node5->input(1); - std::string vb = node6->input(1); - std::string ow = node19->input(1); - std::string ob = node20->input(1); - - node20->set_op_type("MultiHeadAttention"); - node20->clear_input(); - node20->add_input(node->input(0)); - node20->add_input(node3->input(0)); - node20->add_input(node5->input(0)); - // q - node20->add_input(qw); - node20->add_input(qb); - // k - node20->add_input(kw); - node20->add_input(kb); - // v - node20->add_input(vw); - node20->add_input(vb); - // out linear - node20->add_input(ow); - node20->add_input(ob); - - onnx::AttributeProto* attr_embed_dim = node20->add_attribute(); - attr_embed_dim->set_name("embed_dim"); - attr_embed_dim->set_i(embed_dim); - - onnx::AttributeProto* attr_num_heads = node20->add_attribute(); - attr_num_heads->set_name("num_heads"); - attr_num_heads->set_i(num_heads); - - reduced_node_count += 19; - i += 19; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // MultiHeadAttention <= MatMul(q) - Add + // - MatMul(k) - Add + // - MatMul(v) - Add + // - Mul + // - Reshape - Transpose + // - Reshape - Reshape - Transpose - Transpose + // - Gemm - Softmax - Gemm - Transpose - Reshape - + // MatMul - Add + if (node->op_type() == "MatMul") + { + if (i + 19 >= node_count) continue; + + if (node_reference[node->output(0)] != 1) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); + onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); + onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); + onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); + onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); + onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); + onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); + onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); + onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); + onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); + onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); + onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); + onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17); + onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18); + onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19); + + if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || + node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || + node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || + node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || + node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || + node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || + node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || + node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || + node20->op_type() != "Add") + continue; + + if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || + node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || + node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || + node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || + node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || + node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || + node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || + node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || + node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1) + continue; + + if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || + node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || + node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || + node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || + node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || + node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || + node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || + node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || + node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || + node20->input(0) != node19->output(0)) + continue; + + std::vector q_B = get_node_attr_from_input_af(weights[node2->input(1)]); + std::vector k_B = get_node_attr_from_input_af(weights[node4->input(1)]); + std::vector v_B = get_node_attr_from_input_af(weights[node6->input(1)]); + std::vector o_B = get_node_attr_from_input_af(weights[node20->input(1)]); + + if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size()) + continue; + + int embed_dim = q_B.size(); + + // 1 0 2 + std::vector perm9 = get_node_attr_ai(*node9, "perm"); + std::vector perm12 = get_node_attr_ai(*node12, "perm"); + if (perm9.size() != 3 || perm12.size() != 3) continue; + + if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || + perm12[2] != 2) + continue; + + // 1 2 0 + std::vector perm13 = get_node_attr_ai(*node13, "perm"); + if (perm13.size() != 3) continue; + + if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0) continue; + + // 1 0 2 + std::vector perm17 = get_node_attr_ai(*node17, "perm"); + if (perm17.size() != 3) continue; + + if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2) continue; + + int softmax_axis = get_node_attr_i(*node15, "axis"); + if (softmax_axis != 2) continue; + + // 1/-1, seqlen * num_heads, embed_dim / num_heads + std::vector shape8; + std::vector shape10; + std::vector shape11; + if (node8->input_size() == 1) + { + shape8 = get_node_attr_ai(*node8, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node8->input(1)) == weights.end()) continue; + + shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); + } + if (node10->input_size() == 1) + { + shape10 = get_node_attr_ai(*node10, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node10->input(1)) == weights.end()) continue; + + shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]); + } + if (node11->input_size() == 1) + { + shape11 = get_node_attr_ai(*node11, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node11->input(1)) == weights.end()) continue; + + shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]); + } + + if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3) continue; + + if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || + shape8[2] != shape11[2]) + continue; + + int num_heads = embed_dim / shape8[2]; + + // 1, seqlen, embed_dim + std::vector shape18; + if (node18->input_size() == 1) + { + shape18 = get_node_attr_ai(*node18, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node18->input(1)) == weights.end()) continue; + + shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]); + } + + if (shape18.size() != 3) continue; + + if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1]) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + node5->set_op_type("noop_reducedncnn"); + node6->set_op_type("noop_reducedncnn"); + node7->set_op_type("noop_reducedncnn"); + node8->set_op_type("noop_reducedncnn"); + node9->set_op_type("noop_reducedncnn"); + node10->set_op_type("noop_reducedncnn"); + node11->set_op_type("noop_reducedncnn"); + node12->set_op_type("noop_reducedncnn"); + node13->set_op_type("noop_reducedncnn"); + node14->set_op_type("noop_reducedncnn"); + node15->set_op_type("noop_reducedncnn"); + node16->set_op_type("noop_reducedncnn"); + node17->set_op_type("noop_reducedncnn"); + node18->set_op_type("noop_reducedncnn"); + node19->set_op_type("noop_reducedncnn"); + + node_reference[node2->input(0)] -= 1; + node_reference[node4->input(0)] -= 1; + node_reference[node6->input(0)] -= 1; + node_reference[node7->input(0)] -= 1; + node_reference[node7->input(1)] -= 1; + node_reference[node8->input(0)] -= 1; + if (node8->input_size() == 2) + { + node_reference[node8->input(1)] -= 1; + } + node_reference[node9->input(0)] -= 1; + node_reference[node10->input(0)] -= 1; + if (node10->input_size() == 2) + { + node_reference[node10->input(1)] -= 1; + } + node_reference[node11->input(0)] -= 1; + if (node11->input_size() == 2) + { + node_reference[node11->input(1)] -= 1; + } + node_reference[node12->input(0)] -= 1; + node_reference[node13->input(0)] -= 1; + node_reference[node14->input(0)] -= 1; + node_reference[node14->input(1)] -= 1; + node_reference[node15->input(0)] -= 1; + node_reference[node16->input(0)] -= 1; + node_reference[node16->input(1)] -= 1; + node_reference[node17->input(0)] -= 1; + node_reference[node18->input(0)] -= 1; + if (node18->input_size() == 2) + { + node_reference[node18->input(1)] -= 1; + } + node_reference[node19->input(0)] -= 1; + node_reference[node20->input(0)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); + blob_names.erase(node5->output(0)); + blob_names.erase(node6->output(0)); + blob_names.erase(node7->output(0)); + blob_names.erase(node8->output(0)); + blob_names.erase(node9->output(0)); + blob_names.erase(node10->output(0)); + blob_names.erase(node11->output(0)); + blob_names.erase(node12->output(0)); + blob_names.erase(node13->output(0)); + blob_names.erase(node14->output(0)); + blob_names.erase(node15->output(0)); + blob_names.erase(node16->output(0)); + blob_names.erase(node17->output(0)); + blob_names.erase(node18->output(0)); + blob_names.erase(node19->output(0)); + + std::string qw = node->input(1); + std::string qb = node2->input(1); + std::string kw = node3->input(1); + std::string kb = node4->input(1); + std::string vw = node5->input(1); + std::string vb = node6->input(1); + std::string ow = node19->input(1); + std::string ob = node20->input(1); + + node20->set_op_type("MultiHeadAttention"); + node20->clear_input(); + node20->add_input(node->input(0)); + node20->add_input(node3->input(0)); + node20->add_input(node5->input(0)); + // q + node20->add_input(qw); + node20->add_input(qb); + // k + node20->add_input(kw); + node20->add_input(kb); + // v + node20->add_input(vw); + node20->add_input(vb); + // out linear + node20->add_input(ow); + node20->add_input(ob); + + onnx::AttributeProto* attr_embed_dim = node20->add_attribute(); + attr_embed_dim->set_name("embed_dim"); + attr_embed_dim->set_i(embed_dim); + + onnx::AttributeProto* attr_num_heads = node20->add_attribute(); + attr_num_heads->set_name("num_heads"); + attr_num_heads->set_i(num_heads); + + reduced_node_count += 19; + i += 19; + } } - } - - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // MultiHeadAttention <= MatMul(qkv) - Add - Split - // - Mul - // - Reshape - Transpose - // - Reshape - Reshape - Transpose - Transpose - // - Gemm - Softmax - Gemm - Transpose - Reshape - - // MatMul - Add - if (node->op_type() == "MatMul") { - if (i + 16 >= node_count) continue; - - if (node_reference[node->output(0)] != 1) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); - onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); - onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); - onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); - onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); - onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); - onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); - onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); - onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); - onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); - onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); - onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); - - if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || - node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || - node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || - node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || - node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || - node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || - node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || - node17->op_type() != "Add") - continue; - - if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || - node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || - node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || - node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || - node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || - node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || - node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || - node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || - node_reference[node16->output(0)] != 1) - continue; - - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || - node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || - node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || - node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || - node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || - node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || - node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || - node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || - node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0)) - continue; - - std::vector qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]); - std::vector o_B = get_node_attr_from_input_af(weights[node17->input(1)]); - - if (qkv_B.size() != o_B.size() * 3) continue; - - int embed_dim = o_B.size(); - - // 1 0 2 - std::vector perm6 = get_node_attr_ai(*node6, "perm"); - std::vector perm9 = get_node_attr_ai(*node9, "perm"); - if (perm6.size() != 3 || perm9.size() != 3) continue; - - if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || - perm9[2] != 2) - continue; - - // 1 2 0 - std::vector perm10 = get_node_attr_ai(*node10, "perm"); - if (perm10.size() != 3) continue; - - if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0) continue; - - // 1 0 2 - std::vector perm14 = get_node_attr_ai(*node14, "perm"); - if (perm14.size() != 3) continue; - - if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2) continue; - - int softmax_axis = get_node_attr_i(*node12, "axis"); - if (softmax_axis != 2) continue; - - // 1/-1, seqlen * num_heads, embed_dim / num_heads - std::vector shape5; - std::vector shape7; - std::vector shape8; - if (node5->input_size() == 1) { - shape5 = get_node_attr_ai(*node5, "shape"); - } else { - // skip weight reshape - if (weights.find(node5->input(1)) == weights.end()) continue; - - shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]); - } - if (node7->input_size() == 1) { - shape7 = get_node_attr_ai(*node7, "shape"); - } else { - // skip weight reshape - if (weights.find(node7->input(1)) == weights.end()) continue; - - shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]); - } - if (node8->input_size() == 1) { - shape8 = get_node_attr_ai(*node8, "shape"); - } else { - // skip weight reshape - if (weights.find(node8->input(1)) == weights.end()) continue; - - shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); - } - - if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3) continue; - - if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || - shape5[2] != shape8[2]) - continue; - - int num_heads = embed_dim / shape5[2]; - - // 1, seqlen, embed_dim - std::vector shape15; - if (node15->input_size() == 1) { - shape15 = get_node_attr_ai(*node15, "shape"); - } else { - // skip weight reshape - if (weights.find(node15->input(1)) == weights.end()) continue; - - shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]); - } - - if (shape15.size() != 3) continue; - - if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1]) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - node5->set_op_type("noop_reducedncnn"); - node6->set_op_type("noop_reducedncnn"); - node7->set_op_type("noop_reducedncnn"); - node8->set_op_type("noop_reducedncnn"); - node9->set_op_type("noop_reducedncnn"); - node10->set_op_type("noop_reducedncnn"); - node11->set_op_type("noop_reducedncnn"); - node12->set_op_type("noop_reducedncnn"); - node13->set_op_type("noop_reducedncnn"); - node14->set_op_type("noop_reducedncnn"); - node15->set_op_type("noop_reducedncnn"); - node16->set_op_type("noop_reducedncnn"); - - node_reference[node2->input(0)] -= 1; - node_reference[node3->input(0)] -= 1; - node_reference[node4->input(0)] -= 1; - node_reference[node4->input(1)] -= 1; - node_reference[node5->input(0)] -= 1; - if (node5->input_size() == 2) { - node_reference[node5->input(1)] -= 1; - } - node_reference[node6->input(0)] -= 1; - node_reference[node7->input(0)] -= 1; - if (node7->input_size() == 2) { - node_reference[node7->input(1)] -= 1; - } - node_reference[node8->input(0)] -= 1; - if (node8->input_size() == 2) { - node_reference[node8->input(1)] -= 1; - } - node_reference[node9->input(0)] -= 1; - node_reference[node10->input(0)] -= 1; - node_reference[node11->input(0)] -= 1; - node_reference[node11->input(1)] -= 1; - node_reference[node12->input(0)] -= 1; - node_reference[node13->input(0)] -= 1; - node_reference[node13->input(1)] -= 1; - node_reference[node14->input(0)] -= 1; - node_reference[node15->input(0)] -= 1; - if (node15->input_size() == 2) { - node_reference[node15->input(1)] -= 1; - } - node_reference[node16->input(0)] -= 1; - node_reference[node17->input(0)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - blob_names.erase(node3->output(1)); - blob_names.erase(node3->output(2)); - blob_names.erase(node4->output(0)); - blob_names.erase(node5->output(0)); - blob_names.erase(node6->output(0)); - blob_names.erase(node7->output(0)); - blob_names.erase(node8->output(0)); - blob_names.erase(node9->output(0)); - blob_names.erase(node10->output(0)); - blob_names.erase(node11->output(0)); - blob_names.erase(node12->output(0)); - blob_names.erase(node13->output(0)); - blob_names.erase(node14->output(0)); - blob_names.erase(node15->output(0)); - blob_names.erase(node16->output(0)); - - std::string qkvw = node->input(1); - std::string qkvb = node2->input(1); - std::string ow = node16->input(1); - std::string ob = node17->input(1); - - node17->set_op_type("MultiHeadAttention"); - node17->clear_input(); - node17->add_input(node->input(0)); - // qkv - node17->add_input(qkvw); - node17->add_input(qkvb); - // out linear - node17->add_input(ow); - node17->add_input(ob); - - onnx::AttributeProto* attr_embed_dim = node17->add_attribute(); - attr_embed_dim->set_name("embed_dim"); - attr_embed_dim->set_i(embed_dim); - - onnx::AttributeProto* attr_num_heads = node17->add_attribute(); - attr_num_heads->set_name("num_heads"); - attr_num_heads->set_i(num_heads); - - reduced_node_count += 16; - i += 16; + + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // MultiHeadAttention <= MatMul(qkv) - Add - Split + // - Mul + // - Reshape - Transpose + // - Reshape - Reshape - Transpose - Transpose + // - Gemm - Softmax - Gemm - Transpose - Reshape - + // MatMul - Add + if (node->op_type() == "MatMul") + { + if (i + 16 >= node_count) continue; + + if (node_reference[node->output(0)] != 1) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); + onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); + onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); + onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); + onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); + onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); + onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); + onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); + onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); + onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); + onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); + onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); + + if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || + node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || + node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || + node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || + node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || + node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || + node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || + node17->op_type() != "Add") + continue; + + if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || + node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || + node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || + node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || + node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || + node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || + node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || + node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || + node_reference[node16->output(0)] != 1) + continue; + + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || + node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || + node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || + node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || + node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || + node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || + node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || + node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || + node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0)) + continue; + + std::vector qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]); + std::vector o_B = get_node_attr_from_input_af(weights[node17->input(1)]); + + if (qkv_B.size() != o_B.size() * 3) continue; + + int embed_dim = o_B.size(); + + // 1 0 2 + std::vector perm6 = get_node_attr_ai(*node6, "perm"); + std::vector perm9 = get_node_attr_ai(*node9, "perm"); + if (perm6.size() != 3 || perm9.size() != 3) continue; + + if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || + perm9[2] != 2) + continue; + + // 1 2 0 + std::vector perm10 = get_node_attr_ai(*node10, "perm"); + if (perm10.size() != 3) continue; + + if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0) continue; + + // 1 0 2 + std::vector perm14 = get_node_attr_ai(*node14, "perm"); + if (perm14.size() != 3) continue; + + if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2) continue; + + int softmax_axis = get_node_attr_i(*node12, "axis"); + if (softmax_axis != 2) continue; + + // 1/-1, seqlen * num_heads, embed_dim / num_heads + std::vector shape5; + std::vector shape7; + std::vector shape8; + if (node5->input_size() == 1) + { + shape5 = get_node_attr_ai(*node5, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node5->input(1)) == weights.end()) continue; + + shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]); + } + if (node7->input_size() == 1) + { + shape7 = get_node_attr_ai(*node7, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node7->input(1)) == weights.end()) continue; + + shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]); + } + if (node8->input_size() == 1) + { + shape8 = get_node_attr_ai(*node8, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node8->input(1)) == weights.end()) continue; + + shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); + } + + if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3) continue; + + if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || + shape5[2] != shape8[2]) + continue; + + int num_heads = embed_dim / shape5[2]; + + // 1, seqlen, embed_dim + std::vector shape15; + if (node15->input_size() == 1) + { + shape15 = get_node_attr_ai(*node15, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node15->input(1)) == weights.end()) continue; + + shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]); + } + + if (shape15.size() != 3) continue; + + if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1]) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + node5->set_op_type("noop_reducedncnn"); + node6->set_op_type("noop_reducedncnn"); + node7->set_op_type("noop_reducedncnn"); + node8->set_op_type("noop_reducedncnn"); + node9->set_op_type("noop_reducedncnn"); + node10->set_op_type("noop_reducedncnn"); + node11->set_op_type("noop_reducedncnn"); + node12->set_op_type("noop_reducedncnn"); + node13->set_op_type("noop_reducedncnn"); + node14->set_op_type("noop_reducedncnn"); + node15->set_op_type("noop_reducedncnn"); + node16->set_op_type("noop_reducedncnn"); + + node_reference[node2->input(0)] -= 1; + node_reference[node3->input(0)] -= 1; + node_reference[node4->input(0)] -= 1; + node_reference[node4->input(1)] -= 1; + node_reference[node5->input(0)] -= 1; + if (node5->input_size() == 2) + { + node_reference[node5->input(1)] -= 1; + } + node_reference[node6->input(0)] -= 1; + node_reference[node7->input(0)] -= 1; + if (node7->input_size() == 2) + { + node_reference[node7->input(1)] -= 1; + } + node_reference[node8->input(0)] -= 1; + if (node8->input_size() == 2) + { + node_reference[node8->input(1)] -= 1; + } + node_reference[node9->input(0)] -= 1; + node_reference[node10->input(0)] -= 1; + node_reference[node11->input(0)] -= 1; + node_reference[node11->input(1)] -= 1; + node_reference[node12->input(0)] -= 1; + node_reference[node13->input(0)] -= 1; + node_reference[node13->input(1)] -= 1; + node_reference[node14->input(0)] -= 1; + node_reference[node15->input(0)] -= 1; + if (node15->input_size() == 2) + { + node_reference[node15->input(1)] -= 1; + } + node_reference[node16->input(0)] -= 1; + node_reference[node17->input(0)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + blob_names.erase(node3->output(1)); + blob_names.erase(node3->output(2)); + blob_names.erase(node4->output(0)); + blob_names.erase(node5->output(0)); + blob_names.erase(node6->output(0)); + blob_names.erase(node7->output(0)); + blob_names.erase(node8->output(0)); + blob_names.erase(node9->output(0)); + blob_names.erase(node10->output(0)); + blob_names.erase(node11->output(0)); + blob_names.erase(node12->output(0)); + blob_names.erase(node13->output(0)); + blob_names.erase(node14->output(0)); + blob_names.erase(node15->output(0)); + blob_names.erase(node16->output(0)); + + std::string qkvw = node->input(1); + std::string qkvb = node2->input(1); + std::string ow = node16->input(1); + std::string ob = node17->input(1); + + node17->set_op_type("MultiHeadAttention"); + node17->clear_input(); + node17->add_input(node->input(0)); + // qkv + node17->add_input(qkvw); + node17->add_input(qkvb); + // out linear + node17->add_input(ow); + node17->add_input(ob); + + onnx::AttributeProto* attr_embed_dim = node17->add_attribute(); + attr_embed_dim->set_name("embed_dim"); + attr_embed_dim->set_i(embed_dim); + + onnx::AttributeProto* attr_num_heads = node17->add_attribute(); + attr_num_heads->set_name("num_heads"); + attr_num_heads->set_i(num_heads); + + reduced_node_count += 16; + i += 16; + } } - } } diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.h b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.h index 31dc6f5b93..ec4575b51a 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.h +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.h @@ -4,30 +4,35 @@ #include "shape_inference.h" #include "utils.h" -void fuse_identity(onnx::GraphProto* mutable_graph, +void fuse_identity(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, +void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_weight_reshape(onnx::GraphProto* mutable_graph, +void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_shufflechannel(onnx::GraphProto* mutable_graph, +void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, +void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); /** * @brief fuse subgraph @@ -46,85 +51,104 @@ void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, * @param blob_names * @param reduced_node_count */ -void fuse_conv_reshape(onnx::GraphProto* mutable_graph, +void fuse_conv_reshape(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, +void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_hardswish(onnx::GraphProto* mutable_graph, +void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, +void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, +void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, +void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_normalize(onnx::GraphProto* mutable_graph, +void fuse_normalize(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_groupnorm(onnx::GraphProto* mutable_graph, +void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_layernorm(onnx::GraphProto* mutable_graph, +void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_flatten(onnx::GraphProto* mutable_graph, +void fuse_flatten(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, +void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); +void fuse_reorg(onnx::GraphProto* mutable_graph, + std::map& weights, + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, +void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, +void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_multiheadattention(onnx::GraphProto* mutable_graph, +void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_weight_transpose(onnx::GraphProto* mutable_graph, +void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); - -void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); + +void fuse_swish(onnx::GraphProto* mutable_graph, + std::map& weights, + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp index ca8cd628ad..bc38599b63 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -26,2719 +26,3551 @@ #include "shape_inference.h" #include "utils.h" -int main(int argc, char** argv) { - if (!(argc == 2 || argc == 4)) { - fprintf(stderr, "Usage: %s [onnxpb] [ncnnparam] [ncnnbin]\n", argv[0]); - return -1; - } - - const char* onnxpb = argv[1]; - const char* ncnn_prototxt = argc == 4 ? argv[2] : "ncnn.param"; - const char* ncnn_modelbin = argc == 4 ? argv[3] : "ncnn.bin"; - - onnx::ModelProto model; - - // load - bool s1 = read_proto_from_binary(onnxpb, &model); - if (!s1) { - fprintf(stderr, "read_proto_from_binary failed\n"); - return -1; - } - FILE* pp = fopen(ncnn_prototxt, "wb"); - FILE* bp = fopen(ncnn_modelbin, "wb"); - // magic - fprintf(pp, "7767517\n"); - onnx::GraphProto* mutable_graph = model.mutable_graph(); - int node_count = mutable_graph->node_size(); - - // node reference - std::map node_reference; - - // weight node and weight reshape node - std::map weights; - for (int j = 0; j < mutable_graph->initializer_size(); j++) { - const onnx::TensorProto& initializer = mutable_graph->initializer(j); - - // fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), - // initializer.data_type()); - - weights[initializer.name()] = initializer; - } - // topological sort - { - // name -> producer node index - std::set producers; - for (int j = 0; j < mutable_graph->input_size(); j++) { - const std::string& input_name = mutable_graph->input(j).name(); - producers.insert(input_name); +int main(int argc, char** argv) +{ + if (!(argc == 2 || argc == 4)) + { + fprintf(stderr, "Usage: %s [onnxpb] [ncnnparam] [ncnnbin]\n", argv[0]); + return -1; } - for (int i = 0; i < node_count;) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + const char* onnxpb = argv[1]; + const char* ncnn_prototxt = argc == 4 ? argv[2] : "ncnn.param"; + const char* ncnn_modelbin = argc == 4 ? argv[3] : "ncnn.bin"; - bool swapnode = false; - std::string missing_input_name; - for (int j = 0; j < (int)node->input_size(); j++) { - const std::string& input_name = node->input(j); - if (input_name.empty()) continue; + onnx::ModelProto model; - if (producers.find(input_name) == producers.end() && - weights.find(input_name) == weights.end()) { - swapnode = true; - missing_input_name = input_name; - break; - } - } + // load + bool s1 = read_proto_from_binary(onnxpb, &model); + if (!s1) + { + fprintf(stderr, "read_proto_from_binary failed\n"); + return -1; + } + FILE* pp = fopen(ncnn_prototxt, "wb"); + FILE* bp = fopen(ncnn_modelbin, "wb"); + // magic + fprintf(pp, "7767517\n"); + onnx::GraphProto* mutable_graph = model.mutable_graph(); + int node_count = mutable_graph->node_size(); + + // node reference + std::map node_reference; + + // weight node and weight reshape node + std::map weights; + for (int j = 0; j < mutable_graph->initializer_size(); j++) + { + const onnx::TensorProto& initializer = mutable_graph->initializer(j); - if (!swapnode) { - for (int j = 0; j < (int)node->output_size(); j++) { - const std::string& output_name = node->output(j); - if (output_name.empty()) continue; + // fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), + // initializer.data_type()); - producers.insert(output_name); + weights[initializer.name()] = initializer; + } + // topological sort + { + // name -> producer node index + std::set producers; + for (int j = 0; j < mutable_graph->input_size(); j++) + { + const std::string& input_name = mutable_graph->input(j).name(); + producers.insert(input_name); } - i++; - continue; - } - - // find node that produce missing_input_name - int q = i + 1; - for (; q < node_count; q++) { - onnx::NodeProto* nodeq = mutable_graph->mutable_node(q); - bool found = false; - for (int j = 0; j < (int)nodeq->output_size(); j++) { - const std::string& output_name = nodeq->output(j); - if (output_name == missing_input_name) { - found = true; - break; - } - } + for (int i = 0; i < node_count;) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + bool swapnode = false; + std::string missing_input_name; + for (int j = 0; j < (int)node->input_size(); j++) + { + const std::string& input_name = node->input(j); + if (input_name.empty()) continue; + + if (producers.find(input_name) == producers.end() && + weights.find(input_name) == weights.end()) + { + swapnode = true; + missing_input_name = input_name; + break; + } + } - if (found) break; - } + if (!swapnode) + { + for (int j = 0; j < (int)node->output_size(); j++) + { + const std::string& output_name = node->output(j); + if (output_name.empty()) continue; - if (q == node_count) { - fprintf(stderr, "cannot find node produces %s but node %d requires it\n", - missing_input_name.c_str(), i); - return -1; - } - - // fprintf(stderr, "swap %d %d\n", i, q); - // swap this node with q - onnx::NodeProto* nodeq = mutable_graph->mutable_node(q); - onnx::NodeProto tmp = *node; - *node = *nodeq; - *nodeq = tmp; - } - } - // global definition line - // [layer count] [blob count] - std::set blob_names; - for (int i = 0; i < node_count; i++) { - const onnx::NodeProto& node = mutable_graph->node(i); - - const std::string& op = node.op_type(); - - std::string name = node.name(); - if (name.empty()) { - name = node.output(0); - } + producers.insert(output_name); + } - if (op == "Constant") { - onnx::TensorProto tensor = get_node_attr_tensor(node, "value"); - weights[node.output(0)] = tensor; - } + i++; + continue; + } - for (int j = 0; j < (int)node.input_size(); j++) { - const std::string& input_name = node.input(j); + // find node that produce missing_input_name + int q = i + 1; + for (; q < node_count; q++) + { + onnx::NodeProto* nodeq = mutable_graph->mutable_node(q); + bool found = false; + for (int j = 0; j < (int)nodeq->output_size(); j++) + { + const std::string& output_name = nodeq->output(j); + if (output_name == missing_input_name) + { + found = true; + break; + } + } + + if (found) break; + } - blob_names.insert(input_name); + if (q == node_count) + { + fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i); + return -1; + } - if (node_reference.find(input_name) == node_reference.end()) { - node_reference[input_name] = 1; - } else { - node_reference[input_name] = node_reference[input_name] + 1; - } + // fprintf(stderr, "swap %d %d\n", i, q); + // swap this node with q + onnx::NodeProto* nodeq = mutable_graph->mutable_node(q); + onnx::NodeProto tmp = *node; + *node = *nodeq; + *nodeq = tmp; + } } + // global definition line + // [layer count] [blob count] + std::set blob_names; + for (int i = 0; i < node_count; i++) + { + const onnx::NodeProto& node = mutable_graph->node(i); - if (op == "Dropout") { - const std::string& output_name = node.output(0); - blob_names.insert(output_name); - node_reference[output_name] = 0; - continue; - } + const std::string& op = node.op_type(); - for (int j = 0; j < (int)node.output_size(); j++) { - const std::string& output_name = node.output(j); + std::string name = node.name(); + if (name.empty()) + { + name = node.output(0); + } - blob_names.insert(output_name); + if (op == "Constant") + { + onnx::TensorProto tensor = get_node_attr_tensor(node, "value"); + weights[node.output(0)] = tensor; + } - node_reference[output_name] = 0; - } - } - // include Input node - int input_node_count = 0; - for (int j = 0; j < mutable_graph->input_size(); j++) { - const std::string& input_name = mutable_graph->input(j).name(); - - // check weight - if (weights.find(input_name) != weights.end()) continue; - - blob_names.insert(input_name); - - input_node_count++; - } - - // for (auto a: node_reference) - // { - // fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second); - // } - - // op chain fusion - int reduced_node_count = 0; - { - fuse_identity(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, - reduced_node_count); - fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, - reduced_node_count); - fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, - reduced_node_count); - fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - } - // reduce common const weight node_reference - for (int i = 0; i < node_count; i++) { - const onnx::NodeProto& node = mutable_graph->node(i); - - const std::string& op = node.op_type(); - - if (op == "BatchNormalization") { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - } else if (op == "BiasGelu") { - node_reference[node.input(1)] -= 1; - } else if (op == "Clip") { - if (node.input_size() == 3) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - } - } else if (op == "Conv") { - node_reference[node.input(1)] -= 1; - if (node.input_size() == 3) { - node_reference[node.input(2)] -= 1; - } - } else if (op == "ConvTranspose") { - node_reference[node.input(1)] -= 1; - if (node.input_size() == 3) { - node_reference[node.input(2)] -= 1; - } - } else if (op == "EmbedLayerNormalization") { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - node_reference[node.input(5)] -= 1; - node_reference[node.input(6)] -= 1; - } else if (op == "Gemm") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - float beta = get_node_attr_f(node, "beta", 1.f); - int transA = get_node_attr_i(node, "transA", 0); - int transB = get_node_attr_i(node, "transB", 0); - - if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) { - // InnerProduct-like A * B + C, C is optional. - node_reference[node.input(1)] -= 1; - if (node.input_size() == 3) { - node_reference[node.input(2)] -= 1; - } - } - } else if (op == "GroupNorm") { - int affine = get_node_attr_i(node, "affine", 1); - if (affine) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - } - } else if (op == "GRU") { - for (int j = 1; j < node.input_size(); j++) { - node_reference[node.input(j)] -= 1; - } - } else if (op == "InstanceNormalization") { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - } else if (op == "LayerNorm") { - int affine = get_node_attr_i(node, "affine", 1); - if (affine) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - } - } else if (op == "LSTM") { - for (int j = 1; j < node.input_size(); j++) { - node_reference[node.input(j)] -= 1; - } - } else if (op == "MatMul") { - if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) { - // InnerProduct - node_reference[node.input(1)] -= 1; - } - } else if (op == "MultiHeadAttention") { - if (node.input_size() == 5) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - } else { - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - node_reference[node.input(5)] -= 1; - node_reference[node.input(6)] -= 1; - node_reference[node.input(7)] -= 1; - node_reference[node.input(8)] -= 1; - node_reference[node.input(9)] -= 1; - node_reference[node.input(10)] -= 1; - } - } else if (op == "NonMaxSuppression") { - if (node.input_size() >= 3) { - node_reference[node.input(2)] -= 1; - } - if (node.input_size() >= 4) { - node_reference[node.input(3)] -= 1; - } - if (node.input_size() >= 5) { - node_reference[node.input(4)] -= 1; - } - } else if (op == "Pad") { - if (node.input_size() >= 2) { - node_reference[node.input(1)] -= 1; - } - } else if (op == "PRelu") { - node_reference[node.input(1)] -= 1; - } else if (op == "Reshape") { - if (node.input_size() == 2) { - if (weights[node.input(1)].data_type() != 0) { - node_reference[node.input(1)] -= 1; - } - } - } else if (op == "Resize") { - if (node.input_size() == 2) { - // opset 10 - node_reference[node.input(1)] -= 1; - } else { - // opset 11+ - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - if (node.input_size() >= 4) { - node_reference[node.input(3)] -= 1; - } - } - } else if (op == "RNN") { - for (int j = 1; j < node.input_size(); j++) { - node_reference[node.input(j)] -= 1; - } - } else if (op == "SkipLayerNormalization") { - node_reference[node.input(2)] -= 1; - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - } else if (op == "Slice") { - if (node.input_size() >= 2) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - if (node.input_size() >= 4) node_reference[node.input(3)] -= 1; - if (node.input_size() >= 5) node_reference[node.input(4)] -= 1; - } - } else if (op == "Upsample") { - if (node.input_size() >= 2) { - node_reference[node.input(1)] -= 1; - } - } else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || - op == "adaptive_max_pool2d") { - if (node.input_size() >= 2) { - node_reference[node.input(1)] -= 1; - } - } - } + for (int j = 0; j < (int)node.input_size(); j++) + { + const std::string& input_name = node.input(j); - // for (auto a: node_reference) - // { - // fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second); - // } + blob_names.insert(input_name); - // count all weight node with zero reference - int zero_reference_weight_node_count = 0; - for (std::map::iterator it = weights.begin(); it != weights.end(); - it++) { - const std::string& input_name = it->first; + if (node_reference.find(input_name) == node_reference.end()) + { + node_reference[input_name] = 1; + } + else + { + node_reference[input_name] = node_reference[input_name] + 1; + } + } - int refcount = node_reference[input_name]; - if (refcount == 0) zero_reference_weight_node_count++; - } + if (op == "Dropout") + { + const std::string& output_name = node.output(0); + blob_names.insert(output_name); + node_reference[output_name] = 0; + continue; + } - // we always treat constant node as weight or binaryop_weights - // do not count it twice for layer_count - int constant_node_count_moved_to_weight = 0; - for (int i = 0; i < node_count; i++) { - const onnx::NodeProto& node = mutable_graph->node(i); + for (int j = 0; j < (int)node.output_size(); j++) + { + const std::string& output_name = node.output(j); - const std::string& op = node.op_type(); + blob_names.insert(output_name); - if (op == "Constant") { - constant_node_count_moved_to_weight++; - } - } - - // some op may have anonymous input - // LSTM sequence_lens - blob_names.erase(""); - node_reference.erase(""); - - // remove node_reference entry with reference equals to one - int split_layer_count = 0; - int splitncnn_blob_count = 0; - // split node reference - std::map split_node_reference; - for (std::map::iterator it = node_reference.begin(); it != node_reference.end(); - it++) { - if (it->second > 1) { - split_layer_count++; - splitncnn_blob_count += it->second; - - split_node_reference[it->first] = it->second; + node_reference[output_name] = 0; + } } - } - - fprintf(pp, "%zu %zu\n", - node_count - constant_node_count_moved_to_weight + weights.size() - - zero_reference_weight_node_count - reduced_node_count + input_node_count + - split_layer_count, - blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count); - - int internal_split = 0; - - // place Input at the beginning - for (int j = 0; j < mutable_graph->input_size(); j++) { - const std::string& input_name = mutable_graph->input(j).name(); + // include Input node + int input_node_count = 0; + for (int j = 0; j < mutable_graph->input_size(); j++) + { + const std::string& input_name = mutable_graph->input(j).name(); - // check weight - if (weights.find(input_name) != weights.end()) continue; + // check weight + if (weights.find(input_name) != weights.end()) continue; - fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str()); + blob_names.insert(input_name); - int refcount = node_reference[input_name]; - if (refcount <= 1) { - continue; + input_node_count++; } - char splitname[256]; - sprintf(splitname, "splitncnn_input%d", j); - fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); - fprintf(pp, " %s", input_name.c_str()); + // for (auto a: node_reference) + // { + // fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second); + // } - for (int k = 0; k < refcount; k++) { - fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k); + // op chain fusion + int reduced_node_count = 0; + { + fuse_identity(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count); } - fprintf(pp, "\n"); - } + // reduce common const weight node_reference + for (int i = 0; i < node_count; i++) + { + const onnx::NodeProto& node = mutable_graph->node(i); - // place MemoryData next - for (std::map::iterator weight_it = weights.begin(); - weight_it != weights.end(); weight_it++) { - const std::string& input_name = weight_it->first; + const std::string& op = node.op_type(); - int refcount = node_reference[input_name]; - if (refcount == 0) { - continue; + if (op == "BatchNormalization") + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + } + else if (op == "BiasGelu") + { + node_reference[node.input(1)] -= 1; + } + else if (op == "Clip") + { + if (node.input_size() == 3) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + } + else if (op == "Conv") + { + node_reference[node.input(1)] -= 1; + if (node.input_size() == 3) + { + node_reference[node.input(2)] -= 1; + } + } + else if (op == "ConvTranspose") + { + node_reference[node.input(1)] -= 1; + if (node.input_size() == 3) + { + node_reference[node.input(2)] -= 1; + } + } + else if (op == "EmbedLayerNormalization") + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + node_reference[node.input(5)] -= 1; + node_reference[node.input(6)] -= 1; + } + else if (op == "Gemm") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + float beta = get_node_attr_f(node, "beta", 1.f); + int transA = get_node_attr_i(node, "transA", 0); + int transB = get_node_attr_i(node, "transB", 0); + + if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) + { + // InnerProduct-like A * B + C, C is optional. + node_reference[node.input(1)] -= 1; + if (node.input_size() == 3) + { + node_reference[node.input(2)] -= 1; + } + } + } + else if (op == "GroupNorm") + { + int affine = get_node_attr_i(node, "affine", 1); + if (affine) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + } + else if (op == "GRU") + { + for (int j = 1; j < node.input_size(); j++) + { + node_reference[node.input(j)] -= 1; + } + } + else if (op == "InstanceNormalization") + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + else if (op == "LayerNorm") + { + int affine = get_node_attr_i(node, "affine", 1); + if (affine) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + } + else if (op == "LSTM") + { + for (int j = 1; j < node.input_size(); j++) + { + node_reference[node.input(j)] -= 1; + } + } + else if (op == "MatMul") + { + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) + { + // InnerProduct + node_reference[node.input(1)] -= 1; + } + } + else if (op == "MultiHeadAttention") + { + if (node.input_size() == 5) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + } + else + { + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + node_reference[node.input(5)] -= 1; + node_reference[node.input(6)] -= 1; + node_reference[node.input(7)] -= 1; + node_reference[node.input(8)] -= 1; + node_reference[node.input(9)] -= 1; + node_reference[node.input(10)] -= 1; + } + } + else if (op == "NonMaxSuppression") + { + if (node.input_size() >= 3) + { + node_reference[node.input(2)] -= 1; + } + if (node.input_size() >= 4) + { + node_reference[node.input(3)] -= 1; + } + if (node.input_size() >= 5) + { + node_reference[node.input(4)] -= 1; + } + } + else if (op == "Pad") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + } + } + else if (op == "PRelu") + { + node_reference[node.input(1)] -= 1; + } + else if (op == "Reshape") + { + if (node.input_size() == 2) + { + if (weights[node.input(1)].data_type() != 0) + { + node_reference[node.input(1)] -= 1; + } + } + } + else if (op == "Resize") + { + if (node.input_size() == 2) + { + // opset 10 + node_reference[node.input(1)] -= 1; + } + else + { + // opset 11+ + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + if (node.input_size() >= 4) + { + node_reference[node.input(3)] -= 1; + } + } + } + else if (op == "RNN") + { + for (int j = 1; j < node.input_size(); j++) + { + node_reference[node.input(j)] -= 1; + } + } + else if (op == "SkipLayerNormalization") + { + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + } + else if (op == "Slice") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + if (node.input_size() >= 4) node_reference[node.input(3)] -= 1; + if (node.input_size() >= 5) node_reference[node.input(4)] -= 1; + } + } + else if (op == "Upsample") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + } + } + else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || + op == "adaptive_max_pool2d") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + } + } } - fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str()); - - const onnx::TensorProto& M = weights[input_name]; - - if (M.dims_size() == 0) { - fprintf(pp, " 0=%d", get_tensor_proto_data_size(M)); - } else if (M.dims_size() == 1) { - fprintf(pp, " 0=%d", (int)M.dims(0)); - } else if (M.dims_size() == 2) { - fprintf(pp, " 0=%d", (int)M.dims(1)); - if (M.dims(0) != 1) { - fprintf(pp, " 1=%d", (int)M.dims(0)); - } - } else if (M.dims_size() == 3) { - fprintf(pp, " 0=%d", (int)M.dims(2)); - fprintf(pp, " 1=%d", (int)M.dims(1)); - if (M.dims(0) != 1) { - fprintf(pp, " 2=%d", (int)M.dims(0)); - } - } else if (M.dims_size() == 4) { - fprintf(pp, " 0=%d", (int)M.dims(3)); - fprintf(pp, " 1=%d", (int)M.dims(2)); - fprintf(pp, " 2=%d", (int)M.dims(1)); - } + // for (auto a: node_reference) + // { + // fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second); + // } - fprintf(pp, "\n"); - if (M.data_type() == 1) { - fwrite_tensor_proto_data(M, bp); - } else if (M.data_type() == 7 || M.data_type() == 6 || M.data_type() == 9 || - M.data_type() == 11) { - fwrite_tensor_proto_data_to_float(M, bp); - } else { - fwrite_tensor_proto_data(M, bp); - } + // count all weight node with zero reference + int zero_reference_weight_node_count = 0; + for (std::map::iterator it = weights.begin(); it != weights.end(); + it++) + { + const std::string& input_name = it->first; - if (refcount <= 1) { - continue; + int refcount = node_reference[input_name]; + if (refcount == 0) zero_reference_weight_node_count++; } - char splitname[256]; - sprintf(splitname, "splitncnn_%d", internal_split); - fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); + // we always treat constant node as weight or binaryop_weights + // do not count it twice for layer_count + int constant_node_count_moved_to_weight = 0; + for (int i = 0; i < node_count; i++) + { + const onnx::NodeProto& node = mutable_graph->node(i); - fprintf(pp, " %s", input_name.c_str()); + const std::string& op = node.op_type(); - for (int k = 0; k < refcount; k++) { - fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k); + if (op == "Constant") + { + constant_node_count_moved_to_weight++; + } } - fprintf(pp, "\n"); - internal_split++; - } + // some op may have anonymous input + // LSTM sequence_lens + blob_names.erase(""); + node_reference.erase(""); + + // remove node_reference entry with reference equals to one + int split_layer_count = 0; + int splitncnn_blob_count = 0; + // split node reference + std::map split_node_reference; + for (std::map::iterator it = node_reference.begin(); it != node_reference.end(); + it++) + { + if (it->second > 1) + { + split_layer_count++; + splitncnn_blob_count += it->second; - for (int i = 0; i < node_count; i++) { - const onnx::NodeProto& node = mutable_graph->node(i); - const std::string& op = node.op_type(); + split_node_reference[it->first] = it->second; + } + } - // fprintf(stderr, "op = %s\n", op.c_str()); + fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count); - if (op == "noop_reducedncnn") { - continue; - } + int internal_split = 0; - std::string name = node.name(); - if (name.empty()) { - name = node.output(0); - } + // place Input at the beginning + for (int j = 0; j < mutable_graph->input_size(); j++) + { + const std::string& input_name = mutable_graph->input(j).name(); - int input_size = node.input_size(); - int output_size = node.output_size(); + // check weight + if (weights.find(input_name) != weights.end()) continue; - for (int j = 0; j < (int)node.input_size(); j++) { - const std::string& input_name = node.input(j); + fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str()); - // check weight - if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) { - input_size--; - } + int refcount = node_reference[input_name]; + if (refcount <= 1) + { + continue; + } - if (input_name.empty()) { - input_size--; - } + char splitname[256]; + sprintf(splitname, "splitncnn_input%d", j); + fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); + fprintf(pp, " %s", input_name.c_str()); - // fprintf(stderr, " input = %s\n", input_name.c_str()); + for (int k = 0; k < refcount; k++) + { + fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k); + } + fprintf(pp, "\n"); } - /* - for (int j=0; j<(int)node.output_size(); j++) + + // place MemoryData next + for (std::map::iterator weight_it = weights.begin(); + weight_it != weights.end(); + weight_it++) { - const std::string& output_name = node.output(j); - fprintf(stderr, " output = %s\n", output_name.c_str()); - } - */ - - if (op == "Abs") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Acos") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Add") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "ArgMax") { - fprintf(pp, "%-16s", "TopK"); - } else if (op == "Asin") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Atan") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "AveragePool" || op == "MaxPool") { - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - if (kernel_shape.size() == 1) { - fprintf(pp, "%-16s", "Pooling1D"); - } else { - fprintf(pp, "%-16s", "Pooling"); - } - } else if (op == "BatchNormalization") { - fprintf(pp, "%-16s", "BatchNorm"); - } else if (op == "BiasGelu") { - fprintf(pp, "%-16s", "BiasGelu"); - } else if (op == "Cast") { - fprintf(pp, "%-16s", "Noop"); - } else if (op == "Ceil") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Clip") { - fprintf(pp, "%-16s", "Clip"); - } else if (op == "Concat") { - fprintf(pp, "%-16s", "Concat"); - } else if (op == "Constant") { - continue; - } else if (op == "ConstantOfShape") { - fprintf(pp, "%-16s", "ConstantOfShape"); - } else if (op == "Conv") { - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - if (kernel_shape.size() == 1) { - fprintf(pp, "%-16s", "Convolution1D"); - } else { - int group = get_node_attr_i(node, "group", 1); - if (group > 1) { - fprintf(pp, "%-16s", "ConvolutionDepthWise"); - } else { - fprintf(pp, "%-16s", "Convolution"); - } - } - } else if (op == "ConvTranspose") { - int group = get_node_attr_i(node, "group", 1); - if (group > 1) { - fprintf(pp, "%-16s", "DeconvolutionDepthWise"); - } else { - fprintf(pp, "%-16s", "Deconvolution"); - } - } else if (op == "Cos") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Crop") { - fprintf(pp, "%-16s", "Crop"); - } else if (op == "DepthToSpace") { - fprintf(pp, "%-16s", "PixelShuffle"); - } else if (op == "DetectionOutput") { - fprintf(pp, "%-16s", "DetectionOutput"); - } else if (op == "Div") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "Dropout") { - fprintf(pp, "%-16s", "Dropout"); - output_size = 1; - } else if (op == "Elu") { - fprintf(pp, "%-16s", "ELU"); - } else if (op == "EmbedLayerNormalization") { - fprintf(pp, "%-16s", "EmbedLayerNormalization"); - } else if (op == "Equal") { - fprintf(pp, "%-16s", "Compare"); - } else if (op == "Exp") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Expand") { - fprintf(pp, "%-16s", "Expand"); - } else if (op == "Flatten") { - fprintf(pp, "%-16s", "Flatten"); - } else if (op == "Floor") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Gather") { - fprintf(pp, "%-16s", "Gather"); - } else if (op == "Gelu") { - fprintf(pp, "%-16s", "GELU"); - } else if (op == "Gemm") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - float beta = get_node_attr_f(node, "beta", 1.f); - int transA = get_node_attr_i(node, "transA", 0); - int transB = get_node_attr_i(node, "transB", 0); - - if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) { - // InnerProduct-like A * B + C - fprintf(pp, "%-16s", "InnerProduct"); - } else { - fprintf(pp, "%-16s", "Gemm"); - } - } else if (op == "GlobalAveragePool") { - fprintf(pp, "%-16s", "Pooling"); - } else if (op == "GlobalMaxPool") { - fprintf(pp, "%-16s", "Pooling"); - } else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || - op == "adaptive_max_pool2d") { - fprintf(pp, "%-16s", "Pooling"); - } else if (op == "GroupNorm") { - fprintf(pp, "%-16s", "GroupNorm"); - } else if (op == "GRU") { - fprintf(pp, "%-16s", "GRU"); - } else if (op == "HardSigmoid") { - fprintf(pp, "%-16s", "HardSigmoid"); - } else if (op == "HardSwish") { - fprintf(pp, "%-16s", "HardSwish"); - } else if (op == "ImageScaler") { - fprintf(pp, "%-16s", "Scale"); - } else if (op == "InstanceNormalization") { - fprintf(pp, "%-16s", "InstanceNorm"); - } else if (op == "LayerNorm") { - fprintf(pp, "%-16s", "LayerNorm"); - } else if (op == "LeakyRelu") { - fprintf(pp, "%-16s", "ReLU"); - } else if (op == "Threshold") { - fprintf(pp, "%-16s", "Threshold"); - } else if (op == "Log") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "LRN") { - fprintf(pp, "%-16s", "LRN"); - } else if (op == "LSTM") { - fprintf(pp, "%-16s", "LSTM"); - } else if (op == "MatMul") { - if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) { - fprintf(pp, "%-16s", "InnerProduct"); - } else { - fprintf(pp, "%-16s", "Gemm"); - } - } else if (op == "Max") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "Min") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "Mul") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "MultiHeadAttention") { - fprintf(pp, "%-16s", "MultiHeadAttention"); - } else if (op == "Neg") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "NonMaxSuppression") { - fprintf(pp, "%-16s", "NonMaxSuppression"); - } else if (op == "Normalize") { - fprintf(pp, "%-16s", "Normalize"); - } else if (op == "Pad") { - fprintf(pp, "%-16s", "Padding"); - } else if (op == "PixelShuffle") { - fprintf(pp, "%-16s", "PixelShuffle"); - } else if (op == "Pow") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "PriorBox") { - fprintf(pp, "%-16s", "PriorBox"); - } else if (op == "PRelu") { - fprintf(pp, "%-16s", "PReLU"); - } else if (op == "Range") { - fprintf(pp, "%-16s", "Range"); - } else if (op == "Reciprocal") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || - op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || - op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") { - fprintf(pp, "%-16s", "Reduction"); - } else if (op == "Relu") { - fprintf(pp, "%-16s", "ReLU"); - } else if (op == "Reorg") { - fprintf(pp, "%-16s", "Reorg"); - } else if (op == "Reshape") { - fprintf(pp, "%-16s", "Reshape"); - } else if (op == "RNN") { - fprintf(pp, "%-16s", "RNN"); - } else if (op == "RDiv") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "RSub") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "RoiAlign") { - fprintf(pp, "%-16s", "ROIAlign"); - } else if (op == "ScatterND") { - fprintf(pp, "%-16s", "ScatterND"); - } else if (op == "Shape") { - fprintf(pp, "%-16s", "Shape"); - } else if (op == "ShuffleChannel") { - fprintf(pp, "%-16s", "ShuffleChannel"); - } else if (op == "Sigmoid") { - fprintf(pp, "%-16s", "Sigmoid"); - } else if (op == "Sin") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "SkipLayerNormalization") { - fprintf(pp, "%-16s", "SkipLayerNormalization"); - } else if (op == "Slice") { - std::vector ends; - std::vector steps; - bool use_crop = true; - - if (node.input_size() == 1) { - ends = get_node_attr_ai(node, "ends"); - steps = get_node_attr_ai(node, "steps"); // TODO - } else { - ends = get_node_attr_from_input_ai(weights[node.input(2)]); - if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]); - } - - // assert step == 1 - for (int i = 0; i < (int)steps.size(); i++) { - if (steps[i] != 1 && steps[i] < ends[i]) { - use_crop = false; - break; - } - } - - if (use_crop) { - fprintf(pp, "%-16s", "Crop"); - } else { - fprintf(pp, "%-16s", "TensorSlice"); - } - } else if (op == "Softmax") { - fprintf(pp, "%-16s", "Softmax"); - } else if (op == "Softplus") { - fprintf(pp, "%-16s", "Softplus"); - } else if (op == "Split") { - fprintf(pp, "%-16s", "Slice"); - } else if (op == "Sqrt") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Squeeze") { - std::vector axes = get_node_attr_ai(node, "axes"); - // fprintf(stderr, "axes[0]: %d\n",axes[0]); - if (axes[0] == 0) { - fprintf(pp, "%-16s", "Noop"); - } else { - fprintf(pp, "%-16s", "Squeeze"); - } - } else if (op == "Sub") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "Sum") { - fprintf(pp, "%-16s", "Eltwise"); - } else if (op == "Swish") { - fprintf(pp, "%-16s", "Swish"); - } else if (op == "Tan") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Tanh") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Tile") { - fprintf(pp, "%-16s", "TileOnnx"); - } else if (op == "TopK") { - fprintf(pp, "%-16s", "TopK"); - } else if (op == "Transpose") { - fprintf(pp, "%-16s", "Permute"); - } else if (op == "Upsample" || op == "Resize") { - fprintf(pp, "%-16s", "Interp"); - } else if (op == "Unsqueeze") { - std::vector axes = get_node_attr_ai(node, "axes"); - // fprintf(stderr, "axes[0]: %d\n",axes[0]); - if (axes[0] == 0) { - fprintf(pp, "%-16s", "Noop"); - } else { - fprintf(pp, "%-16s", "ExpandDims"); - } - } else if (op == "Where") { - fprintf(pp, "%-16s", "Where"); - } else if (op == "Yolov3DetectionOutput") { - fprintf(pp, "%-16s", "Yolov3DetectionOutput"); - } else { - // TODO - fprintf(stderr, "%s not supported yet!\n", op.c_str()); - fprintf(pp, "%-16s", op.c_str()); - } + const std::string& input_name = weight_it->first; - fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size); + int refcount = node_reference[input_name]; + if (refcount == 0) + { + continue; + } - for (int j = 0; j < (int)node.input_size(); j++) { - std::string input_name = node.input(j); + fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str()); - // check weight - if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) { - continue; - } + const onnx::TensorProto& M = weights[input_name]; - if (input_name.empty()) { - continue; - } + if (M.dims_size() == 0) + { + fprintf(pp, " 0=%d", get_tensor_proto_data_size(M)); + } + else if (M.dims_size() == 1) + { + fprintf(pp, " 0=%d", (int)M.dims(0)); + } + else if (M.dims_size() == 2) + { + fprintf(pp, " 0=%d", (int)M.dims(1)); + if (M.dims(0) != 1) + { + fprintf(pp, " 1=%d", (int)M.dims(0)); + } + } + else if (M.dims_size() == 3) + { + fprintf(pp, " 0=%d", (int)M.dims(2)); + fprintf(pp, " 1=%d", (int)M.dims(1)); + if (M.dims(0) != 1) + { + fprintf(pp, " 2=%d", (int)M.dims(0)); + } + } + else if (M.dims_size() == 4) + { + fprintf(pp, " 0=%d", (int)M.dims(3)); + fprintf(pp, " 1=%d", (int)M.dims(2)); + fprintf(pp, " 2=%d", (int)M.dims(1)); + } - if (split_node_reference.find(input_name) != split_node_reference.end()) { - int refidx = split_node_reference[input_name] - 1; - split_node_reference[input_name] = refidx; + fprintf(pp, "\n"); + if (M.data_type() == 1) + { + fwrite_tensor_proto_data(M, bp); + } + else if (M.data_type() == 7 || M.data_type() == 6 || M.data_type() == 9 || + M.data_type() == 11) + { + fwrite_tensor_proto_data_to_float(M, bp); + } + else + { + fwrite_tensor_proto_data(M, bp); + } - char splitsuffix[256]; - sprintf(splitsuffix, "_splitncnn_%d", refidx); - input_name = input_name + splitsuffix; - } + if (refcount <= 1) + { + continue; + } - fprintf(pp, " %s", input_name.c_str()); - } + char splitname[256]; + sprintf(splitname, "splitncnn_%d", internal_split); + fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); - for (int j = 0; j < output_size; j++) { - const std::string& output_name = node.output(j); + fprintf(pp, " %s", input_name.c_str()); - fprintf(pp, " %s", output_name.c_str()); - } + for (int k = 0; k < refcount; k++) + { + fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k); + } + fprintf(pp, "\n"); - if (op == "Abs") { - int op_type = 0; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Acos") { - int op_type = 13; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Add") { - int op_type = 0; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "ArgMax") { - int axis = get_node_attr_i(node, "axis"); - int keepdims = get_node_attr_i(node, "keepdims"); - fprintf(pp, " 0=%d", axis - 1); - fprintf(pp, " 3=%d", keepdims); - } else if (op == "Asin") { - int op_type = 12; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Atan") { - int op_type = 14; - fprintf(pp, " 0=%d", op_type); - } else if (op == "AveragePool" || op == "MaxPool") { - std::string auto_pad = get_node_attr_s(node, "auto_pad"); - int ceil_mode = get_node_attr_i(node, "ceil_mode", 0); - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - std::vector strides = get_node_attr_ai(node, "strides"); - std::vector pads = get_node_attr_ai(node, "pads"); - - int pool = op == "AveragePool" ? 1 : 0; - int pad_mode = 1; - - if (auto_pad == "SAME_UPPER") { - pad_mode = 2; - } else if (auto_pad == "SAME_LOWER") { - pad_mode = 3; - } - - if (ceil_mode == 1) { - pad_mode = 0; - } - - fprintf(pp, " 0=%d", pool); - - if (kernel_shape.size() == 1) { - fprintf(pp, " 1=%d", kernel_shape[0]); - } else if (kernel_shape.size() == 2) { - fprintf(pp, " 1=%d", kernel_shape[1]); - fprintf(pp, " 11=%d", kernel_shape[0]); - } - - if (strides.size() == 1) { - fprintf(pp, " 2=%d", strides[0]); - } else if (strides.size() == 2) { - fprintf(pp, " 2=%d", strides[1]); - fprintf(pp, " 12=%d", strides[0]); - } - - if (pads.size() == 1) { - fprintf(pp, " 3=%d", pads[0]); - } else if (pads.size() == 2) { - fprintf(pp, " 3=%d", pads[1]); - fprintf(pp, " 13=%d", pads[0]); - } else if (pads.size() == 4) { - fprintf(pp, " 3=%d", pads[1]); - fprintf(pp, " 13=%d", pads[0]); - fprintf(pp, " 14=%d", pads[3]); - fprintf(pp, " 15=%d", pads[2]); - } - - fprintf(pp, " 5=%d", pad_mode); - - if (op == "AveragePool") { - int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0); - fprintf(pp, " 6=%d", avgpool_count_include_pad); - } - } else if (op == "BatchNormalization") { - float epsilon = get_node_attr_f(node, "epsilon", 1e-5f); - - const onnx::TensorProto& scale = weights[node.input(1)]; - const onnx::TensorProto& B = weights[node.input(2)]; - const onnx::TensorProto& mean = weights[node.input(3)]; - const onnx::TensorProto& var = weights[node.input(4)]; - - int channels = get_tensor_proto_data_size(scale); - - fprintf(pp, " 0=%d", channels); - - fwrite_tensor_proto_data(scale, bp); - fwrite_tensor_proto_data(mean, bp); - // apply epsilon to var - { - const float* v = - var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data(); - - for (int j = 0; j < channels; j++) { - float ve = v[j] + epsilon; - fwrite(&ve, sizeof(float), 1, bp); - } - } - fwrite_tensor_proto_data(B, bp); - } else if (op == "BiasGelu") { - const onnx::TensorProto& B = weights[node.input(1)]; - - fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(B, bp); - } else if (op == "Ceil") { - int op_type = 3; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Clip") { - float min; - float max; - if (node.input_size() == 1) { - min = get_node_attr_f(node, "min", -FLT_MAX); - max = get_node_attr_f(node, "max", FLT_MAX); - } else { - min = weights.find(node.input(1)) != weights.end() - ? get_node_attr_from_input(weights[node.input(1)]) - : -FLT_MAX; - max = weights.find(node.input(2)) != weights.end() - ? get_node_attr_from_input(weights[node.input(2)]) - : FLT_MAX; - } - - fprintf(pp, " 0=%e", min); - fprintf(pp, " 1=%e", max); - } else if (op == "Concat") { - int axis = get_node_attr_i(node, "axis", 1); - fprintf(pp, " 0=%d", axis - 1); - } else if (op == "Constant") { - // never reach here - } else if (op == "ConstantOfShape") { - float value = 0.f; - value = get_node_attr_f(node, "value", 0.f); - fprintf(pp, " 0=%f", value); - - } else if (op == "Conv") { - const onnx::TensorProto& W = weights[node.input(1)]; - - int num_filter = W.dims(0); - int has_bias = node.input_size() == 3 ? 1 : 0; - - std::string auto_pad = get_node_attr_s(node, "auto_pad"); - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - std::vector dilations = get_node_attr_ai(node, "dilations"); - std::vector strides = get_node_attr_ai(node, "strides"); - std::vector pads = get_node_attr_ai(node, "pads"); - int group = get_node_attr_i(node, "group", 1); - - fprintf(pp, " 0=%d", num_filter); - - if (kernel_shape.size() == 1) { - fprintf(pp, " 1=%d", kernel_shape[0]); - } else if (kernel_shape.size() == 2) { - fprintf(pp, " 1=%d", kernel_shape[1]); - fprintf(pp, " 11=%d", kernel_shape[0]); - } - - if (dilations.size() == 1) { - fprintf(pp, " 2=%d", dilations[0]); - } else if (dilations.size() == 2) { - fprintf(pp, " 2=%d", dilations[1]); - fprintf(pp, " 12=%d", dilations[0]); - } - - if (strides.size() == 1) { - fprintf(pp, " 3=%d", strides[0]); - } else if (strides.size() == 2) { - fprintf(pp, " 3=%d", strides[1]); - fprintf(pp, " 13=%d", strides[0]); - } - - if (auto_pad == "SAME_UPPER") { - fprintf(pp, " 4=-233"); - } else if (auto_pad == "SAME_LOWER") { - fprintf(pp, " 4=-234"); - } else { - if (pads.size() == 1) { - fprintf(pp, " 4=%d", pads[0]); - } else if (pads.size() == 2) { - fprintf(pp, " 4=%d", pads[1]); - fprintf(pp, " 14=%d", pads[0]); - } else if (pads.size() == 4) { - fprintf(pp, " 4=%d", pads[1]); - fprintf(pp, " 14=%d", pads[0]); - fprintf(pp, " 15=%d", pads[3]); - fprintf(pp, " 16=%d", pads[2]); - } - } - - fprintf(pp, " 5=%d", has_bias); - - fprintf(pp, " 6=%d", get_tensor_proto_data_size(W)); - - if (group > 1) { - fprintf(pp, " 7=%d", group); - } - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(W, bp); - - if (has_bias) { - const onnx::TensorProto& B = weights[node.input(2)]; - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "ConvTranspose") { - const onnx::TensorProto& W = weights[node.input(1)]; - - int has_bias = node.input_size() == 3 ? 1 : 0; - - std::string auto_pad = get_node_attr_s(node, "auto_pad"); - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - std::vector dilations = get_node_attr_ai(node, "dilations"); - std::vector strides = get_node_attr_ai(node, "strides"); - std::vector output_padding = get_node_attr_ai(node, "output_padding"); - std::vector output_shape = get_node_attr_ai(node, "output_shape"); - std::vector pads = get_node_attr_ai(node, "pads"); - int group = get_node_attr_i(node, "group", 1); - int num_filter = W.dims(1) * group; - - fprintf(pp, " 0=%d", num_filter); - - if (kernel_shape.size() == 1) { - fprintf(pp, " 1=%d", kernel_shape[0]); - } else if (kernel_shape.size() == 2) { - fprintf(pp, " 1=%d", kernel_shape[1]); - fprintf(pp, " 11=%d", kernel_shape[0]); - } - - if (dilations.size() == 1) { - fprintf(pp, " 2=%d", dilations[0]); - } else if (dilations.size() == 2) { - fprintf(pp, " 2=%d", dilations[1]); - fprintf(pp, " 12=%d", dilations[0]); - } - - if (strides.size() == 1) { - fprintf(pp, " 3=%d", strides[0]); - } else if (strides.size() == 2) { - fprintf(pp, " 3=%d", strides[1]); - fprintf(pp, " 13=%d", strides[0]); - } - - if (auto_pad == "SAME_UPPER") { - fprintf(pp, " 4=-233"); - } else if (auto_pad == "SAME_LOWER") { - fprintf(pp, " 4=-234"); - } else { - if (pads.size() == 1) { - fprintf(pp, " 4=%d", pads[0]); - } else if (pads.size() == 2) { - fprintf(pp, " 4=%d", pads[1]); - fprintf(pp, " 14=%d", pads[0]); - } else if (pads.size() == 4) { - fprintf(pp, " 4=%d", pads[1]); - fprintf(pp, " 14=%d", pads[0]); - fprintf(pp, " 15=%d", pads[3]); - fprintf(pp, " 16=%d", pads[2]); - } - } - - if (output_padding.size() == 1) { - fprintf(pp, " 18=%d", output_padding[0]); - } else if (output_padding.size() == 2) { - fprintf(pp, " 18=%d", output_padding[1]); - fprintf(pp, " 19=%d", output_padding[0]); - } - - if (output_shape.size() == 1) { - fprintf(pp, " 20=%d", output_shape[0]); - } else if (output_shape.size() == 2) { - fprintf(pp, " 20=%d", output_shape[1]); - fprintf(pp, " 21=%d", output_shape[0]); - } - - fprintf(pp, " 5=%d", has_bias); - - fprintf(pp, " 6=%d", get_tensor_proto_data_size(W)); - - if (group > 1) { - fprintf(pp, " 7=%d", group); - } - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int maxk = 0; - if (kernel_shape.size() == 2) { - maxk = kernel_shape[1] * kernel_shape[0]; - } else { - maxk = kernel_shape[0] * kernel_shape[0]; - } - int weight_data_size = get_tensor_proto_data_size(W); - const float* weight_data = 0; - if (W.has_raw_data()) { - weight_data = (const float*)W.raw_data().data(); - } else if (W.data_type() == 1) { - weight_data = W.float_data().data(); - } - for (int g = 0; g < group; g++) { - // reorder weight from inch-outch to outch-inch - int num_filter_g = num_filter / group; - int num_input = weight_data_size / maxk / num_filter_g / group; - const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input; - for (int k = 0; k < num_filter_g; k++) { - for (int j = 0; j < num_input; j++) { - fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp); - } - } - } - - if (has_bias) { - const onnx::TensorProto& B = weights[node.input(2)]; - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "Cos") { - int op_type = 10; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Crop") { - auto starts = get_node_attr_ai(node, "starts"); - fprintf(pp, " -23309=%zu", starts.size()); - for (size_t j = 0; j < starts.size(); ++j) { - fprintf(pp, ",%i", starts[j]); - } - auto ends = get_node_attr_ai(node, "ends"); - fprintf(pp, " -23310=%zu", ends.size()); - for (size_t j = 0; j < ends.size(); ++j) { - fprintf(pp, ",%i", ends[j]); - } - auto axis = get_node_attr_ai(node, "axis"); - fprintf(pp, " -23311=%zu", axis.size()); - for (size_t j = 0; j < axis.size(); ++j) { - fprintf(pp, ",%i", axis[j]); - } - } else if (op == "DepthToSpace") { - // pixelshuffle - int scale_factor = get_node_attr_i(node, "blocksize", 1); - std::string mode = get_node_attr_s(node, "mode"); - fprintf(pp, " 0=%d", scale_factor); - if (mode == "CRD") { - fprintf(pp, " 1=0"); - } else if (mode == "DCR") { - fprintf(pp, " 1=1"); - } - } else if (op == "DetectionOutput") { - float score_threshold = get_node_attr_f(node, "score_threshold"); - float nms_threshold = get_node_attr_f(node, "nms_threshold"); - int nms_top_k = get_node_attr_i(node, "nms_top_k"); - int keep_top_k = get_node_attr_i(node, "keep_top_k"); - int num_class = get_node_attr_i(node, "num_class"); - std::vector vars = get_node_attr_af(node, "vars"); - fprintf(pp, " 0=%d", num_class); - fprintf(pp, " 1=%f", nms_threshold); - fprintf(pp, " 2=%d", nms_top_k); - fprintf(pp, " 3=%d", keep_top_k); - fprintf(pp, " 4=%f", score_threshold); - fprintf(pp, " 5=%f", vars[0]); - fprintf(pp, " 6=%f", vars[1]); - fprintf(pp, " 7=%f", vars[2]); - fprintf(pp, " 8=%f", vars[3]); - } else if (op == "Div") { - int op_type = 3; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "Dropout") { - // no-op - } else if (op == "Elu") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - fprintf(pp, " 0=%e", alpha); - } else if (op == "EmbedLayerNormalization") { - const onnx::TensorProto& words = weights[node.input(2)]; - const onnx::TensorProto& positions = weights[node.input(3)]; - const onnx::TensorProto& W = weights[node.input(5)]; - const onnx::TensorProto& B = weights[node.input(6)]; - - fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); - fprintf(pp, " 1=%d", get_tensor_proto_data_size(words)); - fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions)); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(words, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(positions, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(W, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(B, bp); - } else if (op == "Equal") { - int op_type = 0; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Exp") { - int op_type = 7; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Flatten") { - int axis = get_node_attr_i(node, "axis", 1); - if (axis != 1) { - fprintf(stderr, "Unsupported Flatten axis %d!\n", axis); - } - } else if (op == "Floor") { - int op_type = 2; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Gather") { - if (weights[node.input(1)].dims_size() > 1) { - fprintf(stderr, "Unsupported indice dims > 1"); - } - int axis = get_node_attr_i(node, "axis", 1) - 1; - if (axis < 0) { - fprintf(stderr, "Unsupported Gather axis: %d\n", axis + 1); - } - fprintf(pp, " 0=%d", axis); - } else if (op == "Gelu") { - fprintf(pp, " 0=1"); - } else if (op == "Gemm") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - float beta = get_node_attr_f(node, "beta", 1.f); - int transA = get_node_attr_i(node, "transA", 0); - int transB = get_node_attr_i(node, "transB", 0); - - if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) { - // InnerProduct-like A * B + C - const onnx::TensorProto& B = weights[node.input(1)]; - // B has transposed. - int num_output = B.dims(0); - fprintf(pp, " 0=%d", num_output); - if (node.input_size() == 3) { - fprintf(pp, " 1=1"); - } else { - fprintf(pp, " 1=0"); - } - fprintf(pp, " 2=%d", get_tensor_proto_data_size(B)); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - fwrite_tensor_proto_data(B, bp); - if (node.input_size() == 3) { - const onnx::TensorProto& C = weights[node.input(2)]; - fwrite_tensor_proto_data(C, bp); - } - } else { - // gemm - fprintf(pp, " 0=%e", alpha); - fprintf(pp, " 1=%e", beta); - fprintf(pp, " 2=%d", transA); - fprintf(pp, " 3=%d", transB); - } - } else if (op == "GlobalAveragePool") { - int pool = 1; - int global_pool = 1; - - fprintf(pp, " 0=%d", pool); - fprintf(pp, " 4=%d", global_pool); - } else if (op == "GlobalMaxPool") { - int pool = 0; - int global_pool = 1; - - fprintf(pp, " 0=%d", pool); - fprintf(pp, " 4=%d", global_pool); - } else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || - op == "adaptive_max_pool2d") { - int pool = 0; - if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d") { - pool = 1; - } - int adaptive_pooling = 1; - const onnx::TensorProto& out_shape_tp = weights[node.input(1)]; - std::vector out_shape = get_node_attr_from_input_ai(out_shape_tp); - - fprintf(pp, " 0=%d", pool); - fprintf(pp, " 7=%d", adaptive_pooling); - if (out_shape.size() == 1) { - fprintf(pp, " 8=%d", out_shape[0]); - } else if (out_shape.size() == 2) { - // out_w - fprintf(pp, " 8=%d", out_shape[1]); - // out_h - fprintf(pp, " 18=%d", out_shape[0]); - } - } else if (op == "GroupNorm") { - int groups = get_node_attr_i(node, "groups", 1); - int channels = get_node_attr_i(node, "channels", 1); - float eps = get_node_attr_f(node, "epsilon", 1e-5f); - int affine = get_node_attr_i(node, "affine", 1); - - if (affine) { - // discard affine-less S=1 B=0 - std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); - if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && - affine_B[0] == 0.f) { - affine = 0; - } else { - affine = 0; - { - for (int j = 0; j < channels; j++) { - if (affine_S[j] != 1.f || affine_B[j] != 0.f) { - affine = 1; - break; - } - } - } - } - } - - fprintf(pp, " 0=%d", groups); - fprintf(pp, " 1=%d", channels); - fprintf(pp, " 2=%e", eps); - fprintf(pp, " 3=%d", affine); - if (affine) { - const onnx::TensorProto& scale = weights[node.input(1)]; - const onnx::TensorProto& B = weights[node.input(2)]; - - fwrite_tensor_proto_data(scale, bp); - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "GRU") { - const onnx::TensorProto& W = weights[node.input(1)]; - const onnx::TensorProto& R = weights[node.input(2)]; - const onnx::TensorProto& B = weights[node.input(3)]; - - int hidden_size = get_node_attr_i(node, "hidden_size", 0); - std::string direction = get_node_attr_s(node, "direction"); - - int direction_type = 0; - if (direction == "forward") { - direction_type = 0; - } else if (direction == "reverse") { - direction_type = 1; - } else if (direction == "bidirectional") { - direction_type = 2; - } - - int weight_data_size = get_tensor_proto_data_size(W); - - fprintf(pp, " 0=%d", hidden_size); - fprintf(pp, " 1=%d", weight_data_size); - fprintf(pp, " 2=%d", direction_type); - - int num_directions = direction_type == 2 ? 2 : 1; - - int quantize_tag = 0; - - // reorder num_directions-URN-hidden-size to - // num_directions-RUN-hidden-size - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions; - const float* wptr = - W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data(); - - const float* uptr = wptr; - const float* rptr = wptr + weight_data_size_g; - const float* nptr = wptr + weight_data_size_g * 2; - fwrite(rptr, sizeof(float), weight_data_size_g, bp); - fwrite(uptr, sizeof(float), weight_data_size_g, bp); - fwrite(nptr, sizeof(float), weight_data_size_g, bp); - - if (direction_type == 2) { - uptr += weight_data_size_g * 3; - rptr += weight_data_size_g * 3; - nptr += weight_data_size_g * 3; - fwrite(rptr, sizeof(float), weight_data_size_g, bp); - fwrite(uptr, sizeof(float), weight_data_size_g, bp); - fwrite(nptr, sizeof(float), weight_data_size_g, bp); - } - } - - // reduce U and R bias except N - // reorder num_directions-URN-hidden to num_directions-RUN-hidden - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions; - const float* bptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - const float* wuptr = bptr; - const float* wrptr = bptr + bias_data_size_g; - const float* wnptr = bptr + bias_data_size_g * 2; - const float* buptr = bptr + bias_data_size_g * 3; - const float* brptr = bptr + bias_data_size_g * 4; - const float* bnptr = bptr + bias_data_size_g * 5; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = wrptr[j] + brptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = wuptr[j] + buptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - fwrite(wnptr, sizeof(float), bias_data_size_g, bp); - fwrite(bnptr, sizeof(float), bias_data_size_g, bp); - - if (direction_type == 2) { - wuptr += bias_data_size_g * 6; - wrptr += bias_data_size_g * 6; - wnptr += bias_data_size_g * 6; - buptr += bias_data_size_g * 6; - brptr += bias_data_size_g * 6; - bnptr += bias_data_size_g * 6; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = wrptr[j] + brptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = wuptr[j] + buptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - fwrite(wnptr, sizeof(float), bias_data_size_g, bp); - fwrite(bnptr, sizeof(float), bias_data_size_g, bp); - } - } - - // reorder num_directions-URN-hidden-hidden to - // num_directions-RUN-hidden-hidden - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions; - const float* Rptr = - R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data(); - - const float* uptr = Rptr; - const float* rptr = Rptr + weight_data_size_g; - const float* nptr = Rptr + weight_data_size_g * 2; - fwrite(rptr, sizeof(float), weight_data_size_g, bp); - fwrite(uptr, sizeof(float), weight_data_size_g, bp); - fwrite(nptr, sizeof(float), weight_data_size_g, bp); - - if (direction_type == 2) { - uptr += weight_data_size_g * 3; - rptr += weight_data_size_g * 3; - nptr += weight_data_size_g * 3; - fwrite(rptr, sizeof(float), weight_data_size_g, bp); - fwrite(uptr, sizeof(float), weight_data_size_g, bp); - fwrite(nptr, sizeof(float), weight_data_size_g, bp); - } - } - } else if (op == "HardSigmoid") { - float alpha = get_node_attr_f(node, "alpha", 0.2f); - float beta = get_node_attr_f(node, "beta", 0.5f); - - fprintf(pp, " 0=%e", alpha); - fprintf(pp, " 1=%e", beta); - } else if (op == "HardSwish") { - float alpha = get_node_attr_f(node, "alpha", 0.2f); - float beta = get_node_attr_f(node, "beta", 0.5f); - - fprintf(pp, " 0=%e", alpha); - fprintf(pp, " 1=%e", beta); - } else if (op == "ImageScaler") { - std::vector bias = get_node_attr_af(node, "bias"); - float scale = get_node_attr_f(node, "scale", 1.f); - - int channels = (int)bias.size(); - - fprintf(pp, " 0=%d", channels); - fprintf(pp, " 1=1"); - - for (int j = 0; j < channels; j++) { - fwrite(&scale, sizeof(float), 1, bp); - } - fwrite(&bias[0], sizeof(float), channels, bp); - } else if (op == "InstanceNormalization") { - float eps = get_node_attr_f(node, "epsilon", 1e-5f); - - // discard affine-less S=1 B=0 - std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); - int channels = (int)affine_S.size(); - int affine = 0; - { - for (int j = 0; j < channels; j++) { - if (affine_S[j] != 1.f || affine_B[j] != 0.f) { - affine = 1; - break; - } - } - } - - fprintf(pp, " 0=%d", channels); - fprintf(pp, " 1=%e", eps); - fprintf(pp, " 2=%d", affine); - if (affine) { - const onnx::TensorProto& scale = weights[node.input(1)]; - const onnx::TensorProto& B = weights[node.input(2)]; - - fwrite_tensor_proto_data(scale, bp); - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "LayerNorm") { - float eps = get_node_attr_f(node, "epsilon", 1e-5f); - int affine = get_node_attr_i(node, "affine", 1); - - if (affine) { - // discard affine-less S=1 B=0 - std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); - int affine_size = (int)affine_S.size(); - affine = 0; - { - for (int j = 0; j < affine_size; j++) { - if (affine_S[j] != 1.f || affine_B[j] != 0.f) { - affine = 1; - break; - } - } - } - - if (affine) { - fprintf(pp, " 0=%d", affine_size); - } - } - - fprintf(pp, " 1=%e", eps); - fprintf(pp, " 2=%d", affine); - - if (affine) { - const onnx::TensorProto& scale = weights[node.input(1)]; - const onnx::TensorProto& B = weights[node.input(2)]; - - fwrite_tensor_proto_data(scale, bp); - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "LeakyRelu") { - float alpha = get_node_attr_f(node, "alpha", 0.01f); - fprintf(pp, " 0=%e", alpha); - } else if (op == "Threshold") { - float threshold = get_node_attr_f(node, "threshold", 0.f); - fprintf(pp, " 0=%e", threshold); - } else if (op == "Log") { - int op_type = 8; - fprintf(pp, " 0=%d", op_type); - } else if (op == "LRN") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - float beta = get_node_attr_f(node, "beta", 0.5f); - float bias = get_node_attr_f(node, "bias", 1.f); - int size = get_node_attr_i(node, "size", 1); - - int norm_region = 0; - - fprintf(pp, " 0=%d", norm_region); - fprintf(pp, " 1=%d", size); - fprintf(pp, " 2=%e", alpha); - fprintf(pp, " 3=%e", beta); - fprintf(pp, " 4=%e", bias); - } else if (op == "LSTM") { - const onnx::TensorProto& W = weights[node.input(1)]; - const onnx::TensorProto& R = weights[node.input(2)]; - const onnx::TensorProto& B = weights[node.input(3)]; - - int hidden_size = get_node_attr_i(node, "hidden_size", 0); - std::string direction = get_node_attr_s(node, "direction"); - - int direction_type = 0; - if (direction == "forward") { - direction_type = 0; - } else if (direction == "reverse") { - direction_type = 1; - } else if (direction == "bidirectional") { - direction_type = 2; - } - - int weight_data_size = get_tensor_proto_data_size(W); - - fprintf(pp, " 0=%d", hidden_size); - fprintf(pp, " 1=%d", weight_data_size); - fprintf(pp, " 2=%d", direction_type); - - int num_directions = direction_type == 2 ? 2 : 1; - - int quantize_tag = 0; - - // reorder num_directions-IOFG-hidden-size to - // num_directions-IFOG-hidden-size - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions; - const float* wptr = - W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data(); - - const float* iptr = wptr; - const float* optr = wptr + weight_data_size_g; - const float* fptr = wptr + weight_data_size_g * 2; - const float* gptr = wptr + weight_data_size_g * 3; - fwrite(iptr, sizeof(float), weight_data_size_g, bp); - fwrite(fptr, sizeof(float), weight_data_size_g, bp); - fwrite(optr, sizeof(float), weight_data_size_g, bp); - fwrite(gptr, sizeof(float), weight_data_size_g, bp); - - if (direction_type == 2) { - iptr += weight_data_size_g * 4; - optr += weight_data_size_g * 4; - fptr += weight_data_size_g * 4; - gptr += weight_data_size_g * 4; - fwrite(iptr, sizeof(float), weight_data_size_g, bp); - fwrite(fptr, sizeof(float), weight_data_size_g, bp); - fwrite(optr, sizeof(float), weight_data_size_g, bp); - fwrite(gptr, sizeof(float), weight_data_size_g, bp); - } - } - - // reduce xc and hc bias - // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions; - const float* xcbptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - const float* xiptr = xcbptr; - const float* xoptr = xcbptr + bias_data_size_g; - const float* xfptr = xcbptr + bias_data_size_g * 2; - const float* xgptr = xcbptr + bias_data_size_g * 3; - const float* hiptr = xcbptr + bias_data_size_g * 4; - const float* hoptr = xcbptr + bias_data_size_g * 5; - const float* hfptr = xcbptr + bias_data_size_g * 6; - const float* hgptr = xcbptr + bias_data_size_g * 7; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xiptr[j] + hiptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xfptr[j] + hfptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xoptr[j] + hoptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xgptr[j] + hgptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - - if (direction_type == 2) { - xiptr += bias_data_size_g * 8; - xoptr += bias_data_size_g * 8; - xfptr += bias_data_size_g * 8; - xgptr += bias_data_size_g * 8; - hiptr += bias_data_size_g * 8; - hoptr += bias_data_size_g * 8; - hfptr += bias_data_size_g * 8; - hgptr += bias_data_size_g * 8; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xiptr[j] + hiptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xfptr[j] + hfptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xoptr[j] + hoptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xgptr[j] + hgptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - - // reorder num_directions-IOFG-hidden-hidden to - // num_directions-IFOG-hidden-hidden - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions; - const float* rptr = - R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data(); - - const float* iptr = rptr; - const float* optr = rptr + weight_data_size_g; - const float* fptr = rptr + weight_data_size_g * 2; - const float* gptr = rptr + weight_data_size_g * 3; - fwrite(iptr, sizeof(float), weight_data_size_g, bp); - fwrite(fptr, sizeof(float), weight_data_size_g, bp); - fwrite(optr, sizeof(float), weight_data_size_g, bp); - fwrite(gptr, sizeof(float), weight_data_size_g, bp); - - if (direction_type == 2) { - iptr += weight_data_size_g * 4; - optr += weight_data_size_g * 4; - fptr += weight_data_size_g * 4; - gptr += weight_data_size_g * 4; - fwrite(iptr, sizeof(float), weight_data_size_g, bp); - fwrite(fptr, sizeof(float), weight_data_size_g, bp); - fwrite(optr, sizeof(float), weight_data_size_g, bp); - fwrite(gptr, sizeof(float), weight_data_size_g, bp); - } - } - } else if (op == "MatMul") { - if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) { - // InnerProduct - const onnx::TensorProto& B = weights[node.input(1)]; - - int weight_data_size = get_tensor_proto_data_size(B); - - int num_output = B.dims(B.dims_size() - 1); - int num_input = weight_data_size / num_output; - - fprintf(pp, " 0=%d", num_output); - fprintf(pp, " 1=0"); - fprintf(pp, " 2=%d", weight_data_size); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - // reorder num_input-num_output to num_output-num_input - { - const float* bptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - - for (int j = 0; j < num_output; j++) { - for (int k = 0; k < num_input; k++) { - float vb = bptr[k * num_output + j]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - - // fwrite_tensor_proto_data(B, bp) - } else { - // default matrix multiplication - } - } else if (op == "Max") { - int op_type = 4; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "Min") { - int op_type = 5; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "Mul") { - int op_type = 2; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "MultiHeadAttention") { - int embed_dim = get_node_attr_i(node, "embed_dim", 0); - int num_heads = get_node_attr_i(node, "num_heads", 0); + internal_split++; + } - fprintf(pp, " 0=%d", embed_dim); - fprintf(pp, " 1=%d", num_heads); + for (int i = 0; i < node_count; i++) + { + const onnx::NodeProto& node = mutable_graph->node(i); + const std::string& op = node.op_type(); - if (node.input_size() == 5) { - const onnx::TensorProto& qkvw = weights[node.input(1)]; - const onnx::TensorProto& qkvb = weights[node.input(2)]; - const onnx::TensorProto& ow = weights[node.input(3)]; - const onnx::TensorProto& ob = weights[node.input(4)]; + // fprintf(stderr, "op = %s\n", op.c_str()); - int weight_data_size = get_tensor_proto_data_size(ow); + if (op == "noop_reducedncnn") + { + continue; + } - fprintf(pp, " 2=%d", weight_data_size); + std::string name = node.name(); + if (name.empty()) + { + name = node.output(0); + } - int quantize_tag = 0; + int input_size = node.input_size(); + int output_size = node.output_size(); - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose qw + for (int j = 0; j < (int)node.input_size(); j++) { - const float* wptr = - qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); - const float* bptr = - qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); + const std::string& input_name = node.input(j); - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim * 3 + k]; - fwrite(&vb, sizeof(float), 1, bp); + // check weight + if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) + { + input_size--; } - } - fwrite(bptr, sizeof(float), embed_dim, bp); - } + if (input_name.empty()) + { + input_size--; + } - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose kw + // fprintf(stderr, " input = %s\n", input_name.c_str()); + } + /* + for (int j=0; j<(int)node.output_size(); j++) { - const float* wptr = - qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); - const float* bptr = - qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); - bptr += embed_dim; + const std::string& output_name = node.output(j); + fprintf(stderr, " output = %s\n", output_name.c_str()); + } + */ - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim * 3 + k + embed_dim]; - fwrite(&vb, sizeof(float), 1, bp); + if (op == "Abs") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Acos") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Add") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "ArgMax") + { + fprintf(pp, "%-16s", "TopK"); + } + else if (op == "Asin") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Atan") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "AveragePool" || op == "MaxPool") + { + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + if (kernel_shape.size() == 1) + { + fprintf(pp, "%-16s", "Pooling1D"); + } + else + { + fprintf(pp, "%-16s", "Pooling"); } - } - - fwrite(bptr, sizeof(float), embed_dim, bp); } - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose vw + else if (op == "BatchNormalization") { - const float* wptr = - qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); - const float* bptr = - qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); - bptr += embed_dim * 2; - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim * 3 + k + embed_dim * 2]; - fwrite(&vb, sizeof(float), 1, bp); + fprintf(pp, "%-16s", "BatchNorm"); + } + else if (op == "BiasGelu") + { + fprintf(pp, "%-16s", "BiasGelu"); + } + else if (op == "Cast") + { + fprintf(pp, "%-16s", "Noop"); + } + else if (op == "Ceil") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Clip") + { + fprintf(pp, "%-16s", "Clip"); + } + else if (op == "Concat") + { + fprintf(pp, "%-16s", "Concat"); + } + else if (op == "Constant") + { + continue; + } + else if (op == "ConstantOfShape") + { + fprintf(pp, "%-16s", "ConstantOfShape"); + } + else if (op == "Conv") + { + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + if (kernel_shape.size() == 1) + { + fprintf(pp, "%-16s", "Convolution1D"); + } + else + { + int group = get_node_attr_i(node, "group", 1); + if (group > 1) + { + fprintf(pp, "%-16s", "ConvolutionDepthWise"); + } + else + { + fprintf(pp, "%-16s", "Convolution"); + } } - } - - fwrite(bptr, sizeof(float), embed_dim, bp); } - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose ow + else if (op == "ConvTranspose") { - const float* wptr = - ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); + int group = get_node_attr_i(node, "group", 1); + if (group > 1) + { + fprintf(pp, "%-16s", "DeconvolutionDepthWise"); + } + else + { + fprintf(pp, "%-16s", "Deconvolution"); } - } } - fwrite_tensor_proto_data(ob, bp); - } else { - const onnx::TensorProto& qw = weights[node.input(3)]; - const onnx::TensorProto& qb = weights[node.input(4)]; - const onnx::TensorProto& kw = weights[node.input(5)]; - const onnx::TensorProto& kb = weights[node.input(6)]; - const onnx::TensorProto& vw = weights[node.input(7)]; - const onnx::TensorProto& vb = weights[node.input(8)]; - const onnx::TensorProto& ow = weights[node.input(9)]; - const onnx::TensorProto& ob = weights[node.input(10)]; - - int weight_data_size = get_tensor_proto_data_size(qw); - - fprintf(pp, " 2=%d", weight_data_size); - - int quantize_tag = 0; - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose qw + else if (op == "Cos") { - const float* wptr = - qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Crop") + { + fprintf(pp, "%-16s", "Crop"); + } + else if (op == "DepthToSpace") + { + fprintf(pp, "%-16s", "PixelShuffle"); + } + else if (op == "DetectionOutput") + { + fprintf(pp, "%-16s", "DetectionOutput"); + } + else if (op == "Div") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "Dropout") + { + fprintf(pp, "%-16s", "Dropout"); + output_size = 1; + } + else if (op == "Elu") + { + fprintf(pp, "%-16s", "ELU"); + } + else if (op == "EmbedLayerNormalization") + { + fprintf(pp, "%-16s", "EmbedLayerNormalization"); + } + else if (op == "Equal") + { + fprintf(pp, "%-16s", "Compare"); + } + else if (op == "Exp") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Expand") + { + fprintf(pp, "%-16s", "Expand"); + } + else if (op == "Flatten") + { + fprintf(pp, "%-16s", "Flatten"); + } + else if (op == "Floor") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Gather") + { + fprintf(pp, "%-16s", "Gather"); + } + else if (op == "Gelu") + { + fprintf(pp, "%-16s", "GELU"); + } + else if (op == "Gemm") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + float beta = get_node_attr_f(node, "beta", 1.f); + int transA = get_node_attr_i(node, "transA", 0); + int transB = get_node_attr_i(node, "transB", 0); + + if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) + { + // InnerProduct-like A * B + C + fprintf(pp, "%-16s", "InnerProduct"); + } + else + { + fprintf(pp, "%-16s", "Gemm"); + } + } + else if (op == "GlobalAveragePool") + { + fprintf(pp, "%-16s", "Pooling"); + } + else if (op == "GlobalMaxPool") + { + fprintf(pp, "%-16s", "Pooling"); + } + else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || + op == "adaptive_max_pool2d") + { + fprintf(pp, "%-16s", "Pooling"); + } + else if (op == "GroupNorm") + { + fprintf(pp, "%-16s", "GroupNorm"); + } + else if (op == "GRU") + { + fprintf(pp, "%-16s", "GRU"); + } + else if (op == "HardSigmoid") + { + fprintf(pp, "%-16s", "HardSigmoid"); + } + else if (op == "HardSwish") + { + fprintf(pp, "%-16s", "HardSwish"); + } + else if (op == "ImageScaler") + { + fprintf(pp, "%-16s", "Scale"); + } + else if (op == "InstanceNormalization") + { + fprintf(pp, "%-16s", "InstanceNorm"); + } + else if (op == "LayerNorm") + { + fprintf(pp, "%-16s", "LayerNorm"); + } + else if (op == "LeakyRelu") + { + fprintf(pp, "%-16s", "ReLU"); + } + else if (op == "Threshold") + { + fprintf(pp, "%-16s", "Threshold"); + } + else if (op == "Log") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "LRN") + { + fprintf(pp, "%-16s", "LRN"); + } + else if (op == "LSTM") + { + fprintf(pp, "%-16s", "LSTM"); + } + else if (op == "MatMul") + { + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) + { + fprintf(pp, "%-16s", "InnerProduct"); + } + else + { + fprintf(pp, "%-16s", "Gemm"); + } + } + else if (op == "Max") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "Min") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "Mul") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "MultiHeadAttention") + { + fprintf(pp, "%-16s", "MultiHeadAttention"); + } + else if (op == "Neg") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "NonMaxSuppression") + { + fprintf(pp, "%-16s", "NonMaxSuppression"); + } + else if (op == "Normalize") + { + fprintf(pp, "%-16s", "Normalize"); + } + else if (op == "Pad") + { + fprintf(pp, "%-16s", "Padding"); + } + else if (op == "PixelShuffle") + { + fprintf(pp, "%-16s", "PixelShuffle"); + } + else if (op == "Pow") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "PriorBox") + { + fprintf(pp, "%-16s", "PriorBox"); + } + else if (op == "PRelu") + { + fprintf(pp, "%-16s", "PReLU"); + } + else if (op == "Range") + { + fprintf(pp, "%-16s", "Range"); + } + else if (op == "Reciprocal") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || + op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || + op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") + { + fprintf(pp, "%-16s", "Reduction"); + } + else if (op == "Relu") + { + fprintf(pp, "%-16s", "ReLU"); + } + else if (op == "Reorg") + { + fprintf(pp, "%-16s", "Reorg"); + } + else if (op == "Reshape") + { + fprintf(pp, "%-16s", "Reshape"); + } + else if (op == "RNN") + { + fprintf(pp, "%-16s", "RNN"); + } + else if (op == "RDiv") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "RSub") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "RoiAlign") + { + fprintf(pp, "%-16s", "ROIAlign"); + } + else if (op == "ScatterND") + { + fprintf(pp, "%-16s", "ScatterND"); + } + else if (op == "Shape") + { + fprintf(pp, "%-16s", "Shape"); + } + else if (op == "ShuffleChannel") + { + fprintf(pp, "%-16s", "ShuffleChannel"); + } + else if (op == "Sigmoid") + { + fprintf(pp, "%-16s", "Sigmoid"); + } + else if (op == "Sin") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "SkipLayerNormalization") + { + fprintf(pp, "%-16s", "SkipLayerNormalization"); + } + else if (op == "Slice") + { + std::vector ends; + std::vector steps; + bool use_crop = true; + + if (node.input_size() == 1) + { + ends = get_node_attr_ai(node, "ends"); + steps = get_node_attr_ai(node, "steps"); // TODO + } + else + { + ends = get_node_attr_from_input_ai(weights[node.input(2)]); + if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]); + } + + // assert step == 1 + for (int i = 0; i < (int)steps.size(); i++) + { + if (steps[i] != 1 && steps[i] < ends[i]) + { + use_crop = false; + break; + } + } + + if (use_crop) + { + fprintf(pp, "%-16s", "Crop"); + } + else + { + fprintf(pp, "%-16s", "TensorSlice"); + } + } + else if (op == "Softmax") + { + fprintf(pp, "%-16s", "Softmax"); + } + else if (op == "Softplus") + { + fprintf(pp, "%-16s", "Softplus"); + } + else if (op == "Split") + { + fprintf(pp, "%-16s", "Slice"); + } + else if (op == "Sqrt") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Squeeze") + { + std::vector axes = get_node_attr_ai(node, "axes"); + // fprintf(stderr, "axes[0]: %d\n",axes[0]); + if (axes[0] == 0) + { + fprintf(pp, "%-16s", "Noop"); + } + else + { + fprintf(pp, "%-16s", "Squeeze"); + } + } + else if (op == "Sub") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "Sum") + { + fprintf(pp, "%-16s", "Eltwise"); + } + else if (op == "Swish") + { + fprintf(pp, "%-16s", "Swish"); + } + else if (op == "Tan") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Tanh") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Tile") + { + fprintf(pp, "%-16s", "TileOnnx"); + } + else if (op == "TopK") + { + fprintf(pp, "%-16s", "TopK"); + } + else if (op == "Transpose") + { + fprintf(pp, "%-16s", "Permute"); + } + else if (op == "Upsample" || op == "Resize") + { + fprintf(pp, "%-16s", "Interp"); + } + else if (op == "Unsqueeze") + { + std::vector axes = get_node_attr_ai(node, "axes"); + // fprintf(stderr, "axes[0]: %d\n",axes[0]); + if (axes[0] == 0) + { + fprintf(pp, "%-16s", "Noop"); + } + else + { + fprintf(pp, "%-16s", "ExpandDims"); + } + } + else if (op == "Where") + { + fprintf(pp, "%-16s", "Where"); + } + else if (op == "Yolov3DetectionOutput") + { + fprintf(pp, "%-16s", "Yolov3DetectionOutput"); + } + else + { + // TODO + fprintf(stderr, "%s not supported yet!\n", op.c_str()); + fprintf(pp, "%-16s", op.c_str()); + } + + fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size); + + for (int j = 0; j < (int)node.input_size(); j++) + { + std::string input_name = node.input(j); + + // check weight + if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) + { + continue; + } + + if (input_name.empty()) + { + continue; + } + + if (split_node_reference.find(input_name) != split_node_reference.end()) + { + int refidx = split_node_reference[input_name] - 1; + split_node_reference[input_name] = refidx; + + char splitsuffix[256]; + sprintf(splitsuffix, "_splitncnn_%d", refidx); + input_name = input_name + splitsuffix; + } + + fprintf(pp, " %s", input_name.c_str()); + } + + for (int j = 0; j < output_size; j++) + { + const std::string& output_name = node.output(j); + + fprintf(pp, " %s", output_name.c_str()); + } + + if (op == "Abs") + { + int op_type = 0; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Acos") + { + int op_type = 13; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Add") + { + int op_type = 0; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "ArgMax") + { + int axis = get_node_attr_i(node, "axis"); + int keepdims = get_node_attr_i(node, "keepdims"); + fprintf(pp, " 0=%d", axis - 1); + fprintf(pp, " 3=%d", keepdims); + } + else if (op == "Asin") + { + int op_type = 12; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Atan") + { + int op_type = 14; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "AveragePool" || op == "MaxPool") + { + std::string auto_pad = get_node_attr_s(node, "auto_pad"); + int ceil_mode = get_node_attr_i(node, "ceil_mode", 0); + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + std::vector strides = get_node_attr_ai(node, "strides"); + std::vector pads = get_node_attr_ai(node, "pads"); + + int pool = op == "AveragePool" ? 1 : 0; + int pad_mode = 1; + + if (auto_pad == "SAME_UPPER") + { + pad_mode = 2; + } + else if (auto_pad == "SAME_LOWER") + { + pad_mode = 3; + } + + if (ceil_mode == 1) + { + pad_mode = 0; + } + + fprintf(pp, " 0=%d", pool); + + if (kernel_shape.size() == 1) + { + fprintf(pp, " 1=%d", kernel_shape[0]); + } + else if (kernel_shape.size() == 2) + { + fprintf(pp, " 1=%d", kernel_shape[1]); + fprintf(pp, " 11=%d", kernel_shape[0]); + } + + if (strides.size() == 1) + { + fprintf(pp, " 2=%d", strides[0]); + } + else if (strides.size() == 2) + { + fprintf(pp, " 2=%d", strides[1]); + fprintf(pp, " 12=%d", strides[0]); + } + + if (pads.size() == 1) + { + fprintf(pp, " 3=%d", pads[0]); + } + else if (pads.size() == 2) + { + fprintf(pp, " 3=%d", pads[1]); + fprintf(pp, " 13=%d", pads[0]); + } + else if (pads.size() == 4) + { + fprintf(pp, " 3=%d", pads[1]); + fprintf(pp, " 13=%d", pads[0]); + fprintf(pp, " 14=%d", pads[3]); + fprintf(pp, " 15=%d", pads[2]); + } + + fprintf(pp, " 5=%d", pad_mode); + + if (op == "AveragePool") + { + int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0); + fprintf(pp, " 6=%d", avgpool_count_include_pad); + } + } + else if (op == "BatchNormalization") + { + float epsilon = get_node_attr_f(node, "epsilon", 1e-5f); + + const onnx::TensorProto& scale = weights[node.input(1)]; + const onnx::TensorProto& B = weights[node.input(2)]; + const onnx::TensorProto& mean = weights[node.input(3)]; + const onnx::TensorProto& var = weights[node.input(4)]; + + int channels = get_tensor_proto_data_size(scale); + + fprintf(pp, " 0=%d", channels); + + fwrite_tensor_proto_data(scale, bp); + fwrite_tensor_proto_data(mean, bp); + // apply epsilon to var + { + const float* v = + var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data(); + + for (int j = 0; j < channels; j++) + { + float ve = v[j] + epsilon; + fwrite(&ve, sizeof(float), 1, bp); + } + } + fwrite_tensor_proto_data(B, bp); + } + else if (op == "BiasGelu") + { + const onnx::TensorProto& B = weights[node.input(1)]; + + fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(B, bp); + } + else if (op == "Ceil") + { + int op_type = 3; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Clip") + { + float min; + float max; + if (node.input_size() == 1) + { + min = get_node_attr_f(node, "min", -FLT_MAX); + max = get_node_attr_f(node, "max", FLT_MAX); + } + else + { + min = weights.find(node.input(1)) != weights.end() ? get_node_attr_from_input(weights[node.input(1)]) : -FLT_MAX; + max = weights.find(node.input(2)) != weights.end() ? get_node_attr_from_input(weights[node.input(2)]) : FLT_MAX; + } + + fprintf(pp, " 0=%e", min); + fprintf(pp, " 1=%e", max); + } + else if (op == "Concat") + { + int axis = get_node_attr_i(node, "axis", 1); + fprintf(pp, " 0=%d", axis - 1); + } + else if (op == "Constant") + { + // never reach here + } + else if (op == "ConstantOfShape") + { + float value = 0.f; + value = get_node_attr_f(node, "value", 0.f); + fprintf(pp, " 0=%f", value); + } + else if (op == "Conv") + { + const onnx::TensorProto& W = weights[node.input(1)]; + + int num_filter = W.dims(0); + int has_bias = node.input_size() == 3 ? 1 : 0; + + std::string auto_pad = get_node_attr_s(node, "auto_pad"); + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + std::vector dilations = get_node_attr_ai(node, "dilations"); + std::vector strides = get_node_attr_ai(node, "strides"); + std::vector pads = get_node_attr_ai(node, "pads"); + int group = get_node_attr_i(node, "group", 1); + + fprintf(pp, " 0=%d", num_filter); + + if (kernel_shape.size() == 1) + { + fprintf(pp, " 1=%d", kernel_shape[0]); + } + else if (kernel_shape.size() == 2) + { + fprintf(pp, " 1=%d", kernel_shape[1]); + fprintf(pp, " 11=%d", kernel_shape[0]); + } + + if (dilations.size() == 1) + { + fprintf(pp, " 2=%d", dilations[0]); + } + else if (dilations.size() == 2) + { + fprintf(pp, " 2=%d", dilations[1]); + fprintf(pp, " 12=%d", dilations[0]); + } + + if (strides.size() == 1) + { + fprintf(pp, " 3=%d", strides[0]); + } + else if (strides.size() == 2) + { + fprintf(pp, " 3=%d", strides[1]); + fprintf(pp, " 13=%d", strides[0]); + } + + if (auto_pad == "SAME_UPPER") + { + fprintf(pp, " 4=-233"); + } + else if (auto_pad == "SAME_LOWER") + { + fprintf(pp, " 4=-234"); + } + else + { + if (pads.size() == 1) + { + fprintf(pp, " 4=%d", pads[0]); + } + else if (pads.size() == 2) + { + fprintf(pp, " 4=%d", pads[1]); + fprintf(pp, " 14=%d", pads[0]); + } + else if (pads.size() == 4) + { + fprintf(pp, " 4=%d", pads[1]); + fprintf(pp, " 14=%d", pads[0]); + fprintf(pp, " 15=%d", pads[3]); + fprintf(pp, " 16=%d", pads[2]); + } + } + + fprintf(pp, " 5=%d", has_bias); + + fprintf(pp, " 6=%d", get_tensor_proto_data_size(W)); + + if (group > 1) + { + fprintf(pp, " 7=%d", group); + } + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(W, bp); + + if (has_bias) + { + const onnx::TensorProto& B = weights[node.input(2)]; + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "ConvTranspose") + { + const onnx::TensorProto& W = weights[node.input(1)]; + + int has_bias = node.input_size() == 3 ? 1 : 0; + + std::string auto_pad = get_node_attr_s(node, "auto_pad"); + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + std::vector dilations = get_node_attr_ai(node, "dilations"); + std::vector strides = get_node_attr_ai(node, "strides"); + std::vector output_padding = get_node_attr_ai(node, "output_padding"); + std::vector output_shape = get_node_attr_ai(node, "output_shape"); + std::vector pads = get_node_attr_ai(node, "pads"); + int group = get_node_attr_i(node, "group", 1); + int num_filter = W.dims(1) * group; + + fprintf(pp, " 0=%d", num_filter); + + if (kernel_shape.size() == 1) + { + fprintf(pp, " 1=%d", kernel_shape[0]); + } + else if (kernel_shape.size() == 2) + { + fprintf(pp, " 1=%d", kernel_shape[1]); + fprintf(pp, " 11=%d", kernel_shape[0]); + } + + if (dilations.size() == 1) + { + fprintf(pp, " 2=%d", dilations[0]); + } + else if (dilations.size() == 2) + { + fprintf(pp, " 2=%d", dilations[1]); + fprintf(pp, " 12=%d", dilations[0]); + } + + if (strides.size() == 1) + { + fprintf(pp, " 3=%d", strides[0]); + } + else if (strides.size() == 2) + { + fprintf(pp, " 3=%d", strides[1]); + fprintf(pp, " 13=%d", strides[0]); + } + + if (auto_pad == "SAME_UPPER") + { + fprintf(pp, " 4=-233"); + } + else if (auto_pad == "SAME_LOWER") + { + fprintf(pp, " 4=-234"); + } + else + { + if (pads.size() == 1) + { + fprintf(pp, " 4=%d", pads[0]); + } + else if (pads.size() == 2) + { + fprintf(pp, " 4=%d", pads[1]); + fprintf(pp, " 14=%d", pads[0]); + } + else if (pads.size() == 4) + { + fprintf(pp, " 4=%d", pads[1]); + fprintf(pp, " 14=%d", pads[0]); + fprintf(pp, " 15=%d", pads[3]); + fprintf(pp, " 16=%d", pads[2]); + } + } + + if (output_padding.size() == 1) + { + fprintf(pp, " 18=%d", output_padding[0]); + } + else if (output_padding.size() == 2) + { + fprintf(pp, " 18=%d", output_padding[1]); + fprintf(pp, " 19=%d", output_padding[0]); + } + + if (output_shape.size() == 1) + { + fprintf(pp, " 20=%d", output_shape[0]); + } + else if (output_shape.size() == 2) + { + fprintf(pp, " 20=%d", output_shape[1]); + fprintf(pp, " 21=%d", output_shape[0]); + } + + fprintf(pp, " 5=%d", has_bias); + + fprintf(pp, " 6=%d", get_tensor_proto_data_size(W)); + + if (group > 1) + { + fprintf(pp, " 7=%d", group); + } + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int maxk = 0; + if (kernel_shape.size() == 2) + { + maxk = kernel_shape[1] * kernel_shape[0]; + } + else + { + maxk = kernel_shape[0] * kernel_shape[0]; + } + int weight_data_size = get_tensor_proto_data_size(W); + const float* weight_data = 0; + if (W.has_raw_data()) + { + weight_data = (const float*)W.raw_data().data(); + } + else if (W.data_type() == 1) + { + weight_data = W.float_data().data(); + } + for (int g = 0; g < group; g++) + { + // reorder weight from inch-outch to outch-inch + int num_filter_g = num_filter / group; + int num_input = weight_data_size / maxk / num_filter_g / group; + const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input; + for (int k = 0; k < num_filter_g; k++) + { + for (int j = 0; j < num_input; j++) + { + fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp); + } + } + } + + if (has_bias) + { + const onnx::TensorProto& B = weights[node.input(2)]; + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "Cos") + { + int op_type = 10; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Crop") + { + auto starts = get_node_attr_ai(node, "starts"); + fprintf(pp, " -23309=%zu", starts.size()); + for (size_t j = 0; j < starts.size(); ++j) + { + fprintf(pp, ",%i", starts[j]); + } + auto ends = get_node_attr_ai(node, "ends"); + fprintf(pp, " -23310=%zu", ends.size()); + for (size_t j = 0; j < ends.size(); ++j) + { + fprintf(pp, ",%i", ends[j]); + } + auto axis = get_node_attr_ai(node, "axis"); + fprintf(pp, " -23311=%zu", axis.size()); + for (size_t j = 0; j < axis.size(); ++j) + { + fprintf(pp, ",%i", axis[j]); + } + } + else if (op == "DepthToSpace") + { + // pixelshuffle + int scale_factor = get_node_attr_i(node, "blocksize", 1); + std::string mode = get_node_attr_s(node, "mode"); + fprintf(pp, " 0=%d", scale_factor); + if (mode == "CRD") + { + fprintf(pp, " 1=0"); + } + else if (mode == "DCR") + { + fprintf(pp, " 1=1"); + } + } + else if (op == "DetectionOutput") + { + float score_threshold = get_node_attr_f(node, "score_threshold"); + float nms_threshold = get_node_attr_f(node, "nms_threshold"); + int nms_top_k = get_node_attr_i(node, "nms_top_k"); + int keep_top_k = get_node_attr_i(node, "keep_top_k"); + int num_class = get_node_attr_i(node, "num_class"); + std::vector vars = get_node_attr_af(node, "vars"); + fprintf(pp, " 0=%d", num_class); + fprintf(pp, " 1=%f", nms_threshold); + fprintf(pp, " 2=%d", nms_top_k); + fprintf(pp, " 3=%d", keep_top_k); + fprintf(pp, " 4=%f", score_threshold); + fprintf(pp, " 5=%f", vars[0]); + fprintf(pp, " 6=%f", vars[1]); + fprintf(pp, " 7=%f", vars[2]); + fprintf(pp, " 8=%f", vars[3]); + } + else if (op == "Div") + { + int op_type = 3; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "Dropout") + { + // no-op + } + else if (op == "Elu") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + fprintf(pp, " 0=%e", alpha); + } + else if (op == "EmbedLayerNormalization") + { + const onnx::TensorProto& words = weights[node.input(2)]; + const onnx::TensorProto& positions = weights[node.input(3)]; + const onnx::TensorProto& W = weights[node.input(5)]; + const onnx::TensorProto& B = weights[node.input(6)]; + + fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); + fprintf(pp, " 1=%d", get_tensor_proto_data_size(words)); + fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions)); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(words, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(positions, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(W, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(B, bp); + } + else if (op == "Equal") + { + int op_type = 0; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Exp") + { + int op_type = 7; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Flatten") + { + int axis = get_node_attr_i(node, "axis", 1); + if (axis != 1) + { + fprintf(stderr, "Unsupported Flatten axis %d!\n", axis); + } + } + else if (op == "Floor") + { + int op_type = 2; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Gather") + { + if (weights[node.input(1)].dims_size() > 1) + { + fprintf(stderr, "Unsupported indice dims > 1"); + } + int axis = get_node_attr_i(node, "axis", 1) - 1; + if (axis < 0) + { + fprintf(stderr, "Unsupported Gather axis: %d\n", axis + 1); + } + fprintf(pp, " 0=%d", axis); + } + else if (op == "Gelu") + { + fprintf(pp, " 0=1"); + } + else if (op == "Gemm") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + float beta = get_node_attr_f(node, "beta", 1.f); + int transA = get_node_attr_i(node, "transA", 0); + int transB = get_node_attr_i(node, "transB", 0); + + if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) + { + // InnerProduct-like A * B + C + const onnx::TensorProto& B = weights[node.input(1)]; + // B has transposed. + int num_output = B.dims(0); + fprintf(pp, " 0=%d", num_output); + if (node.input_size() == 3) + { + fprintf(pp, " 1=1"); + } + else + { + fprintf(pp, " 1=0"); + } + fprintf(pp, " 2=%d", get_tensor_proto_data_size(B)); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + fwrite_tensor_proto_data(B, bp); + if (node.input_size() == 3) + { + const onnx::TensorProto& C = weights[node.input(2)]; + fwrite_tensor_proto_data(C, bp); + } + } + else + { + // gemm + fprintf(pp, " 0=%e", alpha); + fprintf(pp, " 1=%e", beta); + fprintf(pp, " 2=%d", transA); + fprintf(pp, " 3=%d", transB); + } + } + else if (op == "GlobalAveragePool") + { + int pool = 1; + int global_pool = 1; + + fprintf(pp, " 0=%d", pool); + fprintf(pp, " 4=%d", global_pool); + } + else if (op == "GlobalMaxPool") + { + int pool = 0; + int global_pool = 1; + + fprintf(pp, " 0=%d", pool); + fprintf(pp, " 4=%d", global_pool); + } + else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || + op == "adaptive_max_pool2d") + { + int pool = 0; + if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d") + { + pool = 1; + } + int adaptive_pooling = 1; + const onnx::TensorProto& out_shape_tp = weights[node.input(1)]; + std::vector out_shape = get_node_attr_from_input_ai(out_shape_tp); + + fprintf(pp, " 0=%d", pool); + fprintf(pp, " 7=%d", adaptive_pooling); + if (out_shape.size() == 1) + { + fprintf(pp, " 8=%d", out_shape[0]); + } + else if (out_shape.size() == 2) + { + // out_w + fprintf(pp, " 8=%d", out_shape[1]); + // out_h + fprintf(pp, " 18=%d", out_shape[0]); + } + } + else if (op == "GroupNorm") + { + int groups = get_node_attr_i(node, "groups", 1); + int channels = get_node_attr_i(node, "channels", 1); + float eps = get_node_attr_f(node, "epsilon", 1e-5f); + int affine = get_node_attr_i(node, "affine", 1); + + if (affine) + { + // discard affine-less S=1 B=0 + std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); + if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && + affine_B[0] == 0.f) + { + affine = 0; + } + else + { + affine = 0; + { + for (int j = 0; j < channels; j++) + { + if (affine_S[j] != 1.f || affine_B[j] != 0.f) + { + affine = 1; + break; + } + } + } + } + } + + fprintf(pp, " 0=%d", groups); + fprintf(pp, " 1=%d", channels); + fprintf(pp, " 2=%e", eps); + fprintf(pp, " 3=%d", affine); + if (affine) + { + const onnx::TensorProto& scale = weights[node.input(1)]; + const onnx::TensorProto& B = weights[node.input(2)]; + + fwrite_tensor_proto_data(scale, bp); + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "GRU") + { + const onnx::TensorProto& W = weights[node.input(1)]; + const onnx::TensorProto& R = weights[node.input(2)]; + const onnx::TensorProto& B = weights[node.input(3)]; + + int hidden_size = get_node_attr_i(node, "hidden_size", 0); + std::string direction = get_node_attr_s(node, "direction"); + + int direction_type = 0; + if (direction == "forward") + { + direction_type = 0; + } + else if (direction == "reverse") + { + direction_type = 1; + } + else if (direction == "bidirectional") + { + direction_type = 2; + } + + int weight_data_size = get_tensor_proto_data_size(W); + + fprintf(pp, " 0=%d", hidden_size); + fprintf(pp, " 1=%d", weight_data_size); + fprintf(pp, " 2=%d", direction_type); + + int num_directions = direction_type == 2 ? 2 : 1; + + int quantize_tag = 0; + + // reorder num_directions-URN-hidden-size to + // num_directions-RUN-hidden-size + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions; + const float* wptr = + W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data(); + + const float* uptr = wptr; + const float* rptr = wptr + weight_data_size_g; + const float* nptr = wptr + weight_data_size_g * 2; + fwrite(rptr, sizeof(float), weight_data_size_g, bp); + fwrite(uptr, sizeof(float), weight_data_size_g, bp); + fwrite(nptr, sizeof(float), weight_data_size_g, bp); + + if (direction_type == 2) + { + uptr += weight_data_size_g * 3; + rptr += weight_data_size_g * 3; + nptr += weight_data_size_g * 3; + fwrite(rptr, sizeof(float), weight_data_size_g, bp); + fwrite(uptr, sizeof(float), weight_data_size_g, bp); + fwrite(nptr, sizeof(float), weight_data_size_g, bp); + } + } + + // reduce U and R bias except N + // reorder num_directions-URN-hidden to num_directions-RUN-hidden + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions; + const float* bptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + const float* wuptr = bptr; + const float* wrptr = bptr + bias_data_size_g; + const float* wnptr = bptr + bias_data_size_g * 2; + const float* buptr = bptr + bias_data_size_g * 3; + const float* brptr = bptr + bias_data_size_g * 4; + const float* bnptr = bptr + bias_data_size_g * 5; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = wrptr[j] + brptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = wuptr[j] + buptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + fwrite(wnptr, sizeof(float), bias_data_size_g, bp); + fwrite(bnptr, sizeof(float), bias_data_size_g, bp); + + if (direction_type == 2) + { + wuptr += bias_data_size_g * 6; + wrptr += bias_data_size_g * 6; + wnptr += bias_data_size_g * 6; + buptr += bias_data_size_g * 6; + brptr += bias_data_size_g * 6; + bnptr += bias_data_size_g * 6; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = wrptr[j] + brptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = wuptr[j] + buptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + fwrite(wnptr, sizeof(float), bias_data_size_g, bp); + fwrite(bnptr, sizeof(float), bias_data_size_g, bp); + } + } + + // reorder num_directions-URN-hidden-hidden to + // num_directions-RUN-hidden-hidden + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions; + const float* Rptr = + R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data(); + + const float* uptr = Rptr; + const float* rptr = Rptr + weight_data_size_g; + const float* nptr = Rptr + weight_data_size_g * 2; + fwrite(rptr, sizeof(float), weight_data_size_g, bp); + fwrite(uptr, sizeof(float), weight_data_size_g, bp); + fwrite(nptr, sizeof(float), weight_data_size_g, bp); + + if (direction_type == 2) + { + uptr += weight_data_size_g * 3; + rptr += weight_data_size_g * 3; + nptr += weight_data_size_g * 3; + fwrite(rptr, sizeof(float), weight_data_size_g, bp); + fwrite(uptr, sizeof(float), weight_data_size_g, bp); + fwrite(nptr, sizeof(float), weight_data_size_g, bp); + } + } + } + else if (op == "HardSigmoid") + { + float alpha = get_node_attr_f(node, "alpha", 0.2f); + float beta = get_node_attr_f(node, "beta", 0.5f); + + fprintf(pp, " 0=%e", alpha); + fprintf(pp, " 1=%e", beta); + } + else if (op == "HardSwish") + { + float alpha = get_node_attr_f(node, "alpha", 0.2f); + float beta = get_node_attr_f(node, "beta", 0.5f); + + fprintf(pp, " 0=%e", alpha); + fprintf(pp, " 1=%e", beta); + } + else if (op == "ImageScaler") + { + std::vector bias = get_node_attr_af(node, "bias"); + float scale = get_node_attr_f(node, "scale", 1.f); + + int channels = (int)bias.size(); + + fprintf(pp, " 0=%d", channels); + fprintf(pp, " 1=1"); + + for (int j = 0; j < channels; j++) + { + fwrite(&scale, sizeof(float), 1, bp); + } + fwrite(&bias[0], sizeof(float), channels, bp); + } + else if (op == "InstanceNormalization") + { + float eps = get_node_attr_f(node, "epsilon", 1e-5f); + + // discard affine-less S=1 B=0 + std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); + int channels = (int)affine_S.size(); + int affine = 0; + { + for (int j = 0; j < channels; j++) + { + if (affine_S[j] != 1.f || affine_B[j] != 0.f) + { + affine = 1; + break; + } + } + } + + fprintf(pp, " 0=%d", channels); + fprintf(pp, " 1=%e", eps); + fprintf(pp, " 2=%d", affine); + if (affine) + { + const onnx::TensorProto& scale = weights[node.input(1)]; + const onnx::TensorProto& B = weights[node.input(2)]; + + fwrite_tensor_proto_data(scale, bp); + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "LayerNorm") + { + float eps = get_node_attr_f(node, "epsilon", 1e-5f); + int affine = get_node_attr_i(node, "affine", 1); + + if (affine) + { + // discard affine-less S=1 B=0 + std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); + int affine_size = (int)affine_S.size(); + affine = 0; + { + for (int j = 0; j < affine_size; j++) + { + if (affine_S[j] != 1.f || affine_B[j] != 0.f) + { + affine = 1; + break; + } + } + } + + if (affine) + { + fprintf(pp, " 0=%d", affine_size); + } + } + + fprintf(pp, " 1=%e", eps); + fprintf(pp, " 2=%d", affine); + + if (affine) + { + const onnx::TensorProto& scale = weights[node.input(1)]; + const onnx::TensorProto& B = weights[node.input(2)]; + + fwrite_tensor_proto_data(scale, bp); + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "LeakyRelu") + { + float alpha = get_node_attr_f(node, "alpha", 0.01f); + fprintf(pp, " 0=%e", alpha); + } + else if (op == "Threshold") + { + float threshold = get_node_attr_f(node, "threshold", 0.f); + fprintf(pp, " 0=%e", threshold); + } + else if (op == "Log") + { + int op_type = 8; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "LRN") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + float beta = get_node_attr_f(node, "beta", 0.5f); + float bias = get_node_attr_f(node, "bias", 1.f); + int size = get_node_attr_i(node, "size", 1); + + int norm_region = 0; + + fprintf(pp, " 0=%d", norm_region); + fprintf(pp, " 1=%d", size); + fprintf(pp, " 2=%e", alpha); + fprintf(pp, " 3=%e", beta); + fprintf(pp, " 4=%e", bias); + } + else if (op == "LSTM") + { + const onnx::TensorProto& W = weights[node.input(1)]; + const onnx::TensorProto& R = weights[node.input(2)]; + const onnx::TensorProto& B = weights[node.input(3)]; + + int hidden_size = get_node_attr_i(node, "hidden_size", 0); + std::string direction = get_node_attr_s(node, "direction"); + + int direction_type = 0; + if (direction == "forward") + { + direction_type = 0; + } + else if (direction == "reverse") + { + direction_type = 1; + } + else if (direction == "bidirectional") + { + direction_type = 2; + } + + int weight_data_size = get_tensor_proto_data_size(W); + + fprintf(pp, " 0=%d", hidden_size); + fprintf(pp, " 1=%d", weight_data_size); + fprintf(pp, " 2=%d", direction_type); + + int num_directions = direction_type == 2 ? 2 : 1; + + int quantize_tag = 0; + + // reorder num_directions-IOFG-hidden-size to + // num_directions-IFOG-hidden-size + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions; + const float* wptr = + W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data(); + + const float* iptr = wptr; + const float* optr = wptr + weight_data_size_g; + const float* fptr = wptr + weight_data_size_g * 2; + const float* gptr = wptr + weight_data_size_g * 3; + fwrite(iptr, sizeof(float), weight_data_size_g, bp); + fwrite(fptr, sizeof(float), weight_data_size_g, bp); + fwrite(optr, sizeof(float), weight_data_size_g, bp); + fwrite(gptr, sizeof(float), weight_data_size_g, bp); + + if (direction_type == 2) + { + iptr += weight_data_size_g * 4; + optr += weight_data_size_g * 4; + fptr += weight_data_size_g * 4; + gptr += weight_data_size_g * 4; + fwrite(iptr, sizeof(float), weight_data_size_g, bp); + fwrite(fptr, sizeof(float), weight_data_size_g, bp); + fwrite(optr, sizeof(float), weight_data_size_g, bp); + fwrite(gptr, sizeof(float), weight_data_size_g, bp); + } + } + + // reduce xc and hc bias + // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions; + const float* xcbptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + const float* xiptr = xcbptr; + const float* xoptr = xcbptr + bias_data_size_g; + const float* xfptr = xcbptr + bias_data_size_g * 2; + const float* xgptr = xcbptr + bias_data_size_g * 3; + const float* hiptr = xcbptr + bias_data_size_g * 4; + const float* hoptr = xcbptr + bias_data_size_g * 5; + const float* hfptr = xcbptr + bias_data_size_g * 6; + const float* hgptr = xcbptr + bias_data_size_g * 7; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xiptr[j] + hiptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xfptr[j] + hfptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xoptr[j] + hoptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xgptr[j] + hgptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + + if (direction_type == 2) + { + xiptr += bias_data_size_g * 8; + xoptr += bias_data_size_g * 8; + xfptr += bias_data_size_g * 8; + xgptr += bias_data_size_g * 8; + hiptr += bias_data_size_g * 8; + hoptr += bias_data_size_g * 8; + hfptr += bias_data_size_g * 8; + hgptr += bias_data_size_g * 8; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xiptr[j] + hiptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xfptr[j] + hfptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xoptr[j] + hoptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xgptr[j] + hgptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + + // reorder num_directions-IOFG-hidden-hidden to + // num_directions-IFOG-hidden-hidden + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions; + const float* rptr = + R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data(); + + const float* iptr = rptr; + const float* optr = rptr + weight_data_size_g; + const float* fptr = rptr + weight_data_size_g * 2; + const float* gptr = rptr + weight_data_size_g * 3; + fwrite(iptr, sizeof(float), weight_data_size_g, bp); + fwrite(fptr, sizeof(float), weight_data_size_g, bp); + fwrite(optr, sizeof(float), weight_data_size_g, bp); + fwrite(gptr, sizeof(float), weight_data_size_g, bp); + + if (direction_type == 2) + { + iptr += weight_data_size_g * 4; + optr += weight_data_size_g * 4; + fptr += weight_data_size_g * 4; + gptr += weight_data_size_g * 4; + fwrite(iptr, sizeof(float), weight_data_size_g, bp); + fwrite(fptr, sizeof(float), weight_data_size_g, bp); + fwrite(optr, sizeof(float), weight_data_size_g, bp); + fwrite(gptr, sizeof(float), weight_data_size_g, bp); + } + } + } + else if (op == "MatMul") + { + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) + { + // InnerProduct + const onnx::TensorProto& B = weights[node.input(1)]; + + int weight_data_size = get_tensor_proto_data_size(B); + + int num_output = B.dims(B.dims_size() - 1); + int num_input = weight_data_size / num_output; + + fprintf(pp, " 0=%d", num_output); + fprintf(pp, " 1=0"); + fprintf(pp, " 2=%d", weight_data_size); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + // reorder num_input-num_output to num_output-num_input + { + const float* bptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + + for (int j = 0; j < num_output; j++) + { + for (int k = 0; k < num_input; k++) + { + float vb = bptr[k * num_output + j]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + + // fwrite_tensor_proto_data(B, bp) + } + else + { + // default matrix multiplication + } + } + else if (op == "Max") + { + int op_type = 4; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "Min") + { + int op_type = 5; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "Mul") + { + int op_type = 2; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "MultiHeadAttention") + { + int embed_dim = get_node_attr_i(node, "embed_dim", 0); + int num_heads = get_node_attr_i(node, "num_heads", 0); + + fprintf(pp, " 0=%d", embed_dim); + fprintf(pp, " 1=%d", num_heads); + + if (node.input_size() == 5) + { + const onnx::TensorProto& qkvw = weights[node.input(1)]; + const onnx::TensorProto& qkvb = weights[node.input(2)]; + const onnx::TensorProto& ow = weights[node.input(3)]; + const onnx::TensorProto& ob = weights[node.input(4)]; + + int weight_data_size = get_tensor_proto_data_size(ow); + + fprintf(pp, " 2=%d", weight_data_size); + + int quantize_tag = 0; + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose qw + { + const float* wptr = + qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); + const float* bptr = + qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim * 3 + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + + fwrite(bptr, sizeof(float), embed_dim, bp); + } + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose kw + { + const float* wptr = + qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); + const float* bptr = + qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); + bptr += embed_dim; + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim * 3 + k + embed_dim]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + + fwrite(bptr, sizeof(float), embed_dim, bp); + } + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose vw + { + const float* wptr = + qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); + const float* bptr = + qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); + bptr += embed_dim * 2; + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim * 3 + k + embed_dim * 2]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + + fwrite(bptr, sizeof(float), embed_dim, bp); + } + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose ow + { + const float* wptr = + ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(ob, bp); + } + else + { + const onnx::TensorProto& qw = weights[node.input(3)]; + const onnx::TensorProto& qb = weights[node.input(4)]; + const onnx::TensorProto& kw = weights[node.input(5)]; + const onnx::TensorProto& kb = weights[node.input(6)]; + const onnx::TensorProto& vw = weights[node.input(7)]; + const onnx::TensorProto& vb = weights[node.input(8)]; + const onnx::TensorProto& ow = weights[node.input(9)]; + const onnx::TensorProto& ob = weights[node.input(10)]; + + int weight_data_size = get_tensor_proto_data_size(qw); + + fprintf(pp, " 2=%d", weight_data_size); + + int quantize_tag = 0; + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose qw + { + const float* wptr = + qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(qb, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose kw + { + const float* wptr = + kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(kb, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose vw + { + const float* wptr = + vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(vb, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose ow + { + const float* wptr = + ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(ob, bp); + } + } + else if (op == "Neg") + { + int op_type = 1; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "NonMaxSuppression") + { + int max_dets = 0; + float iou_thre = 0.f; + float score_thre = 0.f; + // fprintf(stderr, "%s\n", node.name().c_str()); + // fprintf(stderr, "node.input_size(): %d\n", node.input_size()); + if (node.input_size() >= 3) + { + // fprintf(stderr, "ok12!\n"); + max_dets = (int)(get_node_attr_from_input(weights[node.input(2)]) + 0.5); + } + if (node.input_size() >= 4) + { + // fprintf(stderr, "iou_thre: %f\n", + // get_node_attr_from_input(weights[node.input(3)])); + iou_thre = get_node_attr_from_input(weights[node.input(3)]); + } + if (node.input_size() >= 5) + { + // fprintf(stderr, "score_thre: %f\n", + // get_node_attr_from_input(weights[node.input(4)])); + score_thre = get_node_attr_from_input(weights[node.input(4)]); + } + fprintf(pp, " 0=%d", max_dets); + fprintf(pp, " 1=%f", iou_thre); + fprintf(pp, " 2=%f", score_thre); + } + else if (op == "Normalize") + { + float eps = get_node_attr_f(node, "eps", 0.f); + int scale_data_size = 1; + + fprintf(pp, " 1=1"); // channel_shared + fprintf(pp, " 2=%e", eps); + fprintf(pp, " 3=%d", scale_data_size); + fprintf(pp, " 9=1"); // TODO hardcode pytorch style + + const float scale_data[1] = {1.f}; + fwrite(scale_data, sizeof(float), 1, bp); + } + else if (op == "Pad") + { + std::string mode = get_node_attr_s(node, "mode"); + float value = get_node_attr_f(node, "value", 0.f); + + std::vector pads; + if (node.input_size() == 1) + { + pads = get_node_attr_ai(node, "pads"); + } + else + { + pads = get_node_attr_from_input_ai(weights[node.input(1)]); + } + int type = 0; + if (mode == "constant") + { + type = 0; + } + else if (mode == "edge") + { + type = 1; + } + else if (mode == "reflect") + { + type = 2; + } + + int pad_size = (int)pads.size(); + int top = 0; + int bottom = 0; + int left = 0; + int right = 0; + int front = 0; + int behind = 0; + if (pad_size == 8) + { + // NCHW + top = pads[2]; + bottom = pads[6]; + left = pads[3]; + right = pads[7]; + front = pads[1]; + behind = pads[5]; + } + else if (pad_size == 6) + { + // NHW + top = pads[1]; + bottom = pads[4]; + left = pads[2]; + right = pads[5]; + } + else + { + // NW + left = pads[1]; + right = pads[3]; + } + + fprintf(pp, " 0=%d", top); + fprintf(pp, " 1=%d", bottom); + fprintf(pp, " 2=%d", left); + fprintf(pp, " 3=%d", right); + fprintf(pp, " 4=%d", type); + fprintf(pp, " 5=%e", value); + fprintf(pp, " 7=%d", front); + fprintf(pp, " 8=%d", behind); + } + else if (op == "Pow") + { + int op_type = 6; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "PriorBox") + { + std::vector min_sizes = get_node_attr_af(node, "min_sizes"); + std::vector max_sizes = get_node_attr_af(node, "max_sizes"); + std::vector aspect_ratios = get_node_attr_af(node, "aspect_ratios"); + fprintf(pp, " -23300=%zu", min_sizes.size()); + for (size_t j = 0; j < min_sizes.size(); ++j) + { + fprintf(pp, ",%f", min_sizes[j]); + } + fprintf(pp, " -23301=%zu", max_sizes.size()); + for (size_t j = 0; j < max_sizes.size(); ++j) + { + fprintf(pp, ",%f", max_sizes[j]); + } + fprintf(pp, " -23302=%zu", aspect_ratios.size()); + for (size_t j = 0; j < aspect_ratios.size(); ++j) + { + fprintf(pp, ",%f", aspect_ratios[j]); + } + int image_width = get_node_attr_i(node, "image_width"); + int image_height = get_node_attr_i(node, "image_height"); + float step_width = get_node_attr_f(node, "step_width"); + float step_height = get_node_attr_f(node, "step_height"); + float offset = get_node_attr_f(node, "offset"); + int step_mmdetection = get_node_attr_i(node, "step_mmdetection"); + fprintf(pp, " 9=%d", image_width); + fprintf(pp, " 10=%d", image_height); + fprintf(pp, " 11=%f", step_width); + fprintf(pp, " 12=%f", step_height); + fprintf(pp, " 13=%f", offset); + fprintf(pp, " 14=%d", step_mmdetection); + } + else if (op == "PixelShuffle") + { + int scale_factor = get_node_attr_i(node, "scale_factor", 1); + fprintf(pp, " 0=%d", scale_factor); + } + else if (op == "PRelu") + { + const onnx::TensorProto& slope = weights[node.input(1)]; + + int num_slope = get_tensor_proto_data_size(slope); + + fprintf(pp, " 0=%d", num_slope); + + fwrite_tensor_proto_data(slope, bp); + } + else if (op == "Reciprocal") + { + int op_type = 15; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || + op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || + op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") + { + int op_type = -233; + if (op == "ReduceSum") + op_type = 0; + else if (op == "ReduceSumSquare") + op_type = 2; + else if (op == "ReduceMean") + op_type = 3; + else if (op == "ReduceMax") + op_type = 4; + else if (op == "ReduceMin") + op_type = 5; + else if (op == "ReduceProd") + op_type = 6; + else if (op == "ReduceL1") + op_type = 7; + else if (op == "ReduceL2") + op_type = 8; + else if (op == "ReduceLogSum") + op_type = 9; + else if (op == "ReduceLogSumExp") + op_type = 10; + fprintf(pp, " 0=%d", op_type); + + std::vector axes = get_node_attr_ai(node, "axes"); + int keepdims = get_node_attr_i(node, "keepdims", 1); + + if (axes.size() > 0) + { + // if axes set, reduce according to axes + fprintf(pp, " 1=%d", 0); + fprintf(pp, " -23303=%zu", axes.size()); + for (size_t j = 0; j < axes.size(); j++) + { + if (axes[j] == 0 || axes[j] > 4 || axes[j] < -3) + fprintf(stderr, "Unsupported reduction axes !\n"); + fprintf(pp, ",%d", axes[j] > 0 ? axes[j] - 1 : axes[j]); + } + } + else + { + // if axes not set, reduce all axes by default + fprintf(pp, " 1=%d", 1); + } + fprintf(pp, " 4=%d", keepdims); + fprintf(pp, " 5=1"); + } + else if (op == "Reorg") + { + int stride = get_node_attr_i(node, "stride", 1); + fprintf(pp, " 0=%d", stride); + } + else if (op == "Reshape") + { + std::vector shape; + + if (node.input_size() == 1) + { + shape = get_node_attr_ai(node, "shape"); + } + else if (weights.find(node.input(1)) != weights.end()) + { + shape = get_node_attr_from_input_ai(weights[node.input(1)]); + } + else + { + fprintf(stderr, "Unsupported reshape weight ! \n"); + } + + if (shape.size() == 1) + { + fprintf(pp, " 0=%d", shape[0]); // should never reach here + } + else if (shape.size() == 2) + { + fprintf(pp, " 0=%d", shape[1]); + } + else if (shape.size() == 3) + { + fprintf(pp, " 0=%d", shape[2]); + fprintf(pp, " 1=%d", shape[1]); + } + else if (shape.size() == 4) + { + fprintf(pp, " 0=%d", shape[3]); + fprintf(pp, " 1=%d", shape[2]); + fprintf(pp, " 2=%d", shape[1]); + } + else if (shape.size() == 5) + { + fprintf(pp, " 0=%d", shape[4] * shape[3]); + fprintf(pp, " 1=%d", shape[2]); + fprintf(pp, " 2=%d", shape[1]); + } + } + else if (op == "Resize") + { + std::string mode = get_node_attr_s(node, "mode"); + std::string align = get_node_attr_s(node, "coordinate_transformation_mode"); + + std::vector scales; + std::vector sizes; + if (node.input_size() == 2) + { + // opset 10 + scales = get_node_attr_from_input_af(weights[node.input(1)]); + } + else + { + // opset 11+ + scales = get_node_attr_from_input_af(weights[node.input(2)]); + if (node.input_size() >= 4) + { + sizes = get_node_attr_from_input_ai(weights[node.input(3)]); + } + } + + int resize_type = 1; + if (mode == "nearest") + { + resize_type = 1; + } + else if (mode == "linear") + { + resize_type = 2; + } + else if (mode == "cubic") + { + resize_type = 3; + } + + if (scales.empty() && sizes.empty()) + { + fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n"); + } + + float h_scale = 1.f; + float w_scale = 1.f; + if (scales.size() == 2) + { + w_scale = scales[1]; + } + else if (scales.size() == 3) + { + h_scale = scales[1]; + w_scale = scales[2]; + } + else if (scales.size() == 4) + { + h_scale = scales[2]; + w_scale = scales[3]; + + if (scales[1] != 1.f) fprintf(stderr, "Unsupported Resize scales !\n"); + } + + int output_height = 0; + int output_width = 0; + if (sizes.size() == 2) + { + output_width = sizes[1]; + } + else if (sizes.size() == 3) + { + output_height = sizes[1]; + output_width = sizes[2]; + } + else if (sizes.size() == 4) + { + output_height = sizes[2]; + output_width = sizes[3]; + } + + int align_corner = 0; + if (align == "align_corners") + { + align_corner = 1; + } + + fprintf(pp, " 0=%d", resize_type); + fprintf(pp, " 1=%e", h_scale); + fprintf(pp, " 2=%e", w_scale); + fprintf(pp, " 3=%d", output_height); + fprintf(pp, " 4=%d", output_width); + fprintf(pp, " 6=%d", align_corner); + } + else if (op == "RNN") + { + const onnx::TensorProto& W = weights[node.input(1)]; + const onnx::TensorProto& R = weights[node.input(2)]; + const onnx::TensorProto& B = weights[node.input(3)]; + + int hidden_size = get_node_attr_i(node, "hidden_size", 0); + std::string direction = get_node_attr_s(node, "direction"); + + int direction_type = 0; + if (direction == "forward") + { + direction_type = 0; + } + else if (direction == "reverse") + { + direction_type = 1; + } + else if (direction == "bidirectional") + { + direction_type = 2; + } + + int weight_data_size = get_tensor_proto_data_size(W); + + fprintf(pp, " 0=%d", hidden_size); + fprintf(pp, " 1=%d", weight_data_size); + fprintf(pp, " 2=%d", direction_type); + + int num_directions = direction_type == 2 ? 2 : 1; + + int quantize_tag = 0; + + fwrite(&quantize_tag, sizeof(int), 1, bp); + fwrite_tensor_proto_data(W, bp); + + // reduce xc and hc bias + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions; + const float* bptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + const float* xiptr = bptr; + const float* hiptr = bptr + bias_data_size_g; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xiptr[j] + hiptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + + if (direction_type == 2) + { + xiptr += bias_data_size_g * 2; + hiptr += bias_data_size_g * 2; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xiptr[j] + hiptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + + fwrite(&quantize_tag, sizeof(int), 1, bp); + fwrite_tensor_proto_data(R, bp); + } + else if (op == "RDiv") + { + int op_type = 8; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "RSub") + { + int op_type = 7; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "RoiAlign") + { + int pooled_width = get_node_attr_i(node, "output_width", 1); + int pooled_height = get_node_attr_i(node, "output_height", 1); + float spatial_scale = get_node_attr_f(node, "spatial_scale", 1.f); + int sampling_ratio = get_node_attr_i(node, "sampling_ratio", 0); + fprintf(pp, " 0=%d", pooled_width); + fprintf(pp, " 1=%d", pooled_height); + fprintf(pp, " 2=%f", spatial_scale); + fprintf(pp, " 3=%d", sampling_ratio); + } + else if (op == "ShuffleChannel") + { + int group = get_node_attr_i(node, "group", 1); + int reverse = get_node_attr_i(node, "reverse", 0); + fprintf(pp, " 0=%d", group); + fprintf(pp, " 1=%d", reverse); + } + else if (op == "Sigmoid") + { + // no param + } + else if (op == "Sin") + { + int op_type = 9; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "SkipLayerNormalization") + { + const onnx::TensorProto& W = weights[node.input(2)]; + const onnx::TensorProto& B = weights[node.input(3)]; + const onnx::TensorProto& B2 = weights[node.input(4)]; + + fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(W, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(B, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(B2, bp); + } + else if (op == "Slice") + { + bool use_crop = true; + + std::vector starts; + std::vector ends; + std::vector axes; + std::vector steps; + if (node.input_size() == 1) + { + starts = get_node_attr_ai(node, "starts"); + ends = get_node_attr_ai(node, "ends"); + axes = get_node_attr_ai(node, "axes"); + steps = get_node_attr_ai(node, "steps"); // TODO + } + else + { + starts = get_node_attr_from_input_ai(weights[node.input(1)]); + ends = get_node_attr_from_input_ai(weights[node.input(2)]); + if (node.input_size() >= 4) axes = get_node_attr_from_input_ai(weights[node.input(3)]); + if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]); + } + + // assert step == 1 or step >= ends + for (int i = 0; i < (int)steps.size(); i++) + { + if (steps[i] != 1 && steps[i] < ends[i]) + { + use_crop = false; + fprintf(stderr, "Unsupported slice step ! Use custom TensorSlice\n"); + } + } + + if (use_crop) + { + // filter out N-dim axis + if (!axes.empty()) + { + for (int i = 0; i < (int)axes.size(); i++) + { + int axis = axes[i]; + if (axis == 0) + { + starts.erase(starts.begin() + i); + ends.erase(ends.begin() + i); + axes.erase(axes.begin() + i); + break; + } + } + } + + fprintf(pp, " -23309=%d", (int)starts.size()); + for (int i = 0; i < (int)starts.size(); i++) + { + fprintf(pp, ",%d", starts[i]); + } + fprintf(pp, " -23310=%d", (int)ends.size()); + for (int i = 0; i < (int)ends.size(); i++) + { + fprintf(pp, ",%d", ends[i]); + } + if (!axes.empty()) + { + fprintf(pp, " -23311=%d", (int)axes.size()); + for (int i = 0; i < (int)axes.size(); i++) + { + int axis = axes[i]; + if (axis == 0 || axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n"); + + if (axis > 0) axis = axis - 1; // -1 for skip N-dim + + fprintf(pp, ",%d", axis); + } + } + } + else + { + fprintf(pp, " -23300=%d", (int)starts.size()); + for (int i = 0; i < (int)starts.size(); i++) + { + fprintf(pp, ",%d", starts[i]); + } + fprintf(pp, " -23301=%d", (int)ends.size()); + for (int i = 0; i < (int)ends.size(); i++) + { + fprintf(pp, ",%d", ends[i]); + } + if (!axes.empty()) + { + fprintf(pp, " -23302=%d", (int)axes.size()); + for (int i = 0; i < (int)axes.size(); i++) + { + int axis = axes[i]; + if (axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n"); + fprintf(pp, ",%d", axis); + } + } + if (!steps.empty()) + { + fprintf(pp, " -23303=%d", (int)steps.size()); + for (int i = 0; i < (int)steps.size(); i++) + { + int step = steps[i]; + if (step == 0) fprintf(stderr, "Unsupported slice step ! Unsupported slice step\n"); + fprintf(pp, ",%d", step); + } + } + } + } + else if (op == "Softmax") + { + int axis = get_node_attr_i(node, "axis", 1); + fprintf(pp, " 0=%d", axis - 1); + fprintf(pp, " 1=1"); + } + else if (op == "Split") + { + int axis = get_node_attr_i(node, "axis", 0); + std::vector split = get_node_attr_ai(node, "split"); + if (axis < 1) fprintf(stderr, "Unsupported split axis !\n"); + + fprintf(pp, " -23300=%d", output_size); + if (split.empty()) + { + for (int i = 0; i < output_size; i++) + { + fprintf(pp, ",-233"); + } + } + else + { + for (size_t i = 0; i < split.size() - 1; i++) + { + fprintf(pp, ",%d", split[i]); + } + fprintf(pp, ",-233"); + } + fprintf(pp, " 1=%d", axis - 1); + } + else if (op == "Sqrt") + { + int op_type = 5; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Squeeze") + { + std::vector axes = get_node_attr_ai(node, "axes"); + + if (axes.empty()) + { + fprintf(pp, " 0=1"); + fprintf(pp, " 1=1"); + fprintf(pp, " 2=1"); + } + else + { + bool flag = true; + for (int i = 0; i < (int)axes.size(); i++) + { + if (axes[i] == 0) + { + flag = false; + break; + } + } + if (flag == true) + { + fprintf(pp, " -23303=%zu", axes.size()); + for (int i = 0; i < (int)axes.size(); i++) + { + if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3) + fprintf(stderr, "Unsupported squeeze axes !: %d, %s\n", axes[i], node.name().c_str()); + fprintf(pp, ",%d", axes[i] - 1); + } + } + } + } + else if (op == "Sub") + { + int op_type = 1; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "Sum") + { + int op_type = 1; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Swish") + { + // no param + } + else if (op == "Tan") + { + int op_type = 11; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Tanh") + { + int op_type = 16; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "TopK") + { + int axis = get_node_attr_i(node, "axis", -1); + axis = axis > 0 ? axis - 1 : axis; + int largest = get_node_attr_i(node, "largest", 1); + int sorted = get_node_attr_i(node, "sorted", 1); + fprintf(pp, " 0=%d", axis); + fprintf(pp, " 1=%d", largest); + fprintf(pp, " 2=%d", sorted); + } + else if (op == "Transpose") + { + std::vector perm = get_node_attr_ai(node, "perm"); + + if (perm.size() == 3) + { + if (perm[1] == 1 && perm[2] == 2) + fprintf(pp, " 0=0"); // w h + else if (perm[1] == 2 && perm[2] == 1) + fprintf(pp, " 0=1"); // h w + else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2) + fprintf(pp, " 0=0"); // w h + else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1) + fprintf(pp, " 0=1"); // h w + } + else if (perm.size() == 4) + { + if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3) + fprintf(pp, " 0=0"); // w h c + else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2) + fprintf(pp, " 0=1"); // h w c + else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3) + fprintf(pp, " 0=2"); // w c h + else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1) + fprintf(pp, " 0=3"); // c w h + else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2) + fprintf(pp, " 0=4"); // h c w + else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1) + fprintf(pp, " 0=5"); // c h w + } + else if (perm.size() == 5) + { + if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4) + fprintf(pp, " 0=0"); // wx h c + else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2) + fprintf(pp, " 0=1"); // h wx c + else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4) + fprintf(pp, " 0=2"); // wx c h + else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1) + fprintf(pp, " 0=3"); // c wx h + else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2) + fprintf(pp, " 0=4"); // h c wx + else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1) + fprintf(pp, " 0=5"); // c h wx + else + fprintf(stderr, "Unsupported transpose type !\n"); + } + } + else if (op == "Upsample") + { + std::string mode = get_node_attr_s(node, "mode"); + std::string align = get_node_attr_s(node, "coordinate_transformation_mode"); + + std::vector scales; + + if (node.input_size() == 1) + { + scales = get_node_attr_af(node, "scales"); + } + else + { + scales = get_node_attr_from_input_af(weights[node.input(1)]); + } + + int resize_type = 1; + if (mode == "nearest") + { + resize_type = 1; + } + else if (mode == "bilinear" || mode == "linear") + { + resize_type = 2; + } + else if (mode == "trilinear") + { + fprintf(stderr, "Unsupported Upsample mode !\n"); + } + + float h_scale = 1.f; + float w_scale = 1.f; + if (scales.size() == 2) + { + w_scale = scales[1]; + } + else if (scales.size() == 3) + { + h_scale = scales[1]; + w_scale = scales[2]; + } + else if (scales.size() == 4) + { + h_scale = scales[2]; + w_scale = scales[3]; + + if (scales[1] != 1.f) fprintf(stderr, "Unsupported Upsample scales !\n"); + } + else + { + fprintf(stderr, "Unsupported Upsample scales !\n"); + } + + int align_corner = 0; + if (align == "align_corners") + { + align_corner = 1; + } + + fprintf(pp, " 0=%d", resize_type); + fprintf(pp, " 1=%e", h_scale); + fprintf(pp, " 2=%e", w_scale); + fprintf(pp, " 6=%d", align_corner); + } + else if (op == "Unsqueeze") + { + std::vector axes = get_node_attr_ai(node, "axes"); + bool flag = true; + for (int i = 0; i < (int)axes.size(); i++) + { + if (axes[i] == 0) + { + flag = false; + break; + } + } + if (flag) + { + fprintf(pp, " -23303=%zu", axes.size()); + for (int i = 0; i < (int)axes.size(); i++) + { + if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4) + fprintf(stderr, "Unsupported unsqueeze axes !: %d, %s\n", axes[i], node.name().c_str()); + fprintf(pp, ",%d", axes[i] - 1); + } + } + } + else if (op == "Yolov3DetectionOutput") + { + int num_class = get_node_attr_i(node, "num_class"); + int num_box = get_node_attr_i(node, "num_box"); + float confidence_threshold = get_node_attr_f(node, "confidence_threshold"); + float nms_threshold = get_node_attr_f(node, "nms_threshold"); + fprintf(pp, " 0=%d", num_class); + fprintf(pp, " 1=%d", num_box); + fprintf(pp, " 2=%e", confidence_threshold); + fprintf(pp, " 3=%e", nms_threshold); + std::vector biases = get_node_attr_af(node, "biases"); + if (biases.size() > 0) + { + fprintf(pp, " -23304=%zu", biases.size()); + for (int i = 0; i < (int)biases.size(); i++) + { + fprintf(pp, ",%e", biases[i]); + } + } + std::vector mask = get_node_attr_af(node, "mask"); + if (mask.size() > 0) + { + fprintf(pp, " -23305=%zu", mask.size()); + for (int i = 0; i < (int)mask.size(); i++) + { + fprintf(pp, ",%e", mask[i]); + } + } + std::vector anchors_scale = get_node_attr_af(node, "anchors_scale"); + if (anchors_scale.size() > 0) + { + fprintf(pp, " -23306=%zu", anchors_scale.size()); + for (int i = 0; i < (int)anchors_scale.size(); i++) + { + fprintf(pp, ",%e", anchors_scale[i]); + } + } + } + else + { + // TODO op specific param + } + + fprintf(pp, "\n"); + for (int j = 0; j < output_size; j++) + { + const std::string& output_name = node.output(j); + if (node_reference.find(output_name) != node_reference.end()) + { + int refcount = node_reference[output_name]; + if (refcount > 1) + { + char splitname[256]; + sprintf(splitname, "splitncnn_%d", internal_split); + fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); + + fprintf(pp, " %s", output_name.c_str()); + + for (int k = 0; k < refcount; k++) + { + fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k); + } + fprintf(pp, "\n"); + + internal_split++; + } } - } - } - fwrite_tensor_proto_data(qb, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose kw - { - const float* wptr = - kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - fwrite_tensor_proto_data(kb, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose vw - { - const float* wptr = - vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - fwrite_tensor_proto_data(vb, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose ow - { - const float* wptr = - ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - fwrite_tensor_proto_data(ob, bp); - } - } else if (op == "Neg") { - int op_type = 1; - fprintf(pp, " 0=%d", op_type); - } else if (op == "NonMaxSuppression") { - int max_dets = 0; - float iou_thre = 0.f; - float score_thre = 0.f; - // fprintf(stderr, "%s\n", node.name().c_str()); - // fprintf(stderr, "node.input_size(): %d\n", node.input_size()); - if (node.input_size() >= 3) { - // fprintf(stderr, "ok12!\n"); - max_dets = (int)(get_node_attr_from_input(weights[node.input(2)]) + 0.5); - } - if (node.input_size() >= 4) { - // fprintf(stderr, "iou_thre: %f\n", - // get_node_attr_from_input(weights[node.input(3)])); - iou_thre = get_node_attr_from_input(weights[node.input(3)]); - } - if (node.input_size() >= 5) { - // fprintf(stderr, "score_thre: %f\n", - // get_node_attr_from_input(weights[node.input(4)])); - score_thre = get_node_attr_from_input(weights[node.input(4)]); - } - fprintf(pp, " 0=%d", max_dets); - fprintf(pp, " 1=%f", iou_thre); - fprintf(pp, " 2=%f", score_thre); - } else if (op == "Normalize") { - float eps = get_node_attr_f(node, "eps", 0.f); - int scale_data_size = 1; - - fprintf(pp, " 1=1"); // channel_shared - fprintf(pp, " 2=%e", eps); - fprintf(pp, " 3=%d", scale_data_size); - fprintf(pp, " 9=1"); // TODO hardcode pytorch style - - const float scale_data[1] = {1.f}; - fwrite(scale_data, sizeof(float), 1, bp); - } else if (op == "Pad") { - std::string mode = get_node_attr_s(node, "mode"); - float value = get_node_attr_f(node, "value", 0.f); - - std::vector pads; - if (node.input_size() == 1) { - pads = get_node_attr_ai(node, "pads"); - } else { - pads = get_node_attr_from_input_ai(weights[node.input(1)]); - } - int type = 0; - if (mode == "constant") { - type = 0; - } else if (mode == "edge") { - type = 1; - } else if (mode == "reflect") { - type = 2; - } - - int pad_size = (int)pads.size(); - int top = 0; - int bottom = 0; - int left = 0; - int right = 0; - int front = 0; - int behind = 0; - if (pad_size == 8) { - // NCHW - top = pads[2]; - bottom = pads[6]; - left = pads[3]; - right = pads[7]; - front = pads[1]; - behind = pads[5]; - } else if (pad_size == 6) { - // NHW - top = pads[1]; - bottom = pads[4]; - left = pads[2]; - right = pads[5]; - } else { - // NW - left = pads[1]; - right = pads[3]; - } - - fprintf(pp, " 0=%d", top); - fprintf(pp, " 1=%d", bottom); - fprintf(pp, " 2=%d", left); - fprintf(pp, " 3=%d", right); - fprintf(pp, " 4=%d", type); - fprintf(pp, " 5=%e", value); - fprintf(pp, " 7=%d", front); - fprintf(pp, " 8=%d", behind); - } else if (op == "Pow") { - int op_type = 6; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "PriorBox") { - std::vector min_sizes = get_node_attr_af(node, "min_sizes"); - std::vector max_sizes = get_node_attr_af(node, "max_sizes"); - std::vector aspect_ratios = get_node_attr_af(node, "aspect_ratios"); - fprintf(pp, " -23300=%zu", min_sizes.size()); - for (size_t j = 0; j < min_sizes.size(); ++j) { - fprintf(pp, ",%f", min_sizes[j]); - } - fprintf(pp, " -23301=%zu", max_sizes.size()); - for (size_t j = 0; j < max_sizes.size(); ++j) { - fprintf(pp, ",%f", max_sizes[j]); - } - fprintf(pp, " -23302=%zu", aspect_ratios.size()); - for (size_t j = 0; j < aspect_ratios.size(); ++j) { - fprintf(pp, ",%f", aspect_ratios[j]); - } - int image_width = get_node_attr_i(node, "image_width"); - int image_height = get_node_attr_i(node, "image_height"); - float step_width = get_node_attr_f(node, "step_width"); - float step_height = get_node_attr_f(node, "step_height"); - float offset = get_node_attr_f(node, "offset"); - int step_mmdetection = get_node_attr_i(node, "step_mmdetection"); - fprintf(pp, " 9=%d", image_width); - fprintf(pp, " 10=%d", image_height); - fprintf(pp, " 11=%f", step_width); - fprintf(pp, " 12=%f", step_height); - fprintf(pp, " 13=%f", offset); - fprintf(pp, " 14=%d", step_mmdetection); - } else if (op == "PixelShuffle") { - int scale_factor = get_node_attr_i(node, "scale_factor", 1); - fprintf(pp, " 0=%d", scale_factor); - } else if (op == "PRelu") { - const onnx::TensorProto& slope = weights[node.input(1)]; - - int num_slope = get_tensor_proto_data_size(slope); - - fprintf(pp, " 0=%d", num_slope); - - fwrite_tensor_proto_data(slope, bp); - } else if (op == "Reciprocal") { - int op_type = 15; - fprintf(pp, " 0=%d", op_type); - } else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || - op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || - op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") { - int op_type = -233; - if (op == "ReduceSum") - op_type = 0; - else if (op == "ReduceSumSquare") - op_type = 2; - else if (op == "ReduceMean") - op_type = 3; - else if (op == "ReduceMax") - op_type = 4; - else if (op == "ReduceMin") - op_type = 5; - else if (op == "ReduceProd") - op_type = 6; - else if (op == "ReduceL1") - op_type = 7; - else if (op == "ReduceL2") - op_type = 8; - else if (op == "ReduceLogSum") - op_type = 9; - else if (op == "ReduceLogSumExp") - op_type = 10; - fprintf(pp, " 0=%d", op_type); - - std::vector axes = get_node_attr_ai(node, "axes"); - int keepdims = get_node_attr_i(node, "keepdims", 1); - - if (axes.size() > 0) { - // if axes set, reduce according to axes - fprintf(pp, " 1=%d", 0); - fprintf(pp, " -23303=%zu", axes.size()); - for (size_t j = 0; j < axes.size(); j++) { - if (axes[j] == 0 || axes[j] > 4 || axes[j] < -3) - fprintf(stderr, "Unsupported reduction axes !\n"); - fprintf(pp, ",%d", axes[j] > 0 ? axes[j] - 1 : axes[j]); - } - } else { - // if axes not set, reduce all axes by default - fprintf(pp, " 1=%d", 1); - } - fprintf(pp, " 4=%d", keepdims); - fprintf(pp, " 5=1"); - } else if (op == "Reorg") { - int stride = get_node_attr_i(node, "stride", 1); - fprintf(pp, " 0=%d", stride); - } else if (op == "Reshape") { - std::vector shape; - - if (node.input_size() == 1) { - shape = get_node_attr_ai(node, "shape"); - } else if (weights.find(node.input(1)) != weights.end()) { - shape = get_node_attr_from_input_ai(weights[node.input(1)]); - } else { - fprintf(stderr, "Unsupported reshape weight ! \n"); - } - - if (shape.size() == 1) { - fprintf(pp, " 0=%d", shape[0]); // should never reach here - } else if (shape.size() == 2) { - fprintf(pp, " 0=%d", shape[1]); - } else if (shape.size() == 3) { - fprintf(pp, " 0=%d", shape[2]); - fprintf(pp, " 1=%d", shape[1]); - } else if (shape.size() == 4) { - fprintf(pp, " 0=%d", shape[3]); - fprintf(pp, " 1=%d", shape[2]); - fprintf(pp, " 2=%d", shape[1]); - } else if (shape.size() == 5) { - fprintf(pp, " 0=%d", shape[4] * shape[3]); - fprintf(pp, " 1=%d", shape[2]); - fprintf(pp, " 2=%d", shape[1]); - } - } else if (op == "Resize") { - std::string mode = get_node_attr_s(node, "mode"); - std::string align = get_node_attr_s(node, "coordinate_transformation_mode"); - - std::vector scales; - std::vector sizes; - if (node.input_size() == 2) { - // opset 10 - scales = get_node_attr_from_input_af(weights[node.input(1)]); - } else { - // opset 11+ - scales = get_node_attr_from_input_af(weights[node.input(2)]); - if (node.input_size() >= 4) { - sizes = get_node_attr_from_input_ai(weights[node.input(3)]); - } - } - - int resize_type = 1; - if (mode == "nearest") { - resize_type = 1; - } else if (mode == "linear") { - resize_type = 2; - } else if (mode == "cubic") { - resize_type = 3; - } - - if (scales.empty() && sizes.empty()) { - fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n"); - } - - float h_scale = 1.f; - float w_scale = 1.f; - if (scales.size() == 2) { - w_scale = scales[1]; - } else if (scales.size() == 3) { - h_scale = scales[1]; - w_scale = scales[2]; - } else if (scales.size() == 4) { - h_scale = scales[2]; - w_scale = scales[3]; - - if (scales[1] != 1.f) fprintf(stderr, "Unsupported Resize scales !\n"); - } - - int output_height = 0; - int output_width = 0; - if (sizes.size() == 2) { - output_width = sizes[1]; - } else if (sizes.size() == 3) { - output_height = sizes[1]; - output_width = sizes[2]; - } else if (sizes.size() == 4) { - output_height = sizes[2]; - output_width = sizes[3]; - } - - int align_corner = 0; - if (align == "align_corners") { - align_corner = 1; - } - - fprintf(pp, " 0=%d", resize_type); - fprintf(pp, " 1=%e", h_scale); - fprintf(pp, " 2=%e", w_scale); - fprintf(pp, " 3=%d", output_height); - fprintf(pp, " 4=%d", output_width); - fprintf(pp, " 6=%d", align_corner); - } else if (op == "RNN") { - const onnx::TensorProto& W = weights[node.input(1)]; - const onnx::TensorProto& R = weights[node.input(2)]; - const onnx::TensorProto& B = weights[node.input(3)]; - - int hidden_size = get_node_attr_i(node, "hidden_size", 0); - std::string direction = get_node_attr_s(node, "direction"); - - int direction_type = 0; - if (direction == "forward") { - direction_type = 0; - } else if (direction == "reverse") { - direction_type = 1; - } else if (direction == "bidirectional") { - direction_type = 2; - } - - int weight_data_size = get_tensor_proto_data_size(W); - - fprintf(pp, " 0=%d", hidden_size); - fprintf(pp, " 1=%d", weight_data_size); - fprintf(pp, " 2=%d", direction_type); - - int num_directions = direction_type == 2 ? 2 : 1; - - int quantize_tag = 0; - - fwrite(&quantize_tag, sizeof(int), 1, bp); - fwrite_tensor_proto_data(W, bp); - - // reduce xc and hc bias - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions; - const float* bptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - const float* xiptr = bptr; - const float* hiptr = bptr + bias_data_size_g; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xiptr[j] + hiptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - - if (direction_type == 2) { - xiptr += bias_data_size_g * 2; - hiptr += bias_data_size_g * 2; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xiptr[j] + hiptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - - fwrite(&quantize_tag, sizeof(int), 1, bp); - fwrite_tensor_proto_data(R, bp); - } else if (op == "RDiv") { - int op_type = 8; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "RSub") { - int op_type = 7; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "RoiAlign") { - int pooled_width = get_node_attr_i(node, "output_width", 1); - int pooled_height = get_node_attr_i(node, "output_height", 1); - float spatial_scale = get_node_attr_f(node, "spatial_scale", 1.f); - int sampling_ratio = get_node_attr_i(node, "sampling_ratio", 0); - fprintf(pp, " 0=%d", pooled_width); - fprintf(pp, " 1=%d", pooled_height); - fprintf(pp, " 2=%f", spatial_scale); - fprintf(pp, " 3=%d", sampling_ratio); - } else if (op == "ShuffleChannel") { - int group = get_node_attr_i(node, "group", 1); - int reverse = get_node_attr_i(node, "reverse", 0); - fprintf(pp, " 0=%d", group); - fprintf(pp, " 1=%d", reverse); - } else if (op == "Sigmoid") { - // no param - } else if (op == "Sin") { - int op_type = 9; - fprintf(pp, " 0=%d", op_type); - } else if (op == "SkipLayerNormalization") { - const onnx::TensorProto& W = weights[node.input(2)]; - const onnx::TensorProto& B = weights[node.input(3)]; - const onnx::TensorProto& B2 = weights[node.input(4)]; - - fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(W, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(B, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(B2, bp); - } else if (op == "Slice") { - bool use_crop = true; - - std::vector starts; - std::vector ends; - std::vector axes; - std::vector steps; - if (node.input_size() == 1) { - starts = get_node_attr_ai(node, "starts"); - ends = get_node_attr_ai(node, "ends"); - axes = get_node_attr_ai(node, "axes"); - steps = get_node_attr_ai(node, "steps"); // TODO - } else { - starts = get_node_attr_from_input_ai(weights[node.input(1)]); - ends = get_node_attr_from_input_ai(weights[node.input(2)]); - if (node.input_size() >= 4) axes = get_node_attr_from_input_ai(weights[node.input(3)]); - if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]); - } - - // assert step == 1 or step >= ends - for (int i = 0; i < (int)steps.size(); i++) { - if (steps[i] != 1 && steps[i] < ends[i]) { - use_crop = false; - fprintf(stderr, "Unsupported slice step ! Use custom TensorSlice\n"); - } - } - - if (use_crop) { - // filter out N-dim axis - if (!axes.empty()) { - for (int i = 0; i < (int)axes.size(); i++) { - int axis = axes[i]; - if (axis == 0) { - starts.erase(starts.begin() + i); - ends.erase(ends.begin() + i); - axes.erase(axes.begin() + i); - break; - } - } - } - - fprintf(pp, " -23309=%d", (int)starts.size()); - for (int i = 0; i < (int)starts.size(); i++) { - fprintf(pp, ",%d", starts[i]); - } - fprintf(pp, " -23310=%d", (int)ends.size()); - for (int i = 0; i < (int)ends.size(); i++) { - fprintf(pp, ",%d", ends[i]); - } - if (!axes.empty()) { - fprintf(pp, " -23311=%d", (int)axes.size()); - for (int i = 0; i < (int)axes.size(); i++) { - int axis = axes[i]; - if (axis == 0 || axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n"); - - if (axis > 0) axis = axis - 1; // -1 for skip N-dim - - fprintf(pp, ",%d", axis); - } - } - } else { - fprintf(pp, " -23300=%d", (int)starts.size()); - for (int i = 0; i < (int)starts.size(); i++) { - fprintf(pp, ",%d", starts[i]); - } - fprintf(pp, " -23301=%d", (int)ends.size()); - for (int i = 0; i < (int)ends.size(); i++) { - fprintf(pp, ",%d", ends[i]); - } - if (!axes.empty()) { - fprintf(pp, " -23302=%d", (int)axes.size()); - for (int i = 0; i < (int)axes.size(); i++) { - int axis = axes[i]; - if (axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n"); - fprintf(pp, ",%d", axis); - } - } - if (!steps.empty()) { - fprintf(pp, " -23303=%d", (int)steps.size()); - for (int i = 0; i < (int)steps.size(); i++) { - int step = steps[i]; - if (step == 0) fprintf(stderr, "Unsupported slice step ! Unsupported slice step\n"); - fprintf(pp, ",%d", step); - } - } - } - } else if (op == "Softmax") { - int axis = get_node_attr_i(node, "axis", 1); - fprintf(pp, " 0=%d", axis - 1); - fprintf(pp, " 1=1"); - } else if (op == "Split") { - int axis = get_node_attr_i(node, "axis", 0); - std::vector split = get_node_attr_ai(node, "split"); - if (axis < 1) fprintf(stderr, "Unsupported split axis !\n"); - - fprintf(pp, " -23300=%d", output_size); - if (split.empty()) { - for (int i = 0; i < output_size; i++) { - fprintf(pp, ",-233"); - } - } else { - for (size_t i = 0; i < split.size() - 1; i++) { - fprintf(pp, ",%d", split[i]); - } - fprintf(pp, ",-233"); - } - fprintf(pp, " 1=%d", axis - 1); - } else if (op == "Sqrt") { - int op_type = 5; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Squeeze") { - std::vector axes = get_node_attr_ai(node, "axes"); - - if (axes.empty()) { - fprintf(pp, " 0=1"); - fprintf(pp, " 1=1"); - fprintf(pp, " 2=1"); - } else { - bool flag = true; - for (int i = 0; i < (int)axes.size(); i++) { - if (axes[i] == 0) { - flag = false; - break; - } - } - if (flag == true) { - fprintf(pp, " -23303=%zu", axes.size()); - for (int i = 0; i < (int)axes.size(); i++) { - if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3) - fprintf(stderr, "Unsupported squeeze axes !: %d, %s\n", axes[i], node.name().c_str()); - fprintf(pp, ",%d", axes[i] - 1); - } - } - } - } else if (op == "Sub") { - int op_type = 1; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "Sum") { - int op_type = 1; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Swish") { - // no param - } else if (op == "Tan") { - int op_type = 11; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Tanh") { - int op_type = 16; - fprintf(pp, " 0=%d", op_type); - } else if (op == "TopK") { - int axis = get_node_attr_i(node, "axis", -1); - axis = axis > 0 ? axis - 1 : axis; - int largest = get_node_attr_i(node, "largest", 1); - int sorted = get_node_attr_i(node, "sorted", 1); - fprintf(pp, " 0=%d", axis); - fprintf(pp, " 1=%d", largest); - fprintf(pp, " 2=%d", sorted); - } else if (op == "Transpose") { - std::vector perm = get_node_attr_ai(node, "perm"); - - if (perm.size() == 3) { - if (perm[1] == 1 && perm[2] == 2) - fprintf(pp, " 0=0"); // w h - else if (perm[1] == 2 && perm[2] == 1) - fprintf(pp, " 0=1"); // h w - else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2) - fprintf(pp, " 0=0"); // w h - else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1) - fprintf(pp, " 0=1"); // h w - } else if (perm.size() == 4) { - if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3) - fprintf(pp, " 0=0"); // w h c - else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2) - fprintf(pp, " 0=1"); // h w c - else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3) - fprintf(pp, " 0=2"); // w c h - else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1) - fprintf(pp, " 0=3"); // c w h - else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2) - fprintf(pp, " 0=4"); // h c w - else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1) - fprintf(pp, " 0=5"); // c h w - } else if (perm.size() == 5) { - if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4) - fprintf(pp, " 0=0"); // wx h c - else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2) - fprintf(pp, " 0=1"); // h wx c - else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4) - fprintf(pp, " 0=2"); // wx c h - else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1) - fprintf(pp, " 0=3"); // c wx h - else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2) - fprintf(pp, " 0=4"); // h c wx - else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1) - fprintf(pp, " 0=5"); // c h wx - else - fprintf(stderr, "Unsupported transpose type !\n"); - } - } else if (op == "Upsample") { - std::string mode = get_node_attr_s(node, "mode"); - std::string align = get_node_attr_s(node, "coordinate_transformation_mode"); - - std::vector scales; - - if (node.input_size() == 1) { - scales = get_node_attr_af(node, "scales"); - } else { - scales = get_node_attr_from_input_af(weights[node.input(1)]); - } - - int resize_type = 1; - if (mode == "nearest") { - resize_type = 1; - } else if (mode == "bilinear" || mode == "linear") { - resize_type = 2; - } else if (mode == "trilinear") { - fprintf(stderr, "Unsupported Upsample mode !\n"); - } - - float h_scale = 1.f; - float w_scale = 1.f; - if (scales.size() == 2) { - w_scale = scales[1]; - } else if (scales.size() == 3) { - h_scale = scales[1]; - w_scale = scales[2]; - } else if (scales.size() == 4) { - h_scale = scales[2]; - w_scale = scales[3]; - - if (scales[1] != 1.f) fprintf(stderr, "Unsupported Upsample scales !\n"); - } else { - fprintf(stderr, "Unsupported Upsample scales !\n"); - } - - int align_corner = 0; - if (align == "align_corners") { - align_corner = 1; - } - - fprintf(pp, " 0=%d", resize_type); - fprintf(pp, " 1=%e", h_scale); - fprintf(pp, " 2=%e", w_scale); - fprintf(pp, " 6=%d", align_corner); - } else if (op == "Unsqueeze") { - std::vector axes = get_node_attr_ai(node, "axes"); - bool flag = true; - for (int i = 0; i < (int)axes.size(); i++) { - if (axes[i] == 0) { - flag = false; - break; - } - } - if (flag) { - fprintf(pp, " -23303=%zu", axes.size()); - for (int i = 0; i < (int)axes.size(); i++) { - if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4) - fprintf(stderr, "Unsupported unsqueeze axes !: %d, %s\n", axes[i], node.name().c_str()); - fprintf(pp, ",%d", axes[i] - 1); - } - } - } else if (op == "Yolov3DetectionOutput") { - int num_class = get_node_attr_i(node, "num_class"); - int num_box = get_node_attr_i(node, "num_box"); - float confidence_threshold = get_node_attr_f(node, "confidence_threshold"); - float nms_threshold = get_node_attr_f(node, "nms_threshold"); - fprintf(pp, " 0=%d", num_class); - fprintf(pp, " 1=%d", num_box); - fprintf(pp, " 2=%e", confidence_threshold); - fprintf(pp, " 3=%e", nms_threshold); - std::vector biases = get_node_attr_af(node, "biases"); - if (biases.size() > 0) { - fprintf(pp, " -23304=%zu", biases.size()); - for (int i = 0; i < (int)biases.size(); i++) { - fprintf(pp, ",%e", biases[i]); - } - } - std::vector mask = get_node_attr_af(node, "mask"); - if (mask.size() > 0) { - fprintf(pp, " -23305=%zu", mask.size()); - for (int i = 0; i < (int)mask.size(); i++) { - fprintf(pp, ",%e", mask[i]); - } - } - std::vector anchors_scale = get_node_attr_af(node, "anchors_scale"); - if (anchors_scale.size() > 0) { - fprintf(pp, " -23306=%zu", anchors_scale.size()); - for (int i = 0; i < (int)anchors_scale.size(); i++) { - fprintf(pp, ",%e", anchors_scale[i]); - } - } - } else { - // TODO op specific param - } - - fprintf(pp, "\n"); - for (int j = 0; j < output_size; j++) { - const std::string& output_name = node.output(j); - if (node_reference.find(output_name) != node_reference.end()) { - int refcount = node_reference[output_name]; - if (refcount > 1) { - char splitname[256]; - sprintf(splitname, "splitncnn_%d", internal_split); - fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); - - fprintf(pp, " %s", output_name.c_str()); - - for (int k = 0; k < refcount; k++) { - fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k); - } - fprintf(pp, "\n"); - - internal_split++; } - } } - } - fclose(pp); - fclose(bp); - fprintf(stderr, "onnx2ncnn finish\n"); - return 0; + fclose(pp); + fclose(bp); + fprintf(stderr, "onnx2ncnn finish\n"); + return 0; } diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp index dd1fe2c4f6..efecdcd199 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp @@ -13,158 +13,179 @@ * @param context * @return std::tuple> */ -std::tuple> query_shape( - onnx::GraphProto* mutable_graph, onnx::NodeProto* target, - const std::map& weights, - std::map>& context) { - // emplace all input nodes - const int input_count = mutable_graph->input_size(); - for (int i = 0; i < input_count; i++) { - auto inp = mutable_graph->input(i); - onnx::TypeProto inp_type = inp.type(); - onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape(); - - auto dim_size = shape_proto.dim_size(); - std::vector shape(dim_size); - for (int index = 0; index < dim_size; ++index) { - shape[index] = shape_proto.dim(index).dim_value(); - } - - context.emplace(inp.name(), shape); - } - - // BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes - std::vector serial = {target}; - { - std::set mark_as_appended = {}; - while (true) { - int start = 0, end = serial.size(); - for (int i = start; i < end; ++i) { - auto node_ptr = serial[i]; - auto len = node_ptr->input_size(); - - for (int j = 0; j < len; ++j) { - std::string name = node_ptr->input(j); - if (context.find(name) != context.end()) { - // if input founded, skip - continue; - } - - if (weights.find(name) != weights.end()) { - // if founded in weights, extract shape to context - auto weight = weights.at(name); - std::vector shape; - for (auto index = 0; index < weight.dims_size(); ++index) { - shape.emplace_back(weight.dims(index)); - } - context.emplace(name, shape); - continue; - } - - if (mark_as_appended.find(name) != mark_as_appended.end()) { - // if mark as appended, skip - continue; - } - // else append it to serialization list - auto depend_ptr = find_node_by_output_name(mutable_graph, name); - if (depend_ptr == nullptr) { - fprintf(stderr, "cannot find %s from graph !\n", name.c_str()); - return std::make_tuple(false, std::vector{}); - } - mark_as_appended.insert(name); - serial.emplace_back(depend_ptr); +std::tuple> query_shape(onnx::GraphProto* mutable_graph, + onnx::NodeProto* target, + const std::map& weights, + std::map>& context) +{ + // emplace all input nodes + const int input_count = mutable_graph->input_size(); + for (int i = 0; i < input_count; i++) + { + auto inp = mutable_graph->input(i); + onnx::TypeProto inp_type = inp.type(); + onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape(); + + auto dim_size = shape_proto.dim_size(); + std::vector shape(dim_size); + for (int index = 0; index < dim_size; ++index) + { + shape[index] = shape_proto.dim(index).dim_value(); } - } - if (serial.size() <= end) { - // if not new node added, quit - break; - } - - // update start and end position, continue BFS the tree - start = end; - end = serial.size(); + context.emplace(inp.name(), shape); } - } - - // for each node in serialization list, calculate the output shape - { - std::reverse(serial.begin(), serial.end()); - for (auto node : serial) { - if (node->op_type() == "Conv") { - auto inp = context[node->input(0)]; - auto weight = context[node->input(1)]; - assert(inp.size() == 4 and weight.size() == 4); - - int group = get_node_attr_i(*node, "group", 1); - assert(group == 1); - - // treat multiple spatial attr as single one -#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \ - int ATTR = DEFAULT; \ - { \ - std::vector _vec = get_node_attr_ai(*node, NAME); \ - if (not _vec.empty()) { \ - ATTR = _vec[0]; \ - } \ - } - - EXTRACT_REPEATED_PARAM("dilations", dilation, 1); - EXTRACT_REPEATED_PARAM("pads", pad, 0); - EXTRACT_REPEATED_PARAM("strides", stride, 1); - -#undef EXTRACT_REPEATED_PARAM - int on = inp[0]; - int oc = weight[0]; - int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1; - int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1; - context.emplace(node->output(0), std::vector{on, oc, oh, ow}); - - } else if (node->op_type() == "Shape") { - auto inp = context[node->input(0)]; - context.emplace(node->output(0), std::vector{1, inp[1], inp[2], inp[3]}); - - } else if (node->op_type() == "Slice") { - assert(node->input_size() >= 4); + // BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes + std::vector serial = {target}; + { + std::set mark_as_appended = {}; + while (true) + { + int start = 0, end = serial.size(); + for (int i = start; i < end; ++i) + { + auto node_ptr = serial[i]; + auto len = node_ptr->input_size(); + + for (int j = 0; j < len; ++j) + { + std::string name = node_ptr->input(j); + if (context.find(name) != context.end()) + { + // if input founded, skip + continue; + } + + if (weights.find(name) != weights.end()) + { + // if founded in weights, extract shape to context + auto weight = weights.at(name); + std::vector shape; + for (auto index = 0; index < weight.dims_size(); ++index) + { + shape.emplace_back(weight.dims(index)); + } + context.emplace(name, shape); + continue; + } + + if (mark_as_appended.find(name) != mark_as_appended.end()) + { + // if mark as appended, skip + continue; + } + // else append it to serialization list + auto depend_ptr = find_node_by_output_name(mutable_graph, name); + if (depend_ptr == nullptr) + { + fprintf(stderr, "cannot find %s from graph !\n", name.c_str()); + return std::make_tuple(false, std::vector{}); + } + mark_as_appended.insert(name); + serial.emplace_back(depend_ptr); + } + } - auto inp = context[node->input(0)]; - int start = get_node_attr_from_input(weights.at(node->input(1))); - int end = get_node_attr_from_input(weights.at(node->input(2))); - int axes = get_node_attr_from_input(weights.at(node->input(3))); + if (serial.size() <= end) + { + // if not new node added, quit + break; + } - if (axes != 0) { - fprintf(stderr, "Not support axes=%d !\n", axes); - return std::make_tuple(false, std::vector{}); + // update start and end position, continue BFS the tree + start = end; + end = serial.size(); } + } - assert(inp.size() >= end - start); - context.emplace(node->output(0), std::vector{inp.begin() + start, inp.begin() + end}); - - } else if (node->op_type() == "Concat") { - assert(node->input_size() >= 2); - - auto axis = get_node_attr_i(*node, "axis", 0); - if (axis != 0) { - fprintf(stderr, "Not support axes=%d !\n", axis); - return std::make_tuple(false, std::vector{}); - } + // for each node in serialization list, calculate the output shape + { + std::reverse(serial.begin(), serial.end()); + for (auto node : serial) + { + if (node->op_type() == "Conv") + { + auto inp = context[node->input(0)]; + auto weight = context[node->input(1)]; + assert(inp.size() == 4 and weight.size() == 4); + + int group = get_node_attr_i(*node, "group", 1); + assert(group == 1); + + // treat multiple spatial attr as single one +#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \ + int ATTR = DEFAULT; \ + { \ + std::vector _vec = get_node_attr_ai(*node, NAME); \ + if (not _vec.empty()) \ + { \ + ATTR = _vec[0]; \ + } \ + } - std::vector inp = context[node->input(0)]; - std::vector w_data = get_node_attr_from_input_ai(weights.at(node->input(1))); + EXTRACT_REPEATED_PARAM("dilations", dilation, 1); + EXTRACT_REPEATED_PARAM("pads", pad, 0); + EXTRACT_REPEATED_PARAM("strides", stride, 1); - // concat data on axis 0 - inp.insert(inp.end(), w_data.begin(), w_data.end()); - context.emplace(node->output(0), inp); +#undef EXTRACT_REPEATED_PARAM - } else { - fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str()); - return std::make_tuple(false, std::vector{}); - } + int on = inp[0]; + int oc = weight[0]; + int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1; + int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1; + context.emplace(node->output(0), std::vector{on, oc, oh, ow}); + } + else if (node->op_type() == "Shape") + { + auto inp = context[node->input(0)]; + context.emplace(node->output(0), std::vector{1, inp[1], inp[2], inp[3]}); + } + else if (node->op_type() == "Slice") + { + assert(node->input_size() >= 4); + + auto inp = context[node->input(0)]; + int start = get_node_attr_from_input(weights.at(node->input(1))); + int end = get_node_attr_from_input(weights.at(node->input(2))); + int axes = get_node_attr_from_input(weights.at(node->input(3))); + + if (axes != 0) + { + fprintf(stderr, "Not support axes=%d !\n", axes); + return std::make_tuple(false, std::vector{}); + } + + assert(inp.size() >= end - start); + context.emplace(node->output(0), std::vector{inp.begin() + start, inp.begin() + end}); + } + else if (node->op_type() == "Concat") + { + assert(node->input_size() >= 2); + + auto axis = get_node_attr_i(*node, "axis", 0); + if (axis != 0) + { + fprintf(stderr, "Not support axes=%d !\n", axis); + return std::make_tuple(false, std::vector{}); + } + + std::vector inp = context[node->input(0)]; + std::vector w_data = get_node_attr_from_input_ai(weights.at(node->input(1))); + + // concat data on axis 0 + inp.insert(inp.end(), w_data.begin(), w_data.end()); + context.emplace(node->output(0), inp); + } + else + { + fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str()); + return std::make_tuple(false, std::vector{}); + } + } } - } - assert(context.find(target->output(0)) != context.end()); - auto target_shape = context[target->output(0)]; - return std::make_tuple(true, target_shape); + assert(context.find(target->output(0)) != context.end()); + auto target_shape = context[target->output(0)]; + return std::make_tuple(true, target_shape); } diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.h b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.h index fa62ffe9de..55d966ae83 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.h +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.h @@ -13,7 +13,7 @@ * @param context * @return std::tuple> */ -std::tuple> query_shape( - onnx::GraphProto* mutable_graph, onnx::NodeProto* target, - const std::map& weights, - std::map>& context); +std::tuple> query_shape(onnx::GraphProto* mutable_graph, + onnx::NodeProto* target, + const std::map& weights, + std::map>& context); diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/utils.h b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/utils.h index 792db0ed34..ab991a52f9 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/utils.h +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/utils.h @@ -21,381 +21,496 @@ * @param name * @return onnx::NodeProto* */ -static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph, - const std::string& name) { - const int input_count = mutable_graph->node_size(); - for (int i = 0; i < input_count; ++i) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - for (int j = 0; j < node->output_size(); ++j) { - auto output = node->output(j); - if (output == name) { - return node; - } +static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph, + const std::string& name) +{ + const int input_count = mutable_graph->node_size(); + for (int i = 0; i < input_count; ++i) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + for (int j = 0; j < node->output_size(); ++j) + { + auto output = node->output(j); + if (output == name) + { + return node; + } + } } - } - return nullptr; + return nullptr; } -static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message) { - std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); - if (!fs.is_open()) { - fprintf(stderr, "open failed %s\n", filepath); - return false; - } +static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message) +{ + std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); + if (!fs.is_open()) + { + fprintf(stderr, "open failed %s\n", filepath); + return false; + } - google::protobuf::io::IstreamInputStream input(&fs); - google::protobuf::io::CodedInputStream codedstr(&input); + google::protobuf::io::IstreamInputStream input(&fs); + google::protobuf::io::CodedInputStream codedstr(&input); #if GOOGLE_PROTOBUF_VERSION >= 3011000 - codedstr.SetTotalBytesLimit(INT_MAX); + codedstr.SetTotalBytesLimit(INT_MAX); #else - codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); + codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); #endif - bool success = message->ParseFromCodedStream(&codedstr); + bool success = message->ParseFromCodedStream(&codedstr); - fs.close(); + fs.close(); - return success; + return success; } -static std::vector get_node_attr_ai(const onnx::NodeProto& node, const char* key) { - std::vector v; +static std::vector get_node_attr_ai(const onnx::NodeProto& node, const char* key) +{ + std::vector v; + + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + v.resize(attr.ints_size()); + for (int j = 0; j < attr.ints_size(); j++) + { + v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), + (::google::protobuf::int64)INT_MIN); + } + + break; + } + } - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - v.resize(attr.ints_size()); - for (int j = 0; j < attr.ints_size(); j++) { - v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), - (::google::protobuf::int64)INT_MIN); - } + return v; +} - break; +static void set_node_attr_ai(onnx::NodeProto& node, const char* key, const std::vector& value) +{ + onnx::AttributeProto* attr_group = node.add_attribute(); + attr_group->set_name(key); + for (auto v : value) + { + attr_group->add_ints(v); } - } - return v; + return; } -static void set_node_attr_ai(onnx::NodeProto& node, const char* key, - const std::vector& value) { - onnx::AttributeProto* attr_group = node.add_attribute(); - attr_group->set_name(key); - for (auto v : value) { - attr_group->add_ints(v); - } +static std::vector get_node_attr_af(const onnx::NodeProto& node, const char* key) +{ + std::vector v; + + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + v.resize(attr.floats_size()); + for (int j = 0; j < attr.floats_size(); j++) + { + v[j] = attr.floats(j); + } + + break; + } + } - return; + return v; } -static std::vector get_node_attr_af(const onnx::NodeProto& node, const char* key) { - std::vector v; +static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0) +{ + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), + (::google::protobuf::int64)INT_MIN); + } + } - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - v.resize(attr.floats_size()); - for (int j = 0; j < attr.floats_size(); j++) { - v[j] = attr.floats(j); - } + return def; +} - break; +static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f) +{ + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + return attr.f(); + } } - } - return v; + return def; } -static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0) { - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), - (::google::protobuf::int64)INT_MIN); +static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string()) +{ + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + return attr.s(); + } } - } - return def; + return def; } -static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f) { - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - return attr.f(); +static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key) +{ + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + return attr.t(); + } } - } - return def; + return onnx::TensorProto(); } -static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, - const std::string& def = std::string()) { - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - return attr.s(); +template +static T get_node_attr_from_input(const onnx::TensorProto& tp) +{ + T v = 0.f; + + // float + if (tp.data_type() == 1) + { + const float* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const float*)tp.raw_data().data(); + } + else + { + shape_data = tp.float_data().data(); + } + v = shape_data[0]; + } + // double + else if (tp.data_type() == 11) + { + const double* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const double*)tp.raw_data().data(); + } + else + { + shape_data = tp.double_data().data(); + } + v = shape_data[0]; + } + // int64 + else if (tp.data_type() == 7) + { + const int64_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int64_t*)tp.raw_data().data(); + } + else + { + shape_data = tp.int64_data().data(); + } + v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), + (::google::protobuf::int64)INT_MIN); + } + // int32 + else if (tp.data_type() == 6) + { + const int32_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int32_t*)tp.raw_data().data(); + } + else + { + shape_data = tp.int32_data().data(); + } + v = shape_data[0]; + } + else + { + // fprintf(stderr, "tp.name: %s\n", tp.name().c_str()); + fprintf(stderr, "Unknown data type %d\n", tp.data_type()); + fprintf(stderr, "get_node_attr_from_input\n"); + abort(); } - } - return def; + return v; } -static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key) { - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - return attr.t(); +static std::vector get_node_attr_from_input_ai(const onnx::TensorProto& tp) +{ + int size = 0; + + std::vector v; + + // int64 + if (tp.data_type() == 7) + { + const int64_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int64_t*)tp.raw_data().data(); + size = (int)(tp.raw_data().size() / 8); + } + else + { + shape_data = tp.int64_data().data(); + size = tp.int64_data_size(); + } + for (int j = 0; j < size; j++) + { + int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), + (::google::protobuf::int64)INT_MIN); + v.push_back(vi); + } + } + // int32 + else if (tp.data_type() == 6) + { + const int32_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int32_t*)tp.raw_data().data(); + size = (int)(tp.raw_data().size() / 4); + } + else + { + shape_data = tp.int32_data().data(); + size = tp.int32_data_size(); + } + for (int j = 0; j < size; j++) + { + v.push_back(shape_data[j]); + } + } + else + { + fprintf(stderr, "Unknown data type %d\n", tp.data_type()); } - } - return onnx::TensorProto(); + return v; } -template -static T get_node_attr_from_input(const onnx::TensorProto& tp) { - T v = 0.f; - - // float - if (tp.data_type() == 1) { - const float* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const float*)tp.raw_data().data(); - } else { - shape_data = tp.float_data().data(); - } - v = shape_data[0]; - } - // double - else if (tp.data_type() == 11) { - const double* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const double*)tp.raw_data().data(); - } else { - shape_data = tp.double_data().data(); - } - v = shape_data[0]; - } - // int64 - else if (tp.data_type() == 7) { - const int64_t* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const int64_t*)tp.raw_data().data(); - } else { - shape_data = tp.int64_data().data(); - } - v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), - (::google::protobuf::int64)INT_MIN); - } - // int32 - else if (tp.data_type() == 6) { - const int32_t* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const int32_t*)tp.raw_data().data(); - } else { - shape_data = tp.int32_data().data(); - } - v = shape_data[0]; - } else { - // fprintf(stderr, "tp.name: %s\n", tp.name().c_str()); - fprintf(stderr, "Unknown data type %d\n", tp.data_type()); - fprintf(stderr, "get_node_attr_from_input\n"); - abort(); - } - - return v; -} +static std::vector get_node_attr_from_input_af(const onnx::TensorProto& tp) +{ + int size = 0; + + std::vector v; + + // float + if (tp.data_type() == 1) + { + const float* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const float*)tp.raw_data().data(); + size = (int)(tp.raw_data().size() / 4); + } + else + { + shape_data = tp.float_data().data(); + size = tp.float_data_size(); + } + for (int j = 0; j < size; j++) + { + v.push_back(shape_data[j]); + } + } + // double + else if (tp.data_type() == 11) + { + const double* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const double*)tp.raw_data().data(); + size = (int)(tp.raw_data().size() / 8); + } + else + { + shape_data = tp.double_data().data(); + size = tp.double_data_size(); + } + for (int j = 0; j < size; j++) + { + v.push_back((float)shape_data[j]); + } + } + else + { + fprintf(stderr, "Unknown data type %d\n", tp.data_type()); + } -static std::vector get_node_attr_from_input_ai(const onnx::TensorProto& tp) { - int size = 0; - - std::vector v; - - // int64 - if (tp.data_type() == 7) { - const int64_t* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const int64_t*)tp.raw_data().data(); - size = (int)(tp.raw_data().size() / 8); - } else { - shape_data = tp.int64_data().data(); - size = tp.int64_data_size(); - } - for (int j = 0; j < size; j++) { - int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), - (::google::protobuf::int64)INT_MIN); - v.push_back(vi); - } - } - // int32 - else if (tp.data_type() == 6) { - const int32_t* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const int32_t*)tp.raw_data().data(); - size = (int)(tp.raw_data().size() / 4); - } else { - shape_data = tp.int32_data().data(); - size = tp.int32_data_size(); - } - for (int j = 0; j < size; j++) { - v.push_back(shape_data[j]); - } - } else { - fprintf(stderr, "Unknown data type %d\n", tp.data_type()); - } - - return v; + return v; } -static std::vector get_node_attr_from_input_af(const onnx::TensorProto& tp) { - int size = 0; - - std::vector v; - - // float - if (tp.data_type() == 1) { - const float* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const float*)tp.raw_data().data(); - size = (int)(tp.raw_data().size() / 4); - } else { - shape_data = tp.float_data().data(); - size = tp.float_data_size(); - } - for (int j = 0; j < size; j++) { - v.push_back(shape_data[j]); - } - } - // double - else if (tp.data_type() == 11) { - const double* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const double*)tp.raw_data().data(); - size = (int)(tp.raw_data().size() / 8); - } else { - shape_data = tp.double_data().data(); - size = tp.double_data_size(); - } - for (int j = 0; j < size; j++) { - v.push_back((float)shape_data[j]); - } - } else { - fprintf(stderr, "Unknown data type %d\n", tp.data_type()); - } - - return v; -} +static int get_tensor_proto_data_size(const onnx::TensorProto& tp) +{ + if (tp.has_raw_data()) + { + if (tp.data_type() == 1 || tp.data_type() == 6) + { + const std::string& raw_data = tp.raw_data(); + int size = (int)raw_data.size() / 4; + return size; + } + else if (tp.data_type() == 7 || tp.data_type() == 11) + { + const std::string& raw_data = tp.raw_data(); + int size = (int)raw_data.size() / 8; + return size; + } + else if (tp.data_type() == 9) + { + const std::string& raw_data = tp.raw_data(); + return 0; + } + } + else if (tp.data_type() == 1) + { + return tp.float_data_size(); + } + else if (tp.data_type() == 7) + { + return tp.int64_data_size(); + } + else if (tp.data_type() == 6) + { + return tp.int32_data_size(); + } + else if (tp.data_type() == 11) + { + return tp.double_data_size(); + } -static int get_tensor_proto_data_size(const onnx::TensorProto& tp) { - if (tp.has_raw_data()) { - if (tp.data_type() == 1 || tp.data_type() == 6) { - const std::string& raw_data = tp.raw_data(); - int size = (int)raw_data.size() / 4; - return size; - } else if (tp.data_type() == 7 || tp.data_type() == 11) { - const std::string& raw_data = tp.raw_data(); - int size = (int)raw_data.size() / 8; - return size; - } else if (tp.data_type() == 9) { - const std::string& raw_data = tp.raw_data(); - return 0; - } - } else if (tp.data_type() == 1) { - return tp.float_data_size(); - } else if (tp.data_type() == 7) { - return tp.int64_data_size(); - } else if (tp.data_type() == 6) { - return tp.int32_data_size(); - } else if (tp.data_type() == 11) { - return tp.double_data_size(); - } - - return 0; + return 0; } -static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp) { - int size = get_tensor_proto_data_size(tp); +static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp) +{ + int size = get_tensor_proto_data_size(tp); - if (tp.has_raw_data()) { - const std::string& raw_data = tp.raw_data(); - fwrite(raw_data.data(), sizeof(float), size, bp); - } else if (tp.data_type() == 1) { - fwrite(tp.float_data().data(), sizeof(float), size, bp); - } + if (tp.has_raw_data()) + { + const std::string& raw_data = tp.raw_data(); + fwrite(raw_data.data(), sizeof(float), size, bp); + } + else if (tp.data_type() == 1) + { + fwrite(tp.float_data().data(), sizeof(float), size, bp); + } } -static void fwrite_tensor_proto_data_to_float(const onnx::TensorProto& tp, FILE* bp) { - int size = get_tensor_proto_data_size(tp); - size_t written_size; - if (tp.has_raw_data()) { - const std::string& raw_data = tp.raw_data(); - if (tp.data_type() == 6) { - int* intdataptr = (int*)raw_data.data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 7) { - int64_t* intdataptr = (int64_t*)raw_data.data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 9) { - bool* intdataptr = (bool*)raw_data.data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 11) { - double* doubledataptr = (double*)raw_data.data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)doubledataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } - } else if (tp.data_type() == 6) { - int* intdataptr = (int*)tp.int32_data().data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 7) { - int64_t* intdataptr = (int64_t*)tp.int64_data().data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 9) { - int* intdataptr = (int*)tp.int64_data().data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 11) { - double* doubledataptr = (double*)tp.double_data().data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)doubledataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } +static void fwrite_tensor_proto_data_to_float(const onnx::TensorProto& tp, FILE* bp) +{ + int size = get_tensor_proto_data_size(tp); + size_t written_size; + if (tp.has_raw_data()) + { + const std::string& raw_data = tp.raw_data(); + if (tp.data_type() == 6) + { + int* intdataptr = (int*)raw_data.data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 7) + { + int64_t* intdataptr = (int64_t*)raw_data.data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 9) + { + bool* intdataptr = (bool*)raw_data.data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 11) + { + double* doubledataptr = (double*)raw_data.data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)doubledataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + } + else if (tp.data_type() == 6) + { + int* intdataptr = (int*)tp.int32_data().data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 7) + { + int64_t* intdataptr = (int64_t*)tp.int64_data().data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 9) + { + int* intdataptr = (int*)tp.int64_data().data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 11) + { + double* doubledataptr = (double*)tp.double_data().data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)doubledataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } } diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/CMakeLists.txt b/csrc/mmdeploy/backend_ops/ncnn/ops/CMakeLists.txt index abfff8e3f2..755561c379 100755 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/CMakeLists.txt @@ -6,19 +6,17 @@ project(mmdeploy_ncnn_ops) file(GLOB_RECURSE NCNN_OPS_SRCS *.cpp) add_library(${PROJECT_NAME}_obj OBJECT "${NCNN_OPS_SRCS}") target_compile_definitions(${PROJECT_NAME}_obj PRIVATE -DMMDEPLOY_API_EXPORTS=1) -set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE 1) +set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE + 1) target_link_libraries(${PROJECT_NAME}_obj PRIVATE ncnn) -set(_COMMON_INCLUDE_DIRS - $ - $) -target_include_directories(${PROJECT_NAME}_obj - PUBLIC ${_COMMON_INCLUDE_DIRS}) +set(_COMMON_INCLUDE_DIRS $ + $) +target_include_directories(${PROJECT_NAME}_obj PUBLIC ${_COMMON_INCLUDE_DIRS}) mmdeploy_export(${PROJECT_NAME}_obj) mmdeploy_add_library(${PROJECT_NAME} SHARED EXCLUDE "") target_link_libraries(${PROJECT_NAME} PRIVATE ${PROJECT_NAME}_obj) -target_include_directories(${PROJECT_NAME} - PUBLIC ${_COMMON_INCLUDE_DIRS}) +target_include_directories(${PROJECT_NAME} PUBLIC ${_COMMON_INCLUDE_DIRS}) add_library(mmdeploy::ncnn_ops ALIAS ${PROJECT_NAME}) diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.cpp old mode 100755 new mode 100644 index b865db7b25..32ae99669b --- a/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.cpp @@ -3,51 +3,63 @@ #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(ConstantOfShape) -DEFINE_NCNN_OPS(ConstantOfShape, ConstantOfShape) -ConstantOfShape::ConstantOfShape() { - one_blob_only = true; - support_inplace = false; -} +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(ConstantOfShape) + DEFINE_NCNN_OPS(ConstantOfShape, ConstantOfShape) -int ConstantOfShape::load_param(const ParamDict& pd) { - val = pd.get(0, 0.f); - return 0; -} + ConstantOfShape::ConstantOfShape() + { + one_blob_only = true; + support_inplace = false; + } -int ConstantOfShape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { - int dims = bottom_blob.w - 1; - const float* bottom_ptr = bottom_blob; - const float* shape_ptr = bottom_ptr + 1; + int ConstantOfShape::load_param(const ParamDict& pd) + { + val = pd.get(0, 0.f); + return 0; + } - if (dims == 1) { - int w = (int)(shape_ptr[0] + 0.5); - size_t elemsize = sizeof(val); - top_blob.create(w, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - top_blob.fill(val); - return 0; - } else if (dims == 2) { - int h = (int)(shape_ptr[0] + 0.5); - int w = (int)(shape_ptr[1] + 0.5); - size_t elemsize = sizeof(val); - top_blob.create(w, h, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - top_blob.fill(val); - return 0; - } else if (dims == 3) { - int channels = (int)(shape_ptr[0] + 0.5); - int h = (int)(shape_ptr[1] + 0.5); - int w = (int)(shape_ptr[2] + 0.5); - size_t elemsize = sizeof(val); - top_blob.create(w, h, channels, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - top_blob.fill(val); - return 0; - } - return -1; -} + int ConstantOfShape::forward(const Mat& bottom_blob, + Mat& top_blob, + const Option& opt) const + { + int dims = bottom_blob.w - 1; + const float* bottom_ptr = bottom_blob; + const float* shape_ptr = bottom_ptr + 1; + + if (dims == 1) + { + int w = (int)(shape_ptr[0] + 0.5); + size_t elemsize = sizeof(val); + top_blob.create(w, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + top_blob.fill(val); + return 0; + } + else if (dims == 2) + { + int h = (int)(shape_ptr[0] + 0.5); + int w = (int)(shape_ptr[1] + 0.5); + size_t elemsize = sizeof(val); + top_blob.create(w, h, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + top_blob.fill(val); + return 0; + } + else if (dims == 3) + { + int channels = (int)(shape_ptr[0] + 0.5); + int h = (int)(shape_ptr[1] + 0.5); + int w = (int)(shape_ptr[2] + 0.5); + size_t elemsize = sizeof(val); + top_blob.create(w, h, channels, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + top_blob.fill(val); + return 0; + } + return -1; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.h b/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.h old mode 100755 new mode 100644 index b61fb62c09..85317ba559 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.h @@ -4,20 +4,23 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class ConstantOfShape : public ncnn::Layer { - public: - ConstantOfShape(); + class ConstantOfShape : public ncnn::Layer + { + public: + ConstantOfShape(); - virtual int load_param(const ncnn::ParamDict& pd); + virtual int load_param(const ncnn::ParamDict& pd); - virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, - const ncnn::Option& opt) const; + virtual int forward(const ncnn::Mat& bottom_blob, + ncnn::Mat& top_blob, + const ncnn::Option& opt) const; - public: - float val; -}; + public: + float val; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.cpp old mode 100755 new mode 100644 index be3d75a248..ca8120f228 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.cpp @@ -4,330 +4,454 @@ #include "expand.h" #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(Expand) -DEFINE_NCNN_OPS(Expand, Expand) -Expand::Expand() { - one_blob_only = false; - support_inplace = false; -} - -int Expand::forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const Option& opt) const { - const Mat& bottom_blob = bottom_blobs[0]; - size_t elemsize = bottom_blob.elemsize; - const Mat& old_shape_blob = bottom_blobs[1]; - const int shape_width = old_shape_blob.w - 1; - Mat shape_blob(shape_width, elemsize, opt.workspace_allocator); - memcpy(shape_blob.row(0), old_shape_blob.row(0) + 1, shape_width * elemsize); - Mat& top_blob = top_blobs[0]; - - if (bottom_blob.dims == 1 && shape_blob.w == 1) { - int shape_0 = (int)(shape_blob[0] + 0.5); - if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d) vs (%d)\n", bottom_blob.w, shape_0); - } else if (bottom_blob.w == shape_0 || shape_0 == 1) { - top_blob.create(bottom_blob.w, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - - for (int i = 0; i < bottom_blob.w; i++) { - top_blob[i] = bottom_blob[i]; - } - } else if (bottom_blob.w == 1) { - top_blob.create(shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - - for (int i = 0; i < shape_0; i++) { - top_blob[i] = bottom_blob[0]; - } - } else { - fprintf(stderr, "error case\n"); - return -100; +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(Expand) + DEFINE_NCNN_OPS(Expand, Expand) + Expand::Expand() + { + one_blob_only = false; + support_inplace = false; } - return 0; - } else if (bottom_blob.dims == 1 && shape_blob.w == 2) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (1, %d) vs (%d, %d)\n", bottom_blob.w, shape_0, - shape_1); - } else if (bottom_blob.w == shape_1 || shape_1 == 1) { - top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.row(j)[i] = bottom_blob[i]; - } - } + int Expand::forward(const std::vector& bottom_blobs, + std::vector& top_blobs, + const Option& opt) const + { + const Mat& bottom_blob = bottom_blobs[0]; + size_t elemsize = bottom_blob.elemsize; + const Mat& old_shape_blob = bottom_blobs[1]; + const int shape_width = old_shape_blob.w - 1; + Mat shape_blob(shape_width, elemsize, opt.workspace_allocator); + memcpy(shape_blob.row(0), old_shape_blob.row(0) + 1, shape_width * elemsize); + Mat& top_blob = top_blobs[0]; - } else if (bottom_blob.w == 1) { - top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; + if (bottom_blob.dims == 1 && shape_blob.w == 1) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d) vs (%d)\n", bottom_blob.w, shape_0); + } + else if (bottom_blob.w == shape_0 || shape_0 == 1) + { + top_blob.create(bottom_blob.w, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < shape_1; i++) { - top_blob.row(j)[i] = bottom_blob[0]; - } - } + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob[i] = bottom_blob[i]; + } + } + else if (bottom_blob.w == 1) + { + top_blob.create(shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } else if (bottom_blob.dims == 1 && shape_blob.w == 3) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - int shape_2 = (int)(shape_blob[2] + 0.5); - - if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (1, 1, %d) vs (%d, %d, %d)\n", bottom_blob.w, - shape_0, shape_1, shape_2); - } else if (bottom_blob.w == shape_2 || shape_2 == 1) { - top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob[i]; - } - } - } - } else if (bottom_blob.w == 1) { - top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob[0]; - } - } - } - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } else if (bottom_blob.dims == 2 && shape_blob.w == 2) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h, - bottom_blob.w, shape_0, shape_1); - } else if (bottom_blob.h != shape_0 && bottom_blob.h != 1 && shape_0 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h, - bottom_blob.w, shape_0, shape_1); - } else if ((bottom_blob.w == shape_1 || shape_1 == 1) && - (bottom_blob.h == shape_0 || shape_0 == 1)) { - top_blob.create(bottom_blob.w, bottom_blob.h, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.row(j)[i] = bottom_blob.row(j)[i]; - } - } - } else if ((bottom_blob.w == shape_1 || shape_1 == 1) && (bottom_blob.h == 1)) { - top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.row(j)[i] = bottom_blob.row(0)[i]; - } - } - } else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_0 || shape_0 == 1)) { - top_blob.create(shape_1, bottom_blob.h, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_1; i++) { - top_blob.row(j)[i] = bottom_blob.row(j)[0]; + for (int i = 0; i < shape_0; i++) + { + top_blob[i] = bottom_blob[0]; + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else if (bottom_blob.h == 1 && bottom_blob.w == 1) { - top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < shape_1; i++) { - top_blob.row(j)[i] = bottom_blob.row(0)[0]; - } - } - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } else if (bottom_blob.dims == 2 && shape_blob.w == 3) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - int shape_2 = (int)(shape_blob[2] + 0.5); - if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h, - bottom_blob.w, shape_0, shape_1, shape_2); - } else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h, - bottom_blob.w, shape_0, shape_1, shape_2); - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && - (bottom_blob.h == shape_1 || shape_1 == 1)) { - top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[i]; - } - } - } - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1)) { - top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[i]; - } - } - } - - } else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_1 || shape_1 == 1)) { - top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[0]; - } - } - } + else if (bottom_blob.dims == 1 && shape_blob.w == 2) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (1, %d) vs (%d, %d)\n", bottom_blob.w, shape_0, shape_1); + } + else if (bottom_blob.w == shape_1 || shape_1 == 1) + { + top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; - } else if (bottom_blob.h == 1 && bottom_blob.w == 1) { - top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[0]; - } - } - } - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } else if (bottom_blob.dims == 3 && shape_blob.w == 3) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - int shape_2 = (int)(shape_blob[2] + 0.5); - if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, - bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); - } else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, - bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); - } else if (bottom_blob.c != shape_0 && bottom_blob.c != 1 && shape_0 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, - bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && - (bottom_blob.h == shape_1 || shape_1 == 1) && - (bottom_blob.c == shape_0 || shape_0 == 1)) { - top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i]; - } - } - } - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && - (bottom_blob.h == shape_1 || shape_1 == 1) && (bottom_blob.c == 1)) { - top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[i]; - } - } - } - - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) && - (bottom_blob.c == shape_0 || shape_0 == 1)) { - top_blob.create(bottom_blob.w, shape_1, bottom_blob.c, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[i]; - } - } - } + for (int j = 0; j < shape_0; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.row(j)[i] = bottom_blob[i]; + } + } + } + else if (bottom_blob.w == 1) + { + top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) && - (bottom_blob.c == 1)) { - top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[i]; - } + for (int j = 0; j < shape_0; j++) + { + for (int i = 0; i < shape_1; i++) + { + top_blob.row(j)[i] = bottom_blob[0]; + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } + else if (bottom_blob.dims == 1 && shape_blob.w == 3) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + int shape_2 = (int)(shape_blob[2] + 0.5); - } else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) && - (bottom_blob.c == shape_0 || shape_0 == 1)) { - top_blob.create(shape_2, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0]; - } + if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (1, 1, %d) vs (%d, %d, %d)\n", bottom_blob.w, shape_0, shape_1, shape_2); + } + else if (bottom_blob.w == shape_2 || shape_2 == 1) + { + top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob[i]; + } + } + } + } + else if (bottom_blob.w == 1) + { + top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob[0]; + } + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) && - (bottom_blob.c == 1)) { - top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[0]; - } + else if (bottom_blob.dims == 2 && shape_blob.w == 2) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h, bottom_blob.w, shape_0, shape_1); + } + else if (bottom_blob.h != shape_0 && bottom_blob.h != 1 && shape_0 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h, bottom_blob.w, shape_0, shape_1); + } + else if ((bottom_blob.w == shape_1 || shape_1 == 1) && + (bottom_blob.h == shape_0 || shape_0 == 1)) + { + top_blob.create(bottom_blob.w, bottom_blob.h, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.row(j)[i] = bottom_blob.row(j)[i]; + } + } + } + else if ((bottom_blob.w == shape_1 || shape_1 == 1) && (bottom_blob.h == 1)) + { + top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int j = 0; j < shape_0; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.row(j)[i] = bottom_blob.row(0)[i]; + } + } + } + else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_0 || shape_0 == 1)) + { + top_blob.create(shape_1, bottom_blob.h, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < shape_1; i++) + { + top_blob.row(j)[i] = bottom_blob.row(j)[0]; + } + } + } + else if (bottom_blob.h == 1 && bottom_blob.w == 1) + { + top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int j = 0; j < shape_0; j++) + { + for (int i = 0; i < shape_1; i++) + { + top_blob.row(j)[i] = bottom_blob.row(0)[0]; + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else if (bottom_blob.w == 1 && bottom_blob.h == 1 && - (bottom_blob.c == shape_0 || shape_0 == 1)) { - top_blob.create(shape_2, shape_1, bottom_blob.c, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[0]; - } + else if (bottom_blob.dims == 2 && shape_blob.w == 3) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + int shape_2 = (int)(shape_blob[2] + 0.5); + if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && + (bottom_blob.h == shape_1 || shape_1 == 1)) + { + top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[i]; + } + } + } + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1)) + { + top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[i]; + } + } + } + } + else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_1 || shape_1 == 1)) + { + top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[0]; + } + } + } + } + else if (bottom_blob.h == 1 && bottom_blob.w == 1) + { + top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[0]; + } + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else if (bottom_blob.w == 1 && bottom_blob.h == 1 && bottom_blob.c == 1) { - top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[0]; - } + else if (bottom_blob.dims == 3 && shape_blob.w == 3) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + int shape_2 = (int)(shape_blob[2] + 0.5); + if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if (bottom_blob.c != shape_0 && bottom_blob.c != 1 && shape_0 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && + (bottom_blob.h == shape_1 || shape_1 == 1) && + (bottom_blob.c == shape_0 || shape_0 == 1)) + { + top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < bottom_blob.c; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i]; + } + } + } + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && + (bottom_blob.h == shape_1 || shape_1 == 1) && (bottom_blob.c == 1)) + { + top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[i]; + } + } + } + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) && + (bottom_blob.c == shape_0 || shape_0 == 1)) + { + top_blob.create(bottom_blob.w, shape_1, bottom_blob.c, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < bottom_blob.c; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[i]; + } + } + } + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) && + (bottom_blob.c == 1)) + { + top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[i]; + } + } + } + } + else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) && + (bottom_blob.c == shape_0 || shape_0 == 1)) + { + top_blob.create(shape_2, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < bottom_blob.c; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0]; + } + } + } + } + else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) && + (bottom_blob.c == 1)) + { + top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[0]; + } + } + } + } + else if (bottom_blob.w == 1 && bottom_blob.h == 1 && + (bottom_blob.c == shape_0 || shape_0 == 1)) + { + top_blob.create(shape_2, shape_1, bottom_blob.c, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < bottom_blob.c; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[0]; + } + } + } + } + else if (bottom_blob.w == 1 && bottom_blob.h == 1 && bottom_blob.c == 1) + { + top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[0]; + } + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else { - fprintf(stderr, "error case\n"); - return -100; + fprintf(stderr, "Layer: Expand, bottom_blob.dims: %d, shape_blob.w: %d\n", bottom_blob.dims, shape_blob.w); + return -1; } - return 0; - } - fprintf(stderr, "Layer: Expand, bottom_blob.dims: %d, shape_blob.w: %d\n", bottom_blob.dims, - shape_blob.w); - return -1; -} } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.h b/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.h old mode 100755 new mode 100644 index 3dca54fb0f..5b280100a4 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.h @@ -4,15 +4,18 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class Expand : public ncnn::Layer { - public: - Expand(); + class Expand : public ncnn::Layer + { + public: + Expand(); - virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const ncnn::Option& opt) const; -}; + virtual int forward(const std::vector& bottom_blobs, + std::vector& top_blobs, + const ncnn::Option& opt) const; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.cpp index 4b6bd34630..15950bdbfa 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.cpp @@ -4,157 +4,183 @@ #include "../ncnn_ops_definer.h" #include "assert.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(Gather) -DEFINE_NCNN_OPS(Gather, Gather) -Gather::Gather() { - one_blob_only = false; - support_inplace = false; -} - -int Gather::load_param(const ParamDict &pd) { - axis = pd.get(0, 0); - - return 0; -} - -// Gather only support 1-dim of indices, because the data and indices all has -// implicit batch in ncnn, this will lead to wrong shape to match onnx result. -// When indices dim equals to 1, after eliminating implicit batch, the indices -// dim still be 1. So there is only 1 implicit batch in data, this will make -// the shape match onnx result. -int Gather::forward(const std::vector &bottom_blobs, std::vector &top_blobs, - const Option &opt) const { - const Mat &bottom_blob = bottom_blobs[0]; - const Mat &indices = bottom_blobs[1]; - int dims = bottom_blob.dims; - int indices_dims = indices.dims; - size_t elemsize = bottom_blob.elemsize; - int positive_axis = axis < 0 ? dims + axis : axis; - Mat &top_blob = top_blobs[0]; - assert(indices.dims == 1); - const float *indices_ptr = indices; - - if (dims == 1 && indices_dims == 1) // positive_axis == 0 - { - int w = indices.w; - top_blob.create(w, elemsize, opt.blob_allocator); - if (top_blob.empty()) { - return -100; - } - const float *ptr = bottom_blob; - float *outptr = top_blob; - for (int i = 0; i < w; i++) { - float indice = indices_ptr[i]; - outptr[i] = ptr[(int)(indice + 0.5)]; +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(Gather) + DEFINE_NCNN_OPS(Gather, Gather) + + Gather::Gather() + { + one_blob_only = false; + support_inplace = false; } - return 0; - } - - if (dims == 2 && positive_axis == 0 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - top_blob.create(w, indices.w, elemsize, opt.blob_allocator); - // w -> w - // h -> indices.w - // h * w -> indices.w * w - if (top_blob.empty()) { - return -100; - } - const float *ptr = bottom_blob; - float *outptr = top_blob; - for (int i = 0; i < indices.w; i++) { - const int selected = (int)(indices_ptr[i] + 0.5); - memcpy(top_blob.row(i), bottom_blob.row(selected), w * elemsize); - } + int Gather::load_param(const ParamDict& pd) + { + axis = pd.get(0, 0); - return 0; - } - - if (dims == 2 && positive_axis == 1 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - top_blob.create(indices.w, h, elemsize, opt.blob_allocator); - // w -> h - // h -> indices.w - // h * w -> indices.w * h - if (top_blob.empty()) { - return -100; - } - const float *ptr = bottom_blob; - float *outptr = top_blob; - for (int j = 0; j < h; j++) { - for (int i = 0; i < indices.w; i++) { - int selected = (int)(indices_ptr[i] + 0.5); - outptr[j * indices.w + i] = ptr[j * w + selected]; - } + return 0; } - return 0; - } - if (dims == 3 && positive_axis == 0 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - int channels = bottom_blob.c; - top_blob.create(w, h, indices.w, elemsize, opt.blob_allocator); + // Gather only support 1-dim of indices, because the data and indices all has + // implicit batch in ncnn, this will lead to wrong shape to match onnx result. + // When indices dim equals to 1, after eliminating implicit batch, the indices + // dim still be 1. So there is only 1 implicit batch in data, this will make + // the shape match onnx result. + int Gather::forward(const std::vector& bottom_blobs, + std::vector& top_blobs, + const Option& opt) const + { + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& indices = bottom_blobs[1]; + int dims = bottom_blob.dims; + int indices_dims = indices.dims; + size_t elemsize = bottom_blob.elemsize; + int positive_axis = axis < 0 ? dims + axis : axis; + Mat& top_blob = top_blobs[0]; + assert(indices.dims == 1); + const float* indices_ptr = indices; + + if (dims == 1 && indices_dims == 1) // positive_axis == 0 + { + int w = indices.w; + top_blob.create(w, elemsize, opt.blob_allocator); + if (top_blob.empty()) + { + return -100; + } + const float* ptr = bottom_blob; + float* outptr = top_blob; + for (int i = 0; i < w; i++) + { + float indice = indices_ptr[i]; + outptr[i] = ptr[(int)(indice + 0.5)]; + } + + return 0; + } - if (top_blob.empty()) { - return -100; - } - for (int i = 0; i < indices.w; i++) { - int selected = (int)(indices_ptr[i] + 0.5); - const unsigned char *ptr = bottom_blob.channel(selected); - unsigned char *outptr = top_blob.channel(i); + if (dims == 2 && positive_axis == 0 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + top_blob.create(w, indices.w, elemsize, opt.blob_allocator); + // w -> w + // h -> indices.w + // h * w -> indices.w * w + if (top_blob.empty()) + { + return -100; + } + const float* ptr = bottom_blob; + float* outptr = top_blob; + for (int i = 0; i < indices.w; i++) + { + const int selected = (int)(indices_ptr[i] + 0.5); + memcpy(top_blob.row(i), bottom_blob.row(selected), w * elemsize); + } + + return 0; + } - memcpy(outptr, ptr, w * h * elemsize); - } - return 0; - } - - if (dims == 3 && positive_axis == 1 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - int channels = bottom_blob.c; - top_blob.create(w, indices.w, channels, elemsize, opt.blob_allocator); -#pragma omp parallel for num_threads(opt.num_threads) - // use parallel programming - for (int i = 0; i < channels; i++) { - float *outptr = top_blob.channel(i); - const float *ptr = bottom_blob.channel(i); - for (int j = 0; j < indices.w; j++) { - int selected = (int)(indices_ptr[j] + 0.5); - for (int k = 0; k < w; k++) { - outptr[j * w + k] = ptr[selected * w + k]; + if (dims == 2 && positive_axis == 1 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + top_blob.create(indices.w, h, elemsize, opt.blob_allocator); + // w -> h + // h -> indices.w + // h * w -> indices.w * h + if (top_blob.empty()) + { + return -100; + } + const float* ptr = bottom_blob; + float* outptr = top_blob; + for (int j = 0; j < h; j++) + { + for (int i = 0; i < indices.w; i++) + { + int selected = (int)(indices_ptr[i] + 0.5); + outptr[j * indices.w + i] = ptr[j * w + selected]; + } + } + return 0; } - } - } - return 0; - } + if (dims == 3 && positive_axis == 0 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + top_blob.create(w, h, indices.w, elemsize, opt.blob_allocator); + + if (top_blob.empty()) + { + return -100; + } + for (int i = 0; i < indices.w; i++) + { + int selected = (int)(indices_ptr[i] + 0.5); + const unsigned char* ptr = bottom_blob.channel(selected); + unsigned char* outptr = top_blob.channel(i); + + memcpy(outptr, ptr, w * h * elemsize); + } + return 0; + } - if (dims == 3 && positive_axis == 2 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - int channels = bottom_blob.c; - top_blob.create(indices.w, h, channels, elemsize, opt.blob_allocator); + if (dims == 3 && positive_axis == 1 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + top_blob.create(w, indices.w, channels, elemsize, opt.blob_allocator); #pragma omp parallel for num_threads(opt.num_threads) - // use parallel programming - for (int i = 0; i < channels; i++) { - float *outptr = top_blob.channel(i); - const float *ptr = bottom_blob.channel(i); - for (int j = 0; j < h; j++) { - for (int k = 0; k < indices.w; k++) { - int selected = (int)(indices_ptr[k] + 0.5); - outptr[j * indices.w + k] = ptr[j * w + selected]; + // use parallel programming + for (int i = 0; i < channels; i++) + { + float* outptr = top_blob.channel(i); + const float* ptr = bottom_blob.channel(i); + for (int j = 0; j < indices.w; j++) + { + int selected = (int)(indices_ptr[j] + 0.5); + for (int k = 0; k < w; k++) + { + outptr[j * w + k] = ptr[selected * w + k]; + } + } + } + + return 0; } - } - } - return 0; - } - return 0; -} + if (dims == 3 && positive_axis == 2 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + top_blob.create(indices.w, h, channels, elemsize, opt.blob_allocator); +#pragma omp parallel for num_threads(opt.num_threads) + // use parallel programming + for (int i = 0; i < channels; i++) + { + float* outptr = top_blob.channel(i); + const float* ptr = bottom_blob.channel(i); + for (int j = 0; j < h; j++) + { + for (int k = 0; k < indices.w; k++) + { + int selected = (int)(indices_ptr[k] + 0.5); + outptr[j * indices.w + k] = ptr[j * w + selected]; + } + } + } + return 0; + } + + return 0; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.h b/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.h old mode 100755 new mode 100644 index af6eb6365e..e7bfb717c8 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.h @@ -4,20 +4,23 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class Gather : public ncnn::Layer { - public: - Gather(); + class Gather : public ncnn::Layer + { + public: + Gather(); - virtual int load_param(const ncnn::ParamDict& pd); + virtual int load_param(const ncnn::ParamDict& pd); - virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const ncnn::Option& opt) const; + virtual int forward(const std::vector& bottom_blobs, + std::vector& top_blobs, + const ncnn::Option& opt) const; - public: - int axis; -}; + public: + int axis; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_definer.h b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_definer.h old mode 100755 new mode 100644 index 509c8c0ce0..bd5d9ca23e --- a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_definer.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_definer.h @@ -7,22 +7,24 @@ #include "layer.h" #include "ncnn_ops_register.h" -namespace mmdeploy { - -class NCNNOpsDefiner { - public: - NCNNOpsDefiner(const std::string& ops_name, const ncnn::layer_creator_func& creator_func = 0, - const ncnn::layer_destroyer_func& destroyer_func = 0) - : _ops_name(ops_name) { - get_mmdeploy_layer_creator()[_ops_name.c_str()] = creator_func; - } - - private: - const std::string _ops_name; -}; +namespace mmdeploy +{ + + class NCNNOpsDefiner + { + public: + NCNNOpsDefiner(const std::string& ops_name, const ncnn::layer_creator_func& creator_func = 0, const ncnn::layer_destroyer_func& destroyer_func = 0) + : _ops_name(ops_name) + { + get_mmdeploy_layer_creator()[_ops_name.c_str()] = creator_func; + } + + private: + const std::string _ops_name; + }; #define DEFINE_NCNN_OPS(ops_name, OpsLayer) \ - static mmdeploy::NCNNOpsDefiner NCNNOpsDefiner##ops_name{#ops_name, OpsLayer##_layer_creator}; + static mmdeploy::NCNNOpsDefiner NCNNOpsDefiner##ops_name{#ops_name, OpsLayer##_layer_creator}; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.cpp old mode 100755 new mode 100644 index 42bc050a1c..85d4f66d04 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.cpp @@ -3,32 +3,38 @@ #include -std::map &get_mmdeploy_layer_creator() { - static std::map _layer_creator_map; - return _layer_creator_map; +std::map& get_mmdeploy_layer_creator() +{ + static std::map _layer_creator_map; + return _layer_creator_map; } -std::map &get_mmdeploy_layer_destroyer() { - static std::map _layer_destroyer_map; - return _layer_destroyer_map; +std::map& get_mmdeploy_layer_destroyer() +{ + static std::map _layer_destroyer_map; + return _layer_destroyer_map; } -int register_mmdeploy_custom_layers(ncnn::Net &net) { - auto &layer_creator_map = get_mmdeploy_layer_creator(); - auto &layer_destroyer_map = get_mmdeploy_layer_destroyer(); +int register_mmdeploy_custom_layers(ncnn::Net& net) +{ + auto& layer_creator_map = get_mmdeploy_layer_creator(); + auto& layer_destroyer_map = get_mmdeploy_layer_destroyer(); - for (auto const &creator_pair : layer_creator_map) { - auto creator_name = creator_pair.first; - auto creator_func = creator_pair.second; + for (auto const& creator_pair : layer_creator_map) + { + auto creator_name = creator_pair.first; + auto creator_func = creator_pair.second; - ncnn::layer_destroyer_func destroyer_func = 0; - if (layer_destroyer_map.find(creator_name) != layer_destroyer_map.end()) { - destroyer_func = layer_destroyer_map[creator_name]; + ncnn::layer_destroyer_func destroyer_func = 0; + if (layer_destroyer_map.find(creator_name) != layer_destroyer_map.end()) + { + destroyer_func = layer_destroyer_map[creator_name]; + } + int ret = net.register_custom_layer(creator_name, creator_func, destroyer_func); + if (0 != ret) + { + return ret; + } } - int ret = net.register_custom_layer(creator_name, creator_func, destroyer_func); - if (0 != ret) { - return ret; - } - } - return 0; + return 0; } diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.h b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.h old mode 100755 new mode 100644 index 0d9974f783..32c918156c --- a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.h @@ -9,8 +9,9 @@ #include "net.h" MMDEPLOY_API std::map& get_mmdeploy_layer_creator(); + MMDEPLOY_API std::map& get_mmdeploy_layer_destroyer(); -MMDEPLOY_API int register_mmdeploy_custom_layers(ncnn::Net& net); +MMDEPLOY_API int register_mmdeploy_custom_layers(ncnn::Net& net); #endif diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.cpp old mode 100755 new mode 100644 index f538eabbac..cce2935ba1 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.cpp @@ -3,45 +3,59 @@ #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(Shape) -DEFINE_NCNN_OPS(Shape, Shape) -Shape::Shape() { - one_blob_only = true; - support_inplace = false; -} +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(Shape) + DEFINE_NCNN_OPS(Shape, Shape) -int Shape::forward(const Mat &bottom_blob, Mat &top_blob, const Option &opt) const { - int dims = bottom_blob.dims; - int w = bottom_blob.w; - size_t elemsize = sizeof(float); - top_blob.create(dims + 1, elemsize, opt.blob_allocator); - if (top_blob.empty()) { - return -100; - } - float *outptr = top_blob; + Shape::Shape() + { + one_blob_only = true; + support_inplace = false; + } - if (dims == 1) { - outptr[0] = 1.0f; - outptr[1] = w; - } else if (dims == 2) { - int h = bottom_blob.h; - outptr[0] = 1.0f; - outptr[1] = h; - outptr[2] = w; - } else if (dims == 3) { - int h = bottom_blob.h; - int channels = bottom_blob.c; - outptr[0] = 1.0f; - outptr[1] = channels; - outptr[2] = h; - outptr[3] = w; - } else { - fprintf(stdout, "Unsupported dims=%d\n", dims); - } + int Shape::forward(const Mat& bottom_blob, + Mat& top_blob, + const Option& opt) const + { + int dims = bottom_blob.dims; + int w = bottom_blob.w; + size_t elemsize = sizeof(float); + top_blob.create(dims + 1, elemsize, opt.blob_allocator); + if (top_blob.empty()) + { + return -100; + } + float* outptr = top_blob; - return 0; -} + if (dims == 1) + { + outptr[0] = 1.0f; + outptr[1] = w; + } + else if (dims == 2) + { + int h = bottom_blob.h; + outptr[0] = 1.0f; + outptr[1] = h; + outptr[2] = w; + } + else if (dims == 3) + { + int h = bottom_blob.h; + int channels = bottom_blob.c; + outptr[0] = 1.0f; + outptr[1] = channels; + outptr[2] = h; + outptr[3] = w; + } + else + { + fprintf(stdout, "Unsupported dims=%d\n", dims); + } + + return 0; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.h b/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.h old mode 100755 new mode 100644 index 863dc77c1d..2c1e4573bf --- a/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.h @@ -4,15 +4,18 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class Shape : public ncnn::Layer { - public: - Shape(); + class Shape : public ncnn::Layer + { + public: + Shape(); - virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, - const ncnn::Option& opt) const; -}; + virtual int forward(const ncnn::Mat& bottom_blob, + ncnn::Mat& top_blob, + const ncnn::Option& opt) const; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.cpp index 9f2ced1992..8b1e35ae66 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.cpp @@ -5,202 +5,253 @@ #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(TensorSlice) -DEFINE_NCNN_OPS(TensorSlice, TensorSlice) -TensorSlice::TensorSlice() { - one_blob_only = true; - support_inplace = false; -} +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(TensorSlice) + DEFINE_NCNN_OPS(TensorSlice, TensorSlice) -int TensorSlice::load_param(const ParamDict& pd) { - starts = pd.get(0, Mat()); - ends = pd.get(1, Mat()); - axes = pd.get(2, Mat()); - steps = pd.get(3, Mat()); - if (axes.w == 0) { - axes.create(starts.w); - int* axes_ptr = axes; - for (int i = 0; i < starts.w; i++) { - axes_ptr[i] = i; + TensorSlice::TensorSlice() + { + one_blob_only = true; + support_inplace = false; } - } - if (steps.w == 0) { - steps.create(axes.w); - steps.fill(1); - } - return 0; -} -static inline int get_shape_by_axes(const Mat& blob, int axes, int dims) { - switch (dims - axes) { - case 0: - return blob.w; - case 1: - return blob.h; - case 2: - return blob.c; - default: - fprintf(stderr, "wrong axes %d!\n", axes); - return -1; - } - return 0; -} - -int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { - int dims = bottom_blob.dims; - size_t elemsize = bottom_blob.elemsize; - const int* start_ptr = starts; - const int* end_ptr = ends; - const int* axes_ptr = axes; - const int* step_ptr = steps; - if (starts.w > dims || ends.w > dims) { - fprintf(stderr, "start/end attributes shape error!\n"); - return -100; - } - if (axes.w != 1) { - fprintf(stderr, - "axes.w must be 1 because any of multiaxes slice is regarded as " - "multi-staged onnx slice in pytorch2onnx."); - } - if (dims == 1) { - for (int i = 0; i < axes.w; i++) { - int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; - int step = step_ptr[i]; - std::vector temp_val; - int start = start_ptr[i]; - int end = end_ptr[i]; - int cur = start; - if (step > 0) { - while (cur < end && cur < bottom_blob.w) { - temp_val.push_back(bottom_blob[cur]); - cur += step; + int TensorSlice::load_param(const ParamDict& pd) + { + starts = pd.get(0, Mat()); + ends = pd.get(1, Mat()); + axes = pd.get(2, Mat()); + steps = pd.get(3, Mat()); + if (axes.w == 0) + { + axes.create(starts.w); + int* axes_ptr = axes; + for (int i = 0; i < starts.w; i++) + { + axes_ptr[i] = i; + } } - } else if (step < 0) { - while (cur > end && cur > 0) { - temp_val.push_back(bottom_blob[cur]); - cur += step; + if (steps.w == 0) + { + steps.create(axes.w); + steps.fill(1); } - } else { - fprintf(stderr, "step should not be 0!\n"); - return -100; - } - top_blob.create(temp_val.size(), elemsize, opt.blob_allocator); - for (int i = 0; i < temp_val.size(); i++) { - top_blob[i] = temp_val[i]; - } - } - return 0; - } - if (dims == 2) { - std::vector > active_indice; - std::vector indices; - for (int i = 0; i < bottom_blob.h; i++) { - indices.push_back(i); - } - active_indice.push_back(indices); - indices.clear(); - for (int i = 0; i < bottom_blob.w; i++) { - indices.push_back(i); + return 0; } - active_indice.push_back(indices); - for (int i = 0; i < axes.w; i++) { - int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; - int step = step_ptr[i]; - int start = start_ptr[i]; - int end = end_ptr[i]; - int dim_shape = get_shape_by_axes(bottom_blob, positive_axis, dims); - int dim_shape_test = get_shape_by_axes(bottom_blob, positive_axis, dims - 1); - if (dim_shape < 0) { - return -1; - } - end = end < dim_shape ? end : dim_shape; - int cur = start; - std::vector temp_indice; - if (step > 0) { - while (cur < end && cur < dim_shape) { - temp_indice.push_back(cur); - cur += step; - } - } else if (step < 0) { - while (cur > end && cur > 0) { - temp_indice.push_back(cur); - cur += step; - } - } else { - fprintf(stderr, "step should not be 0!\n"); - return -100; - } - active_indice[positive_axis - 1] = temp_indice; - active_indice[positive_axis - 1].resize(temp_indice.size()); - } - top_blob.create((int)active_indice[1].size(), (int)active_indice[0].size(), elemsize, - opt.blob_allocator); - for (int i = 0; i < active_indice[0].size(); i++) { - for (int j = 0; j < active_indice[1].size(); j++) { - top_blob.row(i)[j] = bottom_blob.row(active_indice[0][i])[active_indice[1][j]]; - } - } - return 0; - } - if (dims == 3) { - std::vector > active_indice; - std::vector indices; - for (int i = 0; i < bottom_blob.c; i++) { - indices.push_back(i); - } - active_indice.push_back(indices); - indices.clear(); - for (int i = 0; i < bottom_blob.h; i++) { - indices.push_back(i); - } - active_indice.push_back(indices); - indices.clear(); - for (int i = 0; i < bottom_blob.w; i++) { - indices.push_back(i); + static inline int get_shape_by_axes(const Mat& blob, int axes, int dims) + { + switch (dims - axes) + { + case 0: + return blob.w; + case 1: + return blob.h; + case 2: + return blob.c; + default: + fprintf(stderr, "wrong axes %d!\n", axes); + return -1; + } + return 0; } - active_indice.push_back(indices); - for (int i = 0; i < axes.w; i++) { - int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; - int step = step_ptr[i]; - int start = start_ptr[i]; - int end = end_ptr[i]; - int cur = start; - std::vector temp_indice; - if (step > 0) { - while (cur < end && cur < bottom_blob.w) { - temp_indice.push_back(cur); - cur += step; + int TensorSlice::forward(const Mat& bottom_blob, + Mat& top_blob, + const Option& opt) const + { + int dims = bottom_blob.dims; + size_t elemsize = bottom_blob.elemsize; + const int* start_ptr = starts; + const int* end_ptr = ends; + const int* axes_ptr = axes; + const int* step_ptr = steps; + if (starts.w > dims || ends.w > dims) + { + fprintf(stderr, "start/end attributes shape error!\n"); + return -100; } - } else if (step < 0) { - while (cur > end && cur > 0) { - temp_indice.push_back(cur); - cur += step; + if (axes.w != 1) + { + fprintf(stderr, + "axes.w must be 1 because any of multiaxes slice is regarded as " + "multi-staged onnx slice in pytorch2onnx."); } - } else { - fprintf(stderr, "step should not be 0!\n"); - return -100; - } - active_indice[positive_axis - 1] = temp_indice; - active_indice[positive_axis - 1].resize(temp_indice.size()); - } - top_blob.create((int)active_indice[2].size(), (int)active_indice[1].size(), - (int)active_indice[0].size(), elemsize, opt.blob_allocator); - for (int i = 0; i < active_indice[0].size(); i++) { - for (int j = 0; j < active_indice[1].size(); j++) { - for (int k = 0; k < active_indice[2].size(); k++) { - top_blob.channel(i).row(j)[k] = bottom_blob.channel(active_indice[0][i]) - .row(active_indice[1][j])[active_indice[2][k]]; + if (dims == 1) + { + for (int i = 0; i < axes.w; i++) + { + int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; + int step = step_ptr[i]; + std::vector temp_val; + int start = start_ptr[i]; + int end = end_ptr[i]; + int cur = start; + if (step > 0) + { + while (cur < end && cur < bottom_blob.w) + { + temp_val.push_back(bottom_blob[cur]); + cur += step; + } + } + else if (step < 0) + { + while (cur > end && cur > 0) + { + temp_val.push_back(bottom_blob[cur]); + cur += step; + } + } + else + { + fprintf(stderr, "step should not be 0!\n"); + return -100; + } + top_blob.create(temp_val.size(), elemsize, opt.blob_allocator); + for (int i = 0; i < temp_val.size(); i++) + { + top_blob[i] = temp_val[i]; + } + } + return 0; + } + if (dims == 2) + { + std::vector> active_indice; + std::vector indices; + for (int i = 0; i < bottom_blob.h; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + indices.clear(); + for (int i = 0; i < bottom_blob.w; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + for (int i = 0; i < axes.w; i++) + { + int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; + int step = step_ptr[i]; + int start = start_ptr[i]; + int end = end_ptr[i]; + int dim_shape = get_shape_by_axes(bottom_blob, positive_axis, dims); + int dim_shape_test = get_shape_by_axes(bottom_blob, positive_axis, dims - 1); + if (dim_shape < 0) + { + return -1; + } + end = end < dim_shape ? end : dim_shape; + int cur = start; + std::vector temp_indice; + if (step > 0) + { + while (cur < end && cur < dim_shape) + { + temp_indice.push_back(cur); + cur += step; + } + } + else if (step < 0) + { + while (cur > end && cur > 0) + { + temp_indice.push_back(cur); + cur += step; + } + } + else + { + fprintf(stderr, "step should not be 0!\n"); + return -100; + } + active_indice[positive_axis - 1] = temp_indice; + active_indice[positive_axis - 1].resize(temp_indice.size()); + } + top_blob.create((int)active_indice[1].size(), (int)active_indice[0].size(), elemsize, opt.blob_allocator); + for (int i = 0; i < active_indice[0].size(); i++) + { + for (int j = 0; j < active_indice[1].size(); j++) + { + top_blob.row(i)[j] = bottom_blob.row(active_indice[0][i])[active_indice[1][j]]; + } + } + return 0; } - } - } - return 0; - } - return 0; -} + if (dims == 3) + { + std::vector> active_indice; + std::vector indices; + for (int i = 0; i < bottom_blob.c; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + indices.clear(); + for (int i = 0; i < bottom_blob.h; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + indices.clear(); + for (int i = 0; i < bottom_blob.w; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + for (int i = 0; i < axes.w; i++) + { + int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; + int step = step_ptr[i]; + + int start = start_ptr[i]; + int end = end_ptr[i]; + int cur = start; + std::vector temp_indice; + if (step > 0) + { + while (cur < end && cur < bottom_blob.w) + { + temp_indice.push_back(cur); + cur += step; + } + } + else if (step < 0) + { + while (cur > end && cur > 0) + { + temp_indice.push_back(cur); + cur += step; + } + } + else + { + fprintf(stderr, "step should not be 0!\n"); + return -100; + } + active_indice[positive_axis - 1] = temp_indice; + active_indice[positive_axis - 1].resize(temp_indice.size()); + } + top_blob.create((int)active_indice[2].size(), (int)active_indice[1].size(), (int)active_indice[0].size(), elemsize, opt.blob_allocator); + for (int i = 0; i < active_indice[0].size(); i++) + { + for (int j = 0; j < active_indice[1].size(); j++) + { + for (int k = 0; k < active_indice[2].size(); k++) + { + top_blob.channel(i).row(j)[k] = bottom_blob.channel(active_indice[0][i]) + .row(active_indice[1][j])[active_indice[2][k]]; + } + } + } + return 0; + } + + return 0; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.h b/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.h old mode 100755 new mode 100644 index 9164d43335..fbffdcb843 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.h @@ -4,23 +4,26 @@ #include "layer.h" -namespace mmdeploy { - -class TensorSlice : public ncnn::Layer { - public: - TensorSlice(); - - virtual int load_param(const ncnn::ParamDict& pd); - - virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, - const ncnn::Option& opt) const; - - public: - ncnn::Mat starts; - ncnn::Mat ends; - ncnn::Mat axes; - ncnn::Mat steps; -}; +namespace mmdeploy +{ + + class TensorSlice : public ncnn::Layer + { + public: + TensorSlice(); + + virtual int load_param(const ncnn::ParamDict& pd); + + virtual int forward(const ncnn::Mat& bottom_blob, + ncnn::Mat& top_blob, + const ncnn::Option& opt) const; + + public: + ncnn::Mat starts; + ncnn::Mat ends; + ncnn::Mat axes; + ncnn::Mat steps; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.cpp index f618831568..cfa55d1f8e 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.cpp @@ -6,872 +6,1122 @@ #include #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(TopK) -DEFINE_NCNN_OPS(TopK, TopK) - -TopK::TopK() { - one_blob_only = false; - support_inplace = false; -} -int TopK::load_param(const ParamDict& pd) { - axis = pd.get(0, -1); - largest = pd.get(1, 1); - sorted = pd.get(2, 1); - keep_dims = pd.get(3, 1); - - return 0; -} -int TopK::forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const Option& opt) const { - int dims = bottom_blobs[0].dims; - int positive_axis = axis < 0 ? dims + axis : axis; - int topk; - if (bottom_blobs.size() == 2) { - const Mat& topk_blob = bottom_blobs[1]; - topk = (int)(topk_blob[0] + 0.5); - } else if (bottom_blobs.size() == 1) { - topk = 1; - } else { - fprintf(stderr, "topk input blobs should be 1 or 2, but not %ld\n", bottom_blobs.size()); - return -103; - } - - // To do: Cut the top_val_blob after unit test. And we should change them in - // param files. - // Adaptive outputs. For onnx TopK, we output 2 blobs, for ArgMax, we output - // 1 blob. - Mat& top_val_blob = top_blobs[0]; - Mat& top_ind_blob = top_blobs.size() == 2 ? top_blobs[1] : top_val_blob; - - if (topk > 1) { - // real topk - if (keep_dims == 0) { - fprintf(stderr, "real topk should not reduce dims!\n"); - return -102; +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(TopK) + DEFINE_NCNN_OPS(TopK, TopK) + + TopK::TopK() + { + one_blob_only = false; + support_inplace = false; } - if (dims == 1 && positive_axis == 0) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - const float* ptr = bottom_blobs[0]; - std::vector > vec; - vec.resize(bottom_blobs[0].w); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(ptr[i], -i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(ptr[i], i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - float* valptr = top_val_blob; - float* indptr = top_ind_blob; - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - valptr[i] = vec[i].first; - indptr[i] = abs(vec[i].second); - } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - - // pair comparison - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0][i] > valtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } else if (bottom_blobs[0][i] == valtarget && i <= indtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0][i] < valtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } else if (bottom_blobs[0][i] == valtarget && i <= indtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } - } - } - } + + int TopK::load_param(const ParamDict& pd) + { + axis = pd.get(0, -1); + largest = pd.get(1, 1); + sorted = pd.get(2, 1); + keep_dims = pd.get(3, 1); + + return 0; } - if (dims == 2 && positive_axis == 0) { - if (topk > bottom_blobs[0].h) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int col = 0; col < bottom_blobs[0].w; col++) { - std::vector > vec; - vec.resize(bottom_blobs[0].h); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], -i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; + + int TopK::forward(const std::vector& bottom_blobs, + std::vector& top_blobs, + const Option& opt) const + { + int dims = bottom_blobs[0].dims; + int positive_axis = axis < 0 ? dims + axis : axis; + int topk; + if (bottom_blobs.size() == 2) + { + const Mat& topk_blob = bottom_blobs[1]; + topk = (int)(topk_blob[0] + 0.5); } - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.row(i)[col] = vec[i].first; - top_ind_blob.row(i)[col] = abs(vec[i].second); - } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(i)[col] > valtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].h; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(i)[col] < valtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } - } - } - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; + else if (bottom_blobs.size() == 1) + { + topk = 1; } - } - } - if (dims == 2 && positive_axis == 1) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int r = 0; r < bottom_blobs[0].h; r++) { - std::vector > vec; - vec.resize(bottom_blobs[0].w); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], -i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; + else + { + fprintf(stderr, "topk input blobs should be 1 or 2, but not %ld\n", bottom_blobs.size()); + return -103; } - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.row(r)[i] = vec[i].first; - top_ind_blob.row(r)[i] = abs(vec[i].second); - } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(r)[i] > valtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(r)[i] < valtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; - cur++; - } - } - } + // To do: Cut the top_val_blob after unit test. And we should change them in + // param files. + // Adaptive outputs. For onnx TopK, we output 2 blobs, for ArgMax, we output + // 1 blob. + Mat& top_val_blob = top_blobs[0]; + Mat& top_ind_blob = top_blobs.size() == 2 ? top_blobs[1] : top_val_blob; - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; - } - } - } - if (dims == 3 && positive_axis == 0) { - if (topk > bottom_blobs[0].c) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int col = 0; col < bottom_blobs[0].w; col++) { - std::vector > vec; - vec.resize(bottom_blobs[0].c); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], -i); + if (topk > 1) + { + // real topk + if (keep_dims == 0) + { + fprintf(stderr, "real topk should not reduce dims!\n"); + return -102; } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.channel(i).row(r)[col] = vec[i].first; - top_ind_blob.channel(i).row(r)[col] = abs(vec[i].second); + if (dims == 1 && positive_axis == 0) + { + if (topk > bottom_blobs[0].w) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + const float* ptr = bottom_blobs[0]; + std::vector> vec; + vec.resize(bottom_blobs[0].w); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(ptr[i], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(ptr[i], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + float* valptr = top_val_blob; + float* indptr = top_ind_blob; + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + valptr[i] = vec[i].first; + indptr[i] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + + // pair comparison + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0][i] > valtarget) + { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; + cur++; + } + else if (bottom_blobs[0][i] == valtarget && i <= indtarget) + { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0][i] < valtarget) + { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; + cur++; + } + else if (bottom_blobs[0][i] == valtarget && i <= indtarget) + { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; + cur++; + } + } + } + } } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].channel(i).row(r)[col] > valtarget) { - top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; - cur++; - } else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && i <= indtarget) { - top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].c; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].channel(i).row(r)[col] < valtarget) { - top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; - cur++; - } else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && i <= indtarget) { - top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; - cur++; - } - } + if (dims == 2 && positive_axis == 0) + { + if (topk > bottom_blobs[0].h) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int col = 0; col < bottom_blobs[0].w; col++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].h); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.row(i)[col] = vec[i].first; + top_ind_blob.row(i)[col] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].row(i)[col] > valtarget) + { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; + cur++; + } + else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) + { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].row(i)[col] < valtarget) + { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; + cur++; + } + else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) + { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } } + if (dims == 2 && positive_axis == 1) + { + if (topk > bottom_blobs[0].w) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; - } - } - } - } - if (dims == 3 && positive_axis == 1) { - if (topk > bottom_blobs[0].h) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int page = 0; page < bottom_blobs[0].c; page++) { - for (int col = 0; col < bottom_blobs[0].w; col++) { - std::vector > vec; - vec.resize(bottom_blobs[0].h); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(i)[col], -i); + top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int r = 0; r < bottom_blobs[0].h; r++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].w); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.row(r)[i] = vec[i].first; + top_ind_blob.row(r)[i] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].row(r)[i] > valtarget) + { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; + cur++; + } + else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) + { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].row(r)[i] < valtarget) + { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; + cur++; + } + else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) + { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(i)[col], i); + if (dims == 3 && positive_axis == 0) + { + if (topk > bottom_blobs[0].c) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int col = 0; col < bottom_blobs[0].w; col++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].c); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.channel(i).row(r)[col] = vec[i].first; + top_ind_blob.channel(i).row(r)[col] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].channel(i).row(r)[col] > valtarget) + { + top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && i <= indtarget) + { + top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].channel(i).row(r)[col] < valtarget) + { + top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && i <= indtarget) + { + top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } + } } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.channel(page).row(i)[col] = vec[i].first; - top_ind_blob.channel(page).row(i)[col] = abs(vec[i].second); + if (dims == 3 && positive_axis == 1) + { + if (topk > bottom_blobs[0].h) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int page = 0; page < bottom_blobs[0].c; page++) + { + for (int col = 0; col < bottom_blobs[0].w; col++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].h); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(i)[col], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(i)[col], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.channel(page).row(i)[col] = vec[i].first; + top_ind_blob.channel(page).row(i)[col] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + for (int i = 0; i < bottom_blobs[0].h; i++) + { + if (cur >= topk) break; + if (largest == 1) + { + if (bottom_blobs[0].channel(page).row(i)[col] > valtarget) + { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } + else if (bottom_blobs[0].channel(page).row(i)[col] == valtarget && + i <= indtarget) + { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } + } + else + { + if (bottom_blobs[0].channel(page).row(i)[col] < valtarget) + { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } + else if (bottom_blobs[0].channel(page).row(i)[col] == valtarget && + i <= indtarget) + { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } + } } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - for (int i = 0; i < bottom_blobs[0].h; i++) { - if (cur >= topk) break; - if (largest == 1) { - if (bottom_blobs[0].channel(page).row(i)[col] > valtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(i)[col] == valtarget && - i <= indtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; - cur++; - } - } else { - if (bottom_blobs[0].channel(page).row(i)[col] < valtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(i)[col] == valtarget && - i <= indtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; - cur++; - } - } + if (dims == 3 && positive_axis == 2) + { + if (topk > bottom_blobs[0].w) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int page = 0; page < bottom_blobs[0].c; page++) + { + for (int r = 0; r < bottom_blobs[0].h; r++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].w); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(r)[i], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(r)[i], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.channel(page).row(r)[i] = vec[i].first; + top_ind_blob.channel(page).row(r)[i] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].channel(page).row(r)[i] > valtarget) + { + top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } + else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && i <= indtarget) + { + top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].channel(page).row(r)[i] < valtarget) + { + top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } + else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && i <= indtarget) + { + top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } + } } - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; - } } - } - } - if (dims == 3 && positive_axis == 2) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int page = 0; page < bottom_blobs[0].c; page++) { - for (int r = 0; r < bottom_blobs[0].h; r++) { - std::vector > vec; - vec.resize(bottom_blobs[0].w); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(r)[i], -i); + else + { + if (topk <= 0) + { + fprintf(stderr, "topk should not <= 0!\n"); + return -102; } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(r)[i], i); + if (dims == 1 && positive_axis == 0) + { + if (topk > bottom_blobs[0].w) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + if (top_blobs.size() == 2) + { + top_ind_blob.create(topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].w); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = ptr[i]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[0] = *index_iter; + if (top_blobs.size() == 2) + indptr[0] = std::distance(vec.begin(), index_iter); + else + valptr[0] = std::distance(vec.begin(), index_iter); // replace with index + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[0] = *index_iter; + if (top_blobs.size() == 2) + indptr[0] = std::distance(vec.begin(), index_iter); + else + valptr[0] = std::distance(vec.begin(), index_iter); // replace with index + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.channel(page).row(r)[i] = vec[i].first; - top_ind_blob.channel(page).row(r)[i] = abs(vec[i].second); + if (dims == 2 && positive_axis == 0) + { + if (keep_dims == 1) + { + top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + else + { + top_val_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].h); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + for (int col = 0; col < bottom_blobs[0].w; col++) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = ptr[i * bottom_blobs[0].w + col]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].channel(page).row(r)[i] > valtarget) { - top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].channel(page).row(r)[i] < valtarget) { - top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } - } + if (dims == 2 && positive_axis == 1) + { + if (keep_dims == 1) + { + top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + else + { + top_val_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].w); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = ptr[r * bottom_blobs[0].w + i]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } } + if (dims == 3 && positive_axis == 0) + { + if (keep_dims == 1) + { + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + else + { + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].c); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; - } - } - } - } - } else { - if (topk <= 0) { - fprintf(stderr, "topk should not <= 0!\n"); - return -102; - } - if (dims == 1 && positive_axis == 0) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - if (top_blobs.size() == 2) { - top_ind_blob.create(topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } - - const float* ptr = bottom_blobs[0]; - std::vector vec; - vec.resize(bottom_blobs[0].w); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = ptr[i]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[0] = *index_iter; - if (top_blobs.size() == 2) - indptr[0] = std::distance(vec.begin(), index_iter); - else - valptr[0] = std::distance(vec.begin(), index_iter); // replace with index - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[0] = *index_iter; - if (top_blobs.size() == 2) - indptr[0] = std::distance(vec.begin(), index_iter); - else - valptr[0] = std::distance(vec.begin(), index_iter); // replace with index - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - if (dims == 2 && positive_axis == 0) { - if (keep_dims == 1) { - top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int col = 0; col < bottom_blobs[0].w; col++) + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + ptr = bottom_blobs[0].channel(i); + vec[i] = ptr[r * bottom_blobs[0].w + col]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[r * top_val_blob.w + col] = *index_iter; + if (top_blobs.size() == 2) + indptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + else + valptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[r * top_val_blob.w + col] = *index_iter; - } else { - top_val_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + indptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + else + valptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } + if (dims == 3 && positive_axis == 1) + { + if (keep_dims == 1) + { + top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } - } - const float* ptr = bottom_blobs[0]; - std::vector vec; - vec.resize(bottom_blobs[0].h); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - for (int col = 0; col < bottom_blobs[0].w; col++) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = ptr[i * bottom_blobs[0].w + col]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[col] = *index_iter; - if (top_blobs.size() == 2) - indptr[col] = std::distance(vec.begin(), index_iter); - else - valptr[col] = std::distance(vec.begin(), index_iter); - - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[col] = *index_iter; - if (top_blobs.size() == 2) - indptr[col] = std::distance(vec.begin(), index_iter); - else - valptr[col] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - if (dims == 2 && positive_axis == 1) { - if (keep_dims == 1) { - top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + std::vector vec; + vec.resize(bottom_blobs[0].h); - } else { - top_val_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } - } - - const float* ptr = bottom_blobs[0]; - std::vector vec; - vec.resize(bottom_blobs[0].w); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = ptr[r * bottom_blobs[0].w + i]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[r] = *index_iter; - if (top_blobs.size() == 2) - indptr[r] = std::distance(vec.begin(), index_iter); - else - valptr[r] = std::distance(vec.begin(), index_iter); - - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[r] = *index_iter; - if (top_blobs.size() == 2) - indptr[r] = std::distance(vec.begin(), index_iter); - else - valptr[r] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - if (dims == 3 && positive_axis == 0) { - if (keep_dims == 1) { - top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + for (int page = 0; page < bottom_blobs[0].c; page++) + { + const float* ptr = bottom_blobs[0].channel(page); + float* valptr = top_val_blob.channel(page); + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); + for (int col = 0; col < bottom_blobs[0].w; col++) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = ptr[i * bottom_blobs[0].w + col]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } + else + { + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } - } else { - top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } - } - const float* ptr = bottom_blobs[0]; - std::vector vec; - vec.resize(bottom_blobs[0].c); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int col = 0; col < bottom_blobs[0].w; col++) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - ptr = bottom_blobs[0].channel(i); - vec[i] = ptr[r * bottom_blobs[0].w + col]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[r * top_val_blob.w + col] = *index_iter; - if (top_blobs.size() == 2) - indptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - else - valptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[r * top_val_blob.w + col] = *index_iter; - - if (top_blobs.size() == 2) - indptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - else - valptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - } - if (dims == 3 && positive_axis == 1) { - if (keep_dims == 1) { - top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + std::vector vec; + vec.resize(bottom_blobs[0].h); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; - std::vector vec; - vec.resize(bottom_blobs[0].h); - - for (int page = 0; page < bottom_blobs[0].c; page++) { - const float* ptr = bottom_blobs[0].channel(page); - float* valptr = top_val_blob.channel(page); - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); - for (int col = 0; col < bottom_blobs[0].w; col++) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = ptr[i * bottom_blobs[0].w + col]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[col] = *index_iter; - if (top_blobs.size() == 2) - indptr[col] = std::distance(vec.begin(), index_iter); - else - valptr[col] = std::distance(vec.begin(), index_iter); - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[col] = *index_iter; - if (top_blobs.size() == 2) - indptr[col] = std::distance(vec.begin(), index_iter); - else - valptr[col] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; + for (int page = 0; page < bottom_blobs[0].c; page++) + { + const float* ptr = bottom_blobs[0].channel(page); + for (int col = 0; col < bottom_blobs[0].w; col++) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = ptr[i * bottom_blobs[0].w + col]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + col] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + col] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } } - } - } - } else { - top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + if (dims == 3 && positive_axis == 2) + { + if (keep_dims == 1) + { + top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } - std::vector vec; - vec.resize(bottom_blobs[0].h); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int page = 0; page < bottom_blobs[0].c; page++) { - const float* ptr = bottom_blobs[0].channel(page); - for (int col = 0; col < bottom_blobs[0].w; col++) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = ptr[i * bottom_blobs[0].w + col]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[page * top_val_blob.w + col] = *index_iter; - if (top_blobs.size() == 2) - indptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - else - valptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[page * top_val_blob.w + col] = *index_iter; - if (top_blobs.size() == 2) - indptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - else - valptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - } - } - if (dims == 3 && positive_axis == 2) { - if (keep_dims == 1) { - top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + std::vector vec; + vec.resize(bottom_blobs[0].w); - std::vector vec; - vec.resize(bottom_blobs[0].w); - - for (int page = 0; page < bottom_blobs[0].c; page++) { - const float* ptr = bottom_blobs[0].channel(page); - float* valptr = top_val_blob.channel(page); - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = ptr[r * bottom_blobs[0].w + i]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[r] = *index_iter; - if (top_blobs.size() == 2) - indptr[r] = std::distance(vec.begin(), index_iter); - else - valptr[r] = std::distance(vec.begin(), index_iter); - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[r] = *index_iter; - if (top_blobs.size() == 2) - indptr[r] = std::distance(vec.begin(), index_iter); - else - valptr[r] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - } else { - top_val_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + for (int page = 0; page < bottom_blobs[0].c; page++) + { + const float* ptr = bottom_blobs[0].channel(page); + float* valptr = top_val_blob.channel(page); + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = ptr[r * bottom_blobs[0].w + i]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } + else + { + top_val_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } - std::vector vec; - vec.resize(bottom_blobs[0].w); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int page = 0; page < bottom_blobs[0].c; page++) { - const float* ptr = bottom_blobs[0].channel(page); - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = ptr[r * bottom_blobs[0].w + i]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[page * top_val_blob.w + r] = *index_iter; - if (top_blobs.size() == 2) - indptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); - else - valptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[page * top_val_blob.w + r] = *index_iter; - if (top_blobs.size() == 2) - indptr[page * top_val_blob.w + r] = std::distance(vec.begin(), index_iter); - else - valptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; + std::vector vec; + vec.resize(bottom_blobs[0].w); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int page = 0; page < bottom_blobs[0].c; page++) + { + const float* ptr = bottom_blobs[0].channel(page); + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = ptr[r * bottom_blobs[0].w + i]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + r] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + r] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_val_blob.w + r] = std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } } - } } - } + return 0; } - } - return 0; -} } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.h b/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.h index d390fbafcd..45e7968b79 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.h @@ -4,21 +4,26 @@ #include "layer.h" -namespace mmdeploy { - -class TopK : public ncnn::Layer { - public: - TopK(); - virtual int load_param(const ncnn::ParamDict& pd); - virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const ncnn::Option& opt) const; - - public: - int axis; - int largest; - int sorted; - int keep_dims; -}; +namespace mmdeploy +{ + + class TopK : public ncnn::Layer + { + public: + TopK(); + + virtual int load_param(const ncnn::ParamDict& pd); + + virtual int forward(const std::vector& bottom_blobs, + std::vector& top_blobs, + const ncnn::Option& opt) const; + + public: + int axis; + int largest; + int sorted; + int keep_dims; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/CMakeLists.txt b/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/CMakeLists.txt index 652f841f7a..1d2a381837 100755 --- a/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/CMakeLists.txt @@ -3,15 +3,16 @@ project(ncnn_ext) # pybind11 -if (NOT TARGET pybind11) - add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11) -endif () +if(NOT TARGET pybind11) + add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11) +endif() pybind11_add_module(ncnn_ext ncnn_ext.cpp) target_link_libraries(ncnn_ext PUBLIC mmdeploy_ncnn_ops ncnn) set(_NCNN_EXT_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/backend/ncnn) -set_target_properties(ncnn_ext PROPERTIES - LIBRARY_OUTPUT_DIRECTORY ${_NCNN_EXT_DIR} - LIBRARY_OUTPUT_DIRECTORY_DEBUG ${_NCNN_EXT_DIR} - LIBRARY_OUTPUT_DIRECTORY_RELEASE ${_NCNN_EXT_DIR}) +set_target_properties( + ncnn_ext + PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${_NCNN_EXT_DIR} + LIBRARY_OUTPUT_DIRECTORY_DEBUG ${_NCNN_EXT_DIR} + LIBRARY_OUTPUT_DIRECTORY_RELEASE ${_NCNN_EXT_DIR}) diff --git a/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp b/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp old mode 100755 new mode 100644 index ac158b9edb..1c8ad70cc7 --- a/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp @@ -4,9 +4,11 @@ #include "ncnn_ops_register.h" #include "net.h" -PYBIND11_MODULE(ncnn_ext, m) { - m.def( - "register_mmdeploy_custom_layers", - [](ncnn::Net &net) { return register_mmdeploy_custom_layers(net); }, - "register mmdeploy custom ncnn layers."); +PYBIND11_MODULE(ncnn_ext, m) +{ + m.def( + "register_mmdeploy_custom_layers", + [](ncnn::Net& net) + { return register_mmdeploy_custom_layers(net); }, + "register mmdeploy custom ncnn layers."); } diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/CMakeLists.txt b/csrc/mmdeploy/backend_ops/onnxruntime/CMakeLists.txt index 9548110be6..f8f7e35f77 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/onnxruntime/CMakeLists.txt @@ -9,16 +9,18 @@ include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake) file(GLOB_RECURSE ORT_OPS_SRCS *.cpp) add_library(${PROJECT_NAME}_obj OBJECT "${ORT_OPS_SRCS}") target_compile_definitions(${PROJECT_NAME}_obj PRIVATE -DMMDEPLOY_API_EXPORTS=1) -target_compile_options(${PROJECT_NAME}_obj PRIVATE - $<$:-fvisibility=hidden>) -set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE 1) +target_compile_options(${PROJECT_NAME}_obj + PRIVATE $<$:-fvisibility=hidden>) +set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE + 1) mmdeploy_export(${PROJECT_NAME}_obj) -target_include_directories(${PROJECT_NAME}_obj PUBLIC - $ - $ - $ - $) +target_include_directories( + ${PROJECT_NAME}_obj + PUBLIC $ + $ + $ + $) target_link_libraries(${PROJECT_NAME}_obj PUBLIC onnxruntime) mmdeploy_add_library(${PROJECT_NAME} SHARED EXCLUDE "") diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/common/onnxruntime_register.h b/csrc/mmdeploy/backend_ops/onnxruntime/common/onnxruntime_register.h index 28d2a2b782..1095c28bae 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/common/onnxruntime_register.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/common/onnxruntime_register.h @@ -6,11 +6,12 @@ #include "mmdeploy/core/macro.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -MMDEPLOY_API OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, - const OrtApiBase *api); + MMDEPLOY_API OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, + const OrtApiBase* api); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.cpp index c604e4b650..da959ec37e 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.cpp @@ -1,10 +1,12 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "ort_utils.h" -namespace mmdeploy { +namespace mmdeploy +{ -CustomOpsTable& get_mmdeploy_custom_ops() { - static CustomOpsTable _custom_ops; - return _custom_ops; -} + CustomOpsTable& get_mmdeploy_custom_ops() + { + static CustomOpsTable _custom_ops; + return _custom_ops; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.h b/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.h index e19c984f86..14d2da3457 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.h @@ -6,32 +6,39 @@ #include #include -namespace mmdeploy { - -typedef std::unordered_map> CustomOpsTable; - -struct OrtTensorDimensions : std::vector { - OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) { - OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); - std::vector::operator=(ort.GetTensorShape(info)); - ort.ReleaseTensorTypeAndShapeInfo(info); - } -}; - -CustomOpsTable& get_mmdeploy_custom_ops(); - -template -class OrtOpsRegistry { - public: - OrtOpsRegistry() { get_mmdeploy_custom_ops()[domain].push_back(&instance); } - - private: - T instance{}; -}; - -#define REGISTER_ONNXRUNTIME_OPS(domain, name) \ - static char __domain_##domain##name[] = #domain; \ - static OrtOpsRegistry<__domain_##domain##name, name> ort_ops_registry_##domain##name {} +namespace mmdeploy +{ + + typedef std::unordered_map> CustomOpsTable; + + struct OrtTensorDimensions : std::vector + { + OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) + { + OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); + std::vector::operator=(ort.GetTensorShape(info)); + ort.ReleaseTensorTypeAndShapeInfo(info); + } + }; + + CustomOpsTable& get_mmdeploy_custom_ops(); + + template + class OrtOpsRegistry + { + public: + OrtOpsRegistry() + { + get_mmdeploy_custom_ops()[domain].push_back(&instance); + } + + private: + T instance{}; + }; + +#define REGISTER_ONNXRUNTIME_OPS(domain, name) \ + static char __domain_##domain##name[] = #domain; \ + static OrtOpsRegistry<__domain_##domain##name, name> ort_ops_registry_##domain##name {} } // namespace mmdeploy #endif // ORT_MMCV_UTILS_H diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.cpp index c7fed37d23..27eb677394 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.cpp @@ -8,287 +8,335 @@ #include "ort_utils.h" -namespace mmdeploy { +namespace mmdeploy +{ #define MIN(a, b) (((a) < (b)) ? (a) : (b)) #define MAX(a, b) (((a) < (b)) ? (b) : (a)) #define CLIP_COORDINATES(in, out, clip_limit) out = MIN((clip_limit - 1), MAX(in, 0)) -GridSampleKernel::GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info) - : ort_(api), info_(info) { - align_corners_ = ort_.KernelInfoGetAttribute(info, "align_corners"); - interpolation_mode_ = ort_.KernelInfoGetAttribute(info, "interpolation_mode"); - padding_mode_ = ort_.KernelInfoGetAttribute(info, "padding_mode"); - - allocator_ = Ort::AllocatorWithDefaultOptions(); -} - -enum GridSamplerInterpolation { Bilinear = 0, Nearest = 1, Bicubic = 2 }; -enum GridSamplerPadding { Zeros = 0, Border = 1, Reflection = 2 }; - -template -static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size, bool align_corners) { - if (align_corners) { - return ((coord + 1) / 2) * (size - 1); - } else { - return ((coord + 1) * size - 1) / 2; - } -} - -// Clips coordinates to between 0 and clip_limit - 1 -template -static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) { - return std::min(static_cast(clip_limit - 1), std::max(in, static_cast(0))); -} - -// Reflects coordinates until they fall between low and high (inclusive). -// The bounds are passed as twice their value so that half-integer values -// can be represented as ints. -template -static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low, int64_t twice_high) { - if (twice_low == twice_high) { - return static_cast(0); - } - scalar_t min = static_cast(twice_low) / 2; - scalar_t span = static_cast(twice_high - twice_low) / 2; - in = std::fabs(in - min); - // `fmod` returns same sign as `in`, which is positive after the `fabs` above. - scalar_t extra = std::fmod(in, span); - int flips = static_cast(std::floor(in / span)); - if (flips % 2 == 0) { - return extra + min; - } else { - return span - extra + min; - } -} - -template -static inline scalar_t compute_coordinates(scalar_t coord, int64_t size, int64_t padding_mode, - bool align_corners) { - if (padding_mode == GridSamplerPadding::Border) { - coord = clip_coordinates(coord, size); - } else if (padding_mode == GridSamplerPadding::Reflection) { - if (align_corners) { - coord = reflect_coordinates(coord, 0, 2 * (size - 1)); - } else { - coord = reflect_coordinates(coord, -1, 2 * size - 1); + GridSampleKernel::GridSampleKernel(const OrtApi& api, const OrtKernelInfo* info) + : ort_(api) + , info_(info) + { + align_corners_ = ort_.KernelInfoGetAttribute(info, "align_corners"); + interpolation_mode_ = ort_.KernelInfoGetAttribute(info, "interpolation_mode"); + padding_mode_ = ort_.KernelInfoGetAttribute(info, "padding_mode"); + + allocator_ = Ort::AllocatorWithDefaultOptions(); } - coord = clip_coordinates(coord, size); - } - return coord; -} - -// Computes the pixel source index value for a grid coordinate -template -static inline scalar_t grid_sampler_compute_source_index(scalar_t coord, int64_t size, - int64_t padding_mode, bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); - coord = compute_coordinates(coord, size, padding_mode, align_corners); - return coord; -} - -static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) { - return h >= 0 && h < H && w >= 0 && w < W; -} - -template -static inline scalar_t get_value_bounded(const scalar_t *data, scalar_t x, scalar_t y, int64_t W, - int64_t H, int64_t sW, int64_t sH, int64_t padding_mode, - bool align_corners) { - x = compute_coordinates(x, W, padding_mode, align_corners); - y = compute_coordinates(y, H, padding_mode, align_corners); - - int64_t ix = static_cast(x); - int64_t iy = static_cast(y); - - if (within_bounds_2d(iy, ix, H, W)) { - return data[iy * sH + ix * sW]; - } - return static_cast(0); -} - -template -static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) { - return ((A + 2) * x - (A + 3)) * x * x + 1; -} - -template -static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) { - return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; -} - -template -static inline void get_cubic_upsample_coefficients(scalar_t coeffs[4], scalar_t t) { - scalar_t A = -0.75; - - scalar_t x1 = t; - coeffs[0] = cubic_convolution2(x1 + 1.0, A); - coeffs[1] = cubic_convolution1(x1, A); - - // opposite coefficients - scalar_t x2 = 1.0 - t; - coeffs[2] = cubic_convolution1(x2, A); - coeffs[3] = cubic_convolution2(x2 + 1.0, A); -} - -template -static inline scalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, - scalar_t t) { - scalar_t coeffs[4]; - get_cubic_upsample_coefficients(coeffs, t); - - return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; -} - -void GridSampleKernel::Compute(OrtKernelContext *context) { - const bool align_corners = align_corners_; - const int64_t padding_mode = padding_mode_; - const int64_t interpolation_mode = interpolation_mode_; - - const OrtValue *input = ort_.KernelContext_GetInput(context, 0); - const float *input_data = reinterpret_cast(ort_.GetTensorData(input)); - - const OrtValue *grid = ort_.KernelContext_GetInput(context, 1); - const float *grid_data = reinterpret_cast(ort_.GetTensorData(grid)); - - OrtTensorDimensions input_dims(ort_, input); - OrtTensorDimensions grid_dims(ort_, grid); - int64_t N = input_dims[0]; - int64_t C = input_dims[1]; - int64_t inp_H = input_dims[2]; - int64_t inp_W = input_dims[3]; - int64_t out_H = grid_dims[1]; - int64_t out_W = grid_dims[2]; - - std::vector output_dims = {N, C, out_H, out_W}; - OrtValue *output = - ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size()); - float *out_ptr = ort_.GetTensorMutableData(output); - - int64_t inp_sN = input_dims[1] * input_dims[2] * input_dims[3]; - int64_t inp_sC = input_dims[2] * input_dims[3]; - int64_t inp_sH = input_dims[3]; - int64_t inp_sW = 1; - int64_t grid_sN = grid_dims[1] * grid_dims[2] * grid_dims[3]; - int64_t grid_sH = grid_dims[2] * grid_dims[3]; - int64_t grid_sW = grid_dims[3]; - int64_t grid_sCoor = 1; - int64_t out_sN = output_dims[1] * output_dims[2] * output_dims[3]; - int64_t out_sC = output_dims[2] * output_dims[3]; - int64_t out_sH = output_dims[3]; - int64_t out_sW = 1; - - // loop over each output pixel - for (int64_t n = 0; n < N; ++n) { - const float *grid_ptr_N = grid_data + n * grid_sN; - const float *inp_ptr_N = input_data + n * inp_sN; - for (int64_t h = 0; h < out_H; ++h) { - for (int64_t w = 0; w < out_W; ++w) { - const float *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; - float x = *grid_ptr_NHW; - float y = grid_ptr_NHW[grid_sCoor]; - - float ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); - float iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); - - if (interpolation_mode == GridSamplerInterpolation::Bilinear) { - // get corner pixel values from (x, y) - // for 4d, we use north-east-south-west - int64_t ix_nw = static_cast(std::floor(ix)); - int64_t iy_nw = static_cast(std::floor(iy)); - - int64_t ix_ne = ix_nw + 1; - int64_t iy_ne = iy_nw; - - int64_t ix_sw = ix_nw; - int64_t iy_sw = iy_nw + 1; - - int64_t ix_se = ix_nw + 1; - int64_t iy_se = iy_nw + 1; - - // get surfaces to each neighbor: - float nw = (ix_se - ix) * (iy_se - iy); - float ne = (ix - ix_sw) * (iy_sw - iy); - float sw = (ix_ne - ix) * (iy - iy_ne); - float se = (ix - ix_nw) * (iy - iy_nw); - - // calculate bilinear weighted pixel value and set output pixel - const float *inp_ptr_NC = inp_ptr_N; - float *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; - for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { - auto res = static_cast(0); - if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { - res += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; - } - if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { - res += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; - } - if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { - res += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; - } - if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { - res += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; - } - *out_ptr_NCHW = res; - } - } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { - int64_t ix_nearest = static_cast(std::nearbyint(ix)); - int64_t iy_nearest = static_cast(std::nearbyint(iy)); - - // assign nearest neighbor pixel value to output pixel - float *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; - const float *inp_ptr_NC = inp_ptr_N; - for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { - if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { - *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; - } else { - *out_ptr_NCHW = static_cast(0); + + enum GridSamplerInterpolation + { + Bilinear = 0, + Nearest = 1, + Bicubic = 2 + }; + enum GridSamplerPadding + { + Zeros = 0, + Border = 1, + Reflection = 2 + }; + + template + static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size, bool align_corners) + { + if (align_corners) + { + return ((coord + 1) / 2) * (size - 1); + } + else + { + return ((coord + 1) * size - 1) / 2; + } + } + + // Clips coordinates to between 0 and clip_limit - 1 + template + static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) + { + return std::min(static_cast(clip_limit - 1), std::max(in, static_cast(0))); + } + + // Reflects coordinates until they fall between low and high (inclusive). + // The bounds are passed as twice their value so that half-integer values + // can be represented as ints. + template + static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low, int64_t twice_high) + { + if (twice_low == twice_high) + { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = std::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = std::fmod(in, span); + int flips = static_cast(std::floor(in / span)); + if (flips % 2 == 0) + { + return extra + min; + } + else + { + return span - extra + min; + } + } + + template + static inline scalar_t compute_coordinates(scalar_t coord, int64_t size, int64_t padding_mode, bool align_corners) + { + if (padding_mode == GridSamplerPadding::Border) + { + coord = clip_coordinates(coord, size); + } + else if (padding_mode == GridSamplerPadding::Reflection) + { + if (align_corners) + { + coord = reflect_coordinates(coord, 0, 2 * (size - 1)); } - } - } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { - // grid_sampler_compute_source_index will "clip the value" of idx - // depends on the padding, - // which would cause calculation to be wrong, - // for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix - // = floor(x) = -1 - // There would be more problem in reflection padding, since the -1 and - // +1 direction is not fixed in boundary condition - ix = grid_sampler_unnormalize(x, inp_W, align_corners); - iy = grid_sampler_unnormalize(y, inp_H, align_corners); - - float ix_nw = std::floor(ix); - float iy_nw = std::floor(iy); - - const float tx = ix - ix_nw; - const float ty = iy - iy_nw; - - const float *inp_ptr_NC = inp_ptr_N; - float *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; - for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { - float coefficients[4]; - - // Interpolate 4 values in the x direction - for (int64_t i = 0; i < 4; ++i) { - coefficients[i] = cubic_interp1d( - get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - tx); + else + { + coord = reflect_coordinates(coord, -1, 2 * size - 1); } + coord = clip_coordinates(coord, size); + } + return coord; + } - // Interpolate in the y direction - *out_ptr_NCHW = cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], - coefficients[3], ty); - } + // Computes the pixel source index value for a grid coordinate + template + static inline scalar_t grid_sampler_compute_source_index(scalar_t coord, int64_t size, int64_t padding_mode, bool align_corners) + { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; + } + + static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) + { + return h >= 0 && h < H && w >= 0 && w < W; + } + + template + static inline scalar_t get_value_bounded(const scalar_t* data, scalar_t x, scalar_t y, int64_t W, int64_t H, int64_t sW, int64_t sH, int64_t padding_mode, bool align_corners) + { + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) + { + return data[iy * sH + ix * sW]; + } + return static_cast(0); + } + + template + static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) + { + return ((A + 2) * x - (A + 3)) * x * x + 1; + } + + template + static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) + { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; + } + + template + static inline void get_cubic_upsample_coefficients(scalar_t coeffs[4], scalar_t t) + { + scalar_t A = -0.75; + + scalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + scalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); + } + + template + static inline scalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) + { + scalar_t coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; + } + + void GridSampleKernel::Compute(OrtKernelContext* context) + { + const bool align_corners = align_corners_; + const int64_t padding_mode = padding_mode_; + const int64_t interpolation_mode = interpolation_mode_; + + const OrtValue* input = ort_.KernelContext_GetInput(context, 0); + const float* input_data = reinterpret_cast(ort_.GetTensorData(input)); + + const OrtValue* grid = ort_.KernelContext_GetInput(context, 1); + const float* grid_data = reinterpret_cast(ort_.GetTensorData(grid)); + + OrtTensorDimensions input_dims(ort_, input); + OrtTensorDimensions grid_dims(ort_, grid); + int64_t N = input_dims[0]; + int64_t C = input_dims[1]; + int64_t inp_H = input_dims[2]; + int64_t inp_W = input_dims[3]; + int64_t out_H = grid_dims[1]; + int64_t out_W = grid_dims[2]; + + std::vector output_dims = {N, C, out_H, out_W}; + OrtValue* output = + ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size()); + float* out_ptr = ort_.GetTensorMutableData(output); + + int64_t inp_sN = input_dims[1] * input_dims[2] * input_dims[3]; + int64_t inp_sC = input_dims[2] * input_dims[3]; + int64_t inp_sH = input_dims[3]; + int64_t inp_sW = 1; + int64_t grid_sN = grid_dims[1] * grid_dims[2] * grid_dims[3]; + int64_t grid_sH = grid_dims[2] * grid_dims[3]; + int64_t grid_sW = grid_dims[3]; + int64_t grid_sCoor = 1; + int64_t out_sN = output_dims[1] * output_dims[2] * output_dims[3]; + int64_t out_sC = output_dims[2] * output_dims[3]; + int64_t out_sH = output_dims[3]; + int64_t out_sW = 1; + + // loop over each output pixel + for (int64_t n = 0; n < N; ++n) + { + const float* grid_ptr_N = grid_data + n * grid_sN; + const float* inp_ptr_N = input_data + n * inp_sN; + for (int64_t h = 0; h < out_H; ++h) + { + for (int64_t w = 0; w < out_W; ++w) + { + const float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; + float x = *grid_ptr_NHW; + float y = grid_ptr_NHW[grid_sCoor]; + + float ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + float iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get corner pixel values from (x, y) + // for 4d, we use north-east-south-west + int64_t ix_nw = static_cast(std::floor(ix)); + int64_t iy_nw = static_cast(std::floor(iy)); + + int64_t ix_ne = ix_nw + 1; + int64_t iy_ne = iy_nw; + + int64_t ix_sw = ix_nw; + int64_t iy_sw = iy_nw + 1; + + int64_t ix_se = ix_nw + 1; + int64_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + float nw = (ix_se - ix) * (iy_se - iy); + float ne = (ix - ix_sw) * (iy_sw - iy); + float sw = (ix_ne - ix) * (iy - iy_ne); + float se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + const float* inp_ptr_NC = inp_ptr_N; + float* out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) + { + auto res = static_cast(0); + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) + { + res += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) + { + res += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) + { + res += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) + { + res += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + *out_ptr_NCHW = res; + } + } + else if (interpolation_mode == GridSamplerInterpolation::Nearest) + { + int64_t ix_nearest = static_cast(std::nearbyint(ix)); + int64_t iy_nearest = static_cast(std::nearbyint(iy)); + + // assign nearest neighbor pixel value to output pixel + float* out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + const float* inp_ptr_NC = inp_ptr_N; + for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) + { + if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) + { + *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; + } + else + { + *out_ptr_NCHW = static_cast(0); + } + } + } + else if (interpolation_mode == GridSamplerInterpolation::Bicubic) + { + // grid_sampler_compute_source_index will "clip the value" of idx + // depends on the padding, + // which would cause calculation to be wrong, + // for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix + // = floor(x) = -1 + // There would be more problem in reflection padding, since the -1 and + // +1 direction is not fixed in boundary condition + ix = grid_sampler_unnormalize(x, inp_W, align_corners); + iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + float ix_nw = std::floor(ix); + float iy_nw = std::floor(iy); + + const float tx = ix - ix_nw; + const float ty = iy - iy_nw; + + const float* inp_ptr_NC = inp_ptr_N; + float* out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) + { + float coefficients[4]; + + // Interpolate 4 values in the x direction + for (int64_t i = 0; i < 4; ++i) + { + coefficients[i] = cubic_interp1d( + get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + tx); + } + + // Interpolate in the y direction + *out_ptr_NCHW = cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty); + } + } + } + } } - } } - } -} -REGISTER_ONNXRUNTIME_OPS(mmdeploy, GridSampleOp); + REGISTER_ONNXRUNTIME_OPS(mmdeploy, GridSampleOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.h b/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.h index 2581b7833e..e6c9fa280f 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.h @@ -4,41 +4,59 @@ #include -namespace mmdeploy { - -struct GridSampleKernel { - GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info); - - void Compute(OrtKernelContext *context); - - protected: - Ort::CustomOpApi ort_; - const OrtKernelInfo *info_; - Ort::AllocatorWithDefaultOptions allocator_; - - int64_t align_corners_; - int64_t interpolation_mode_; - int64_t padding_mode_; -}; - -struct GridSampleOp : Ort::CustomOpBase { - void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const { - return new GridSampleKernel(api, info); - }; - - const char *GetName() const { return "grid_sampler"; }; - - size_t GetInputTypeCount() const { return 2; }; - ONNXTensorElementDataType GetInputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - }; - - size_t GetOutputTypeCount() const { return 1; }; - ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - }; - - const char *GetExecutionProviderType() const { return "CPUExecutionProvider"; }; -}; +namespace mmdeploy +{ + + struct GridSampleKernel + { + GridSampleKernel(const OrtApi& api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + protected: + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; + + int64_t align_corners_; + int64_t interpolation_mode_; + int64_t padding_mode_; + }; + + struct GridSampleOp : Ort::CustomOpBase + { + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const + { + return new GridSampleKernel(api, info); + }; + + const char* GetName() const + { + return "grid_sampler"; + }; + + size_t GetInputTypeCount() const + { + return 2; + }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + size_t GetOutputTypeCount() const + { + return 1; + }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + }; + }; } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp index 075c3277bc..320fa8dd45 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp @@ -8,191 +8,218 @@ #include "modulated_deform_conv/modulated_deform_conv_cpu.h" #include "ort_utils.h" -namespace mmdeploy { - -void parallel_unroll_gemm(const float *A, const float *B, const float *V, const float *H, - const int32_t M, const int32_t N, const int32_t K, const float alpha, - const float beta, float *Y, const int32_t start_row, - const int32_t end_row) { - std::vector tmp(N); - for (int32_t m = start_row; m < end_row; ++m) { - for (int32_t n = 0; n < N; n++) { - tmp[n] = 0; - } +namespace mmdeploy +{ + + void parallel_unroll_gemm(const float* A, const float* B, const float* V, const float* H, const int32_t M, const int32_t N, const int32_t K, const float alpha, const float beta, float* Y, const int32_t start_row, const int32_t end_row) { - int32_t remainder = K % 8; // unroll - for (int32_t k = 0; k < K; k += 8) { - for (int32_t n = 0; n < N; n++) { - tmp[n] += A[m * K + k] * B[k * N + n]; - tmp[n] += A[m * K + k + 1] * B[k * N + N + n]; - tmp[n] += A[m * K + k + 2] * B[k * N + 2 * N + n]; - tmp[n] += A[m * K + k + 3] * B[k * N + 3 * N + n]; - tmp[n] += A[m * K + k + 4] * B[k * N + 4 * N + n]; - tmp[n] += A[m * K + k + 5] * B[k * N + 5 * N + n]; - tmp[n] += A[m * K + k + 6] * B[k * N + 6 * N + n]; - tmp[n] += A[m * K + k + 7] * B[k * N + 7 * N + n]; + std::vector tmp(N); + for (int32_t m = start_row; m < end_row; ++m) + { + for (int32_t n = 0; n < N; n++) + { + tmp[n] = 0; + } + { + int32_t remainder = K % 8; // unroll + for (int32_t k = 0; k < K; k += 8) + { + for (int32_t n = 0; n < N; n++) + { + tmp[n] += A[m * K + k] * B[k * N + n]; + tmp[n] += A[m * K + k + 1] * B[k * N + N + n]; + tmp[n] += A[m * K + k + 2] * B[k * N + 2 * N + n]; + tmp[n] += A[m * K + k + 3] * B[k * N + 3 * N + n]; + tmp[n] += A[m * K + k + 4] * B[k * N + 4 * N + n]; + tmp[n] += A[m * K + k + 5] * B[k * N + 5 * N + n]; + tmp[n] += A[m * K + k + 6] * B[k * N + 6 * N + n]; + tmp[n] += A[m * K + k + 7] * B[k * N + 7 * N + n]; + } + } + for (int32_t k = K - remainder; k < K; k++) + { + for (int32_t n = 0; n < N; n++) + { + tmp[n] += A[m * K + k] * B[k * N + n]; + } + } + } + for (int32_t n = 0; n < N; n++) + { + tmp[n] *= alpha; + if (V) tmp[n] += beta * V[n]; + if (H) tmp[n] += beta * H[m * N + n]; + Y[m * N + n] = tmp[n]; + } } - } - for (int32_t k = K - remainder; k < K; k++) { - for (int32_t n = 0; n < N; n++) { - tmp[n] += A[m * K + k] * B[k * N + n]; + } + + void deformable_conv2d_ref_fp32(const float* src, const float* offset, const float* mask, const float* filter, const float* bias, const int64_t batch, const int64_t src_c, const int64_t src_h, const int64_t src_w, const int64_t dst_c, const int64_t dst_h, const int64_t dst_w, const int64_t group, const int64_t offset_group, const int64_t channels, const int64_t num_output, const int64_t kernel_h, const int64_t kernel_w, const int64_t stride_h, const int64_t stride_w, const int64_t pad_h, const int64_t pad_w, const int64_t dilation_h, const int64_t dilation_w, float* columns, float* dst) + { + const int64_t ic_per_gp = channels / group; + const int64_t oc_per_gp = num_output / group; + // Set up for launching threads + std::size_t num_threads = std::thread::hardware_concurrency(); + std::vector threads; + threads.reserve(num_threads); + + for (int64_t b = 0; b < batch; ++b) + { + for (int64_t g = 0; g < group; ++g) + { + deformable_im2col_2d( + src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w, + offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w, + mask + b * offset_group * kernel_h * kernel_w * dst_h * dst_w, + src_h, + src_w, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + ic_per_gp, + offset_group, + dst_h, + dst_w, + mask != nullptr, + columns); + float* dst_ptr = dst + b * dst_c * dst_h * dst_w + g * oc_per_gp * dst_h * dst_w; + if (bias != nullptr) + { + const float* bias_ptr = bias + g * oc_per_gp; + for (int64_t oc = 0; oc < oc_per_gp; ++oc) + { + for (int64_t hw = 0; hw < dst_h * dst_w; ++hw) + { + dst_ptr[oc * dst_h * dst_w + hw] = bias_ptr[oc]; + } + } + } + else + { + memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w); + } + if (num_threads > 1) + { + // Calculate values to pass to threads + int32_t n_rows = (oc_per_gp + num_threads - 1) / num_threads; + int32_t end_row = 0; + for (int32_t i = 0; i < num_threads; i++) + { + auto start_row = i * n_rows; + end_row = start_row + n_rows; + if (end_row > oc_per_gp) end_row = oc_per_gp; + std::thread t(parallel_unroll_gemm, + filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, + columns, + nullptr, + dst_ptr, + oc_per_gp, + dst_h * dst_w, + ic_per_gp * kernel_h * kernel_w, + 1.0f, + 1.0f, + dst_ptr, + start_row, + end_row); + threads.emplace_back(std::move(t)); + } + // Wait for all threads to complete + for (auto& t : threads) t.join(); + threads.clear(); + } + else + { // parallel gemm degrade to serial gemm with start_row=0 and end_row= oc_per_gp + parallel_unroll_gemm(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, nullptr, dst_ptr, oc_per_gp, dst_h * dst_w, ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr, 0, oc_per_gp); + } + } } - } } - for (int32_t n = 0; n < N; n++) { - tmp[n] *= alpha; - if (V) tmp[n] += beta * V[n]; - if (H) tmp[n] += beta * H[m * N + n]; - Y[m * N + n] = tmp[n]; + + MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(const OrtApi& api, + const OrtKernelInfo* info) + : ort_(api) + , info_(info) + { + std::vector stride = ort_.KernelInfoGetAttribute>(info, "stride"); + stride_height_ = stride[0]; + stride_width_ = stride[1]; + std::vector padding = ort_.KernelInfoGetAttribute>(info, "padding"); + padding_height_ = padding[0]; + padding_width_ = padding[1]; + std::vector dilation = + ort_.KernelInfoGetAttribute>(info, "dilation"); + dilation_height_ = dilation[0]; + dilation_width_ = dilation[1]; + deformable_group_ = ort_.KernelInfoGetAttribute(info, "deform_groups"); + group_ = ort_.KernelInfoGetAttribute(info, "groups"); + + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); } - } -} - -void deformable_conv2d_ref_fp32(const float *src, const float *offset, const float *mask, - const float *filter, const float *bias, const int64_t batch, - const int64_t src_c, const int64_t src_h, const int64_t src_w, - const int64_t dst_c, const int64_t dst_h, const int64_t dst_w, - const int64_t group, const int64_t offset_group, - const int64_t channels, const int64_t num_output, - const int64_t kernel_h, const int64_t kernel_w, - const int64_t stride_h, const int64_t stride_w, const int64_t pad_h, - const int64_t pad_w, const int64_t dilation_h, - const int64_t dilation_w, float *columns, float *dst) { - const int64_t ic_per_gp = channels / group; - const int64_t oc_per_gp = num_output / group; - // Set up for launching threads - std::size_t num_threads = std::thread::hardware_concurrency(); - std::vector threads; - threads.reserve(num_threads); - - for (int64_t b = 0; b < batch; ++b) { - for (int64_t g = 0; g < group; ++g) { - deformable_im2col_2d( - src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w, - offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w, - mask + b * offset_group * kernel_h * kernel_w * dst_h * dst_w, src_h, src_w, kernel_h, - kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, ic_per_gp, - offset_group, dst_h, dst_w, mask != nullptr, columns); - float *dst_ptr = dst + b * dst_c * dst_h * dst_w + g * oc_per_gp * dst_h * dst_w; - if (bias != nullptr) { - const float *bias_ptr = bias + g * oc_per_gp; - for (int64_t oc = 0; oc < oc_per_gp; ++oc) { - for (int64_t hw = 0; hw < dst_h * dst_w; ++hw) { - dst_ptr[oc * dst_h * dst_w + hw] = bias_ptr[oc]; - } - } - } else { - memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w); - } - if (num_threads > 1) { - // Calculate values to pass to threads - int32_t n_rows = (oc_per_gp + num_threads - 1) / num_threads; - int32_t end_row = 0; - for (int32_t i = 0; i < num_threads; i++) { - auto start_row = i * n_rows; - end_row = start_row + n_rows; - if (end_row > oc_per_gp) end_row = oc_per_gp; - std::thread t(parallel_unroll_gemm, - filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, nullptr, - dst_ptr, oc_per_gp, dst_h * dst_w, ic_per_gp * kernel_h * kernel_w, 1.0f, - 1.0f, dst_ptr, start_row, end_row); - threads.emplace_back(std::move(t)); - } - // Wait for all threads to complete - for (auto &t : threads) t.join(); - threads.clear(); - } else { // parallel gemm degrade to serial gemm with start_row=0 and end_row= oc_per_gp - parallel_unroll_gemm(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, - nullptr, dst_ptr, oc_per_gp, dst_h * dst_w, - ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr, 0, oc_per_gp); - } + + void MMCVModulatedDeformConvKernel::Compute(OrtKernelContext* context) + { + const int64_t stride_height = stride_height_; + const int64_t stride_width = stride_width_; + const int64_t padding_height = padding_height_; + const int64_t padding_width = padding_width_; + const int64_t dilation_height = dilation_height_; + const int64_t dilation_width = dilation_width_; + const int64_t deformable_group = deformable_group_; + const int64_t group = group_; + + const OrtValue* input = ort_.KernelContext_GetInput(context, 0); + const float* input_data = reinterpret_cast(ort_.GetTensorData(input)); + + const OrtValue* offset = ort_.KernelContext_GetInput(context, 1); + const float* offset_data = reinterpret_cast(ort_.GetTensorData(offset)); + + const OrtValue* mask = ort_.KernelContext_GetInput(context, 2); + const float* mask_data = reinterpret_cast(ort_.GetTensorData(mask)); + + const OrtValue* filter = ort_.KernelContext_GetInput(context, 3); + const float* filter_data = reinterpret_cast(ort_.GetTensorData(filter)); + + const OrtValue* bias = ort_.KernelContext_GetInput(context, 4); + const float* bias_data = (bias != nullptr) ? reinterpret_cast(ort_.GetTensorData(bias)) : nullptr; + // const float *bias_data = nullptr; + + OrtTensorDimensions input_dims(ort_, input); + OrtTensorDimensions filter_dims(ort_, filter); + + int64_t batch = input_dims[0]; + int64_t channels = input_dims[1]; + int64_t in_height = input_dims[2]; + int64_t in_width = input_dims[3]; + int64_t num_output = filter_dims[0]; + int64_t kernel_height = filter_dims[2]; + int64_t kernel_width = filter_dims[3]; + + // get output memory + int64_t out_height = floor( + (in_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) / stride_height + + 1); + int64_t out_width = floor( + (in_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width + 1); + + std::vector output_dims = {batch, num_output, out_height, out_width}; + OrtValue* output = + ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size()); + float* out_ptr = ort_.GetTensorMutableData(output); + + // allocate tmp memory + int64_t column_len = (channels / group) * kernel_height * kernel_width * out_height * out_width; + float* columns = (float*)allocator_.Alloc(sizeof(float) * column_len); + + deformable_conv2d_ref_fp32(input_data, offset_data, mask_data, filter_data, bias_data, batch, channels, in_height, in_width, num_output, out_height, out_width, group, deformable_group, channels, num_output, kernel_height, kernel_width, stride_height, stride_width, padding_height, padding_width, dilation_height, dilation_width, columns, out_ptr); + + allocator_.Free(columns); } - } -} - -MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(const OrtApi &api, - const OrtKernelInfo *info) - : ort_(api), info_(info) { - std::vector stride = ort_.KernelInfoGetAttribute>(info, "stride"); - stride_height_ = stride[0]; - stride_width_ = stride[1]; - std::vector padding = ort_.KernelInfoGetAttribute>(info, "padding"); - padding_height_ = padding[0]; - padding_width_ = padding[1]; - std::vector dilation = - ort_.KernelInfoGetAttribute>(info, "dilation"); - dilation_height_ = dilation[0]; - dilation_width_ = dilation[1]; - deformable_group_ = ort_.KernelInfoGetAttribute(info, "deform_groups"); - group_ = ort_.KernelInfoGetAttribute(info, "groups"); - - // create allocator - allocator_ = Ort::AllocatorWithDefaultOptions(); -} - -void MMCVModulatedDeformConvKernel::Compute(OrtKernelContext *context) { - const int64_t stride_height = stride_height_; - const int64_t stride_width = stride_width_; - const int64_t padding_height = padding_height_; - const int64_t padding_width = padding_width_; - const int64_t dilation_height = dilation_height_; - const int64_t dilation_width = dilation_width_; - const int64_t deformable_group = deformable_group_; - const int64_t group = group_; - - const OrtValue *input = ort_.KernelContext_GetInput(context, 0); - const float *input_data = reinterpret_cast(ort_.GetTensorData(input)); - - const OrtValue *offset = ort_.KernelContext_GetInput(context, 1); - const float *offset_data = reinterpret_cast(ort_.GetTensorData(offset)); - - const OrtValue *mask = ort_.KernelContext_GetInput(context, 2); - const float *mask_data = reinterpret_cast(ort_.GetTensorData(mask)); - - const OrtValue *filter = ort_.KernelContext_GetInput(context, 3); - const float *filter_data = reinterpret_cast(ort_.GetTensorData(filter)); - - const OrtValue *bias = ort_.KernelContext_GetInput(context, 4); - const float *bias_data = (bias != nullptr) - ? reinterpret_cast(ort_.GetTensorData(bias)) - : nullptr; - // const float *bias_data = nullptr; - - OrtTensorDimensions input_dims(ort_, input); - OrtTensorDimensions filter_dims(ort_, filter); - - int64_t batch = input_dims[0]; - int64_t channels = input_dims[1]; - int64_t in_height = input_dims[2]; - int64_t in_width = input_dims[3]; - int64_t num_output = filter_dims[0]; - int64_t kernel_height = filter_dims[2]; - int64_t kernel_width = filter_dims[3]; - - // get output memory - int64_t out_height = floor( - (in_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) / stride_height + - 1); - int64_t out_width = floor( - (in_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width + 1); - - std::vector output_dims = {batch, num_output, out_height, out_width}; - OrtValue *output = - ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size()); - float *out_ptr = ort_.GetTensorMutableData(output); - - // allocate tmp memory - int64_t column_len = (channels / group) * kernel_height * kernel_width * out_height * out_width; - float *columns = (float *)allocator_.Alloc(sizeof(float) * column_len); - - deformable_conv2d_ref_fp32(input_data, offset_data, mask_data, filter_data, bias_data, batch, - channels, in_height, in_width, num_output, out_height, out_width, - group, deformable_group, channels, num_output, kernel_height, - kernel_width, stride_height, stride_width, padding_height, - padding_width, dilation_height, dilation_width, columns, out_ptr); - - allocator_.Free(columns); -} -REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVModulatedDeformConvOp); -REGISTER_ONNXRUNTIME_OPS(mmcv, MMCVModulatedDeformConvOp); + REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVModulatedDeformConvOp); + REGISTER_ONNXRUNTIME_OPS(mmcv, MMCVModulatedDeformConvOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.h b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.h index 772a9c4a88..7ffeb702d3 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.h @@ -4,55 +4,74 @@ #include -namespace mmdeploy { - -struct MMCVModulatedDeformConvKernel { - MMCVModulatedDeformConvKernel(const OrtApi &api, const OrtKernelInfo *info); - - void Compute(OrtKernelContext *context); - - protected: - Ort::CustomOpApi ort_; - const OrtKernelInfo *info_; - Ort::AllocatorWithDefaultOptions allocator_; - - int64_t stride_height_; - int64_t stride_width_; - int64_t padding_height_; - int64_t padding_width_; - int64_t dilation_height_; - int64_t dilation_width_; - int64_t deformable_group_; - int64_t group_; -}; - -struct MMCVModulatedDeformConvOp - : Ort::CustomOpBase { - void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const { - return new MMCVModulatedDeformConvKernel(api, info); - } - - const char *GetName() const { return "MMCVModulatedDeformConv2d"; }; - - size_t GetInputTypeCount() const { return 5; }; - ONNXTensorElementDataType GetInputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - }; - - OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const { - // The last input (index == 4) is optional, which is bias - if (index == 4) return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; - - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - } - - size_t GetOutputTypeCount() const { return 1; }; - ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - }; - - // force cpu - const char *GetExecutionProviderType() const { return "CPUExecutionProvider"; }; -}; +namespace mmdeploy +{ + + struct MMCVModulatedDeformConvKernel + { + MMCVModulatedDeformConvKernel(const OrtApi& api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + protected: + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; + + int64_t stride_height_; + int64_t stride_width_; + int64_t padding_height_; + int64_t padding_width_; + int64_t dilation_height_; + int64_t dilation_width_; + int64_t deformable_group_; + int64_t group_; + }; + + struct MMCVModulatedDeformConvOp + : Ort::CustomOpBase + { + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const + { + return new MMCVModulatedDeformConvKernel(api, info); + } + + const char* GetName() const + { + return "MMCVModulatedDeformConv2d"; + }; + + size_t GetInputTypeCount() const + { + return 5; + }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const + { + // The last input (index == 4) is optional, which is bias + if (index == 4) return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; + + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + size_t GetOutputTypeCount() const + { + return 1; + }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + // force cpu + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + }; + }; } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp index 784be2c987..397bcbf92c 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp @@ -13,117 +13,132 @@ #include "ort_utils.h" -namespace mmdeploy { -struct Box { - float x1, y1, x2, y2; -}; - -float nms_match_iou(Box box1, Box box2) { - auto inter_x1 = std::max(box1.x1, box2.x1); - auto inter_y1 = std::max(box1.y1, box2.y1); - auto inter_x2 = std::min(box1.x2, box2.x2); - auto inter_y2 = std::min(box1.y2, box2.y2); - - auto eps = 1e-10; - - auto w = std::max(static_cast(0), inter_x2 - inter_x1); - auto h = std::max(static_cast(0), inter_y2 - inter_y1); - - auto area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1); - auto area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1); - auto inter = w * h; - auto ovr = inter / (area1 + area2 - inter + eps); - return ovr; -} -NMSMatchKernel::NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info) - : ort_(api), info_(info) { - // create allocator - allocator_ = Ort::AllocatorWithDefaultOptions(); -} - -void NMSMatchKernel::Compute(OrtKernelContext* context) { - const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); - const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); - const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); - const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); - const OrtValue* iou_threshold_ = ort_.KernelContext_GetInput(context, 2); - const float iou_threshold_data = ort_.GetTensorData(iou_threshold_)[0]; - const OrtValue* score_threshold_ = ort_.KernelContext_GetInput(context, 3); - const float score_threshold_data = ort_.GetTensorData(score_threshold_)[0]; - - OrtTensorDimensions boxes_dim(ort_, boxes); - OrtTensorDimensions scores_dim(ort_, scores); - // loop over batch - int64_t nbatch = boxes_dim[0]; - int64_t nboxes = boxes_dim[1]; - int64_t nclass = scores_dim[1]; - assert(boxes_dim[2] == 4); //(x1, x2, y1, y2) - // alloc some temp memory - bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); - - std::vector res_order; - for (int64_t k = 0; k < nbatch; k++) { - for (int64_t g = 0; g < nclass; g++) { - for (int64_t i = 0; i < nboxes; i++) { - select[i] = true; - } - // scores - // k * nboxes * nclass means per batch - // g * nboxes means per class - // batch = 2 boxes = 3 classes = 4 - std::vector tmp_sc; - // get the class scores - for (int i = 0; i < nboxes; i++) { - tmp_sc.push_back(scores_data[k * nboxes * nclass + g * nboxes + i]); - } - - std::vector order(tmp_sc.size()); - std::iota(order.begin(), order.end(), 0); - std::sort(order.begin(), order.end(), - [&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; }); - for (int64_t _i = 0; _i < nboxes; _i++) { - auto i = order[_i]; - if (select[i] == false) continue; - std::vector v_i; - for (int64_t _j = _i + 1; _j < nboxes; _j++) { - auto j = order[_j]; - if (select[j] == false) continue; - Box vbox1, vbox2; - vbox1.x1 = boxes_data[k * nboxes * 4 + i * 4]; - vbox1.y1 = boxes_data[k * nboxes * 4 + i * 4 + 1]; - vbox1.x2 = boxes_data[k * nboxes * 4 + i * 4 + 2]; - vbox1.y2 = boxes_data[k * nboxes * 4 + i * 4 + 3]; - - vbox2.x1 = boxes_data[k * nboxes * 4 + j * 4]; - vbox2.y1 = boxes_data[k * nboxes * 4 + j * 4 + 1]; - vbox2.x2 = boxes_data[k * nboxes * 4 + j * 4 + 2]; - vbox2.y2 = boxes_data[k * nboxes * 4 + j * 4 + 3]; - - auto ovr = nms_match_iou(vbox1, vbox2); - if (ovr >= iou_threshold_data) { - select[j] = false; - v_i.push_back(j); - } - } - if (tmp_sc[i] > score_threshold_data && v_i.size() != 0) { - for (int v_i_idx = 0; v_i_idx < v_i.size(); v_i_idx++) { - res_order.push_back(k); - res_order.push_back(g); - res_order.push_back(i); - res_order.push_back(v_i[v_i_idx]); - } - } - } +namespace mmdeploy +{ + struct Box + { + float x1, y1, x2, y2; + }; + + float nms_match_iou(Box box1, Box box2) + { + auto inter_x1 = std::max(box1.x1, box2.x1); + auto inter_y1 = std::max(box1.y1, box2.y1); + auto inter_x2 = std::min(box1.x2, box2.x2); + auto inter_y2 = std::min(box1.y2, box2.y2); + + auto eps = 1e-10; + + auto w = std::max(static_cast(0), inter_x2 - inter_x1); + auto h = std::max(static_cast(0), inter_y2 - inter_y1); + + auto area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1); + auto area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1); + auto inter = w * h; + auto ovr = inter / (area1 + area2 - inter + eps); + return ovr; + } + NMSMatchKernel::NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info) + : ort_(api) + , info_(info) + { + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); } - } - std::vector inds_dims({(int64_t)res_order.size() / 4, 4}); - OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); - int64_t* res_data = ort_.GetTensorMutableData(res); + void NMSMatchKernel::Compute(OrtKernelContext* context) + { + const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); + const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); + const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); + const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); + const OrtValue* iou_threshold_ = ort_.KernelContext_GetInput(context, 2); + const float iou_threshold_data = ort_.GetTensorData(iou_threshold_)[0]; + const OrtValue* score_threshold_ = ort_.KernelContext_GetInput(context, 3); + const float score_threshold_data = ort_.GetTensorData(score_threshold_)[0]; + + OrtTensorDimensions boxes_dim(ort_, boxes); + OrtTensorDimensions scores_dim(ort_, scores); + // loop over batch + int64_t nbatch = boxes_dim[0]; + int64_t nboxes = boxes_dim[1]; + int64_t nclass = scores_dim[1]; + assert(boxes_dim[2] == 4); //(x1, x2, y1, y2) + // alloc some temp memory + bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); + + std::vector res_order; + for (int64_t k = 0; k < nbatch; k++) + { + for (int64_t g = 0; g < nclass; g++) + { + for (int64_t i = 0; i < nboxes; i++) + { + select[i] = true; + } + // scores + // k * nboxes * nclass means per batch + // g * nboxes means per class + // batch = 2 boxes = 3 classes = 4 + std::vector tmp_sc; + // get the class scores + for (int i = 0; i < nboxes; i++) + { + tmp_sc.push_back(scores_data[k * nboxes * nclass + g * nboxes + i]); + } + + std::vector order(tmp_sc.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&tmp_sc](int64_t id1, int64_t id2) + { return tmp_sc[id1] > tmp_sc[id2]; }); + for (int64_t _i = 0; _i < nboxes; _i++) + { + auto i = order[_i]; + if (select[i] == false) continue; + std::vector v_i; + for (int64_t _j = _i + 1; _j < nboxes; _j++) + { + auto j = order[_j]; + if (select[j] == false) continue; + Box vbox1, vbox2; + vbox1.x1 = boxes_data[k * nboxes * 4 + i * 4]; + vbox1.y1 = boxes_data[k * nboxes * 4 + i * 4 + 1]; + vbox1.x2 = boxes_data[k * nboxes * 4 + i * 4 + 2]; + vbox1.y2 = boxes_data[k * nboxes * 4 + i * 4 + 3]; + + vbox2.x1 = boxes_data[k * nboxes * 4 + j * 4]; + vbox2.y1 = boxes_data[k * nboxes * 4 + j * 4 + 1]; + vbox2.x2 = boxes_data[k * nboxes * 4 + j * 4 + 2]; + vbox2.y2 = boxes_data[k * nboxes * 4 + j * 4 + 3]; + + auto ovr = nms_match_iou(vbox1, vbox2); + if (ovr >= iou_threshold_data) + { + select[j] = false; + v_i.push_back(j); + } + } + if (tmp_sc[i] > score_threshold_data && v_i.size() != 0) + { + for (int v_i_idx = 0; v_i_idx < v_i.size(); v_i_idx++) + { + res_order.push_back(k); + res_order.push_back(g); + res_order.push_back(i); + res_order.push_back(v_i[v_i_idx]); + } + } + } + } + } + std::vector inds_dims({(int64_t)res_order.size() / 4, 4}); + + OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); + int64_t* res_data = ort_.GetTensorMutableData(res); - memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); + memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); - allocator_.Free(select); -} -REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSMatchOp); + allocator_.Free(select); + } + REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSMatchOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h index 57aa94d964..48e0d0dbb0 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h @@ -10,37 +10,55 @@ #include #include -namespace mmdeploy { -struct NMSMatchKernel { - NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info); - - void Compute(OrtKernelContext* context); - - private: - Ort::CustomOpApi ort_; - const OrtKernelInfo* info_; - Ort::AllocatorWithDefaultOptions allocator_; -}; - -struct NMSMatchOp : Ort::CustomOpBase { - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { - return new NMSMatchKernel(api, info); - } - const char* GetName() const { return "NMSMatch"; } - - size_t GetInputTypeCount() const { return 4; } - ONNXTensorElementDataType GetInputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - - size_t GetOutputTypeCount() const { return 1; } - ONNXTensorElementDataType GetOutputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } - - // force cpu - const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } -}; +namespace mmdeploy +{ + struct NMSMatchKernel + { + NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; + }; + + struct NMSMatchOp : Ort::CustomOpBase + { + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const + { + return new NMSMatchKernel(api, info); + } + const char* GetName() const + { + return "NMSMatch"; + } + + size_t GetInputTypeCount() const + { + return 4; + } + ONNXTensorElementDataType GetInputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const + { + return 1; + } + ONNXTensorElementDataType GetOutputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + + // force cpu + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + } + }; } // namespace mmdeploy #endif // ONNXRUNTIME_NMS_MATCH_H diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp index 9d8cc4597e..73c508ce47 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp @@ -13,356 +13,418 @@ #include "ort_utils.h" -namespace mmdeploy { - -namespace { -struct RotatedBox { - float x_ctr, y_ctr, w, h, a; -}; -struct Point { - float x, y; - Point(const float& px = 0, const float& py = 0) : x(px), y(py) {} - Point operator+(const Point& p) const { return Point(x + p.x, y + p.y); } - Point& operator+=(const Point& p) { - x += p.x; - y += p.y; - return *this; - } - Point operator-(const Point& p) const { return Point(x - p.x, y - p.y); } - Point operator*(const float coeff) const { return Point(x * coeff, y * coeff); } -}; - -float dot_2d(const Point& A, const Point& B) { return A.x * B.x + A.y * B.y; } - -float cross_2d(const Point& A, const Point& B) { return A.x * B.y - B.x * A.y; } -} // namespace - -void get_rotated_vertices(const RotatedBox& box, Point (&pts)[4]) { - // M_PI / 180. == 0.01745329251 - // double theta = box.a * 0.01745329251; - // MODIFIED - double theta = box.a; - float cosTheta2 = (float)cos(theta) * 0.5f; - float sinTheta2 = (float)sin(theta) * 0.5f; - - // y: top --> down; x: left --> right - pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; - pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; - pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; - pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; - pts[2].x = 2 * box.x_ctr - pts[0].x; - pts[2].y = 2 * box.y_ctr - pts[0].y; - pts[3].x = 2 * box.x_ctr - pts[1].x; - pts[3].y = 2 * box.y_ctr - pts[1].y; -} - -int get_intersection_points(const Point (&pts1)[4], const Point (&pts2)[4], - Point (&intersections)[24]) { - // Line vector - // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] - Point vec1[4], vec2[4]; - for (int i = 0; i < 4; i++) { - vec1[i] = pts1[(i + 1) % 4] - pts1[i]; - vec2[i] = pts2[(i + 1) % 4] - pts2[i]; - } - - // Line test - test all line combos for intersection - int num = 0; // number of intersections - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - // Solve for 2x2 Ax=b - float det = cross_2d(vec2[j], vec1[i]); - - // This takes care of parallel lines - if (fabs(det) <= 1e-14) { - continue; - } - - auto vec12 = pts2[j] - pts1[i]; - - float t1 = cross_2d(vec2[j], vec12) / det; - float t2 = cross_2d(vec1[i], vec12) / det; - - if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) { - intersections[num++] = pts1[i] + vec1[i] * t1; - } - } - } - - // Check for vertices of rect1 inside rect2 - { - const auto& AB = vec2[0]; - const auto& DA = vec2[3]; - auto ABdotAB = dot_2d(AB, AB); - auto ADdotAD = dot_2d(DA, DA); - for (int i = 0; i < 4; i++) { - // assume ABCD is the rectangle, and P is the point to be judged - // P is inside ABCD iff. P's projection on AB lies within AB - // and P's projection on AD lies within AD - - auto AP = pts1[i] - pts2[0]; - - auto APdotAB = dot_2d(AP, AB); - auto APdotAD = -dot_2d(AP, DA); - - if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { - intersections[num++] = pts1[i]; - } - } - } - - // Reverse the check - check for vertices of rect2 inside rect1 - { - const auto& AB = vec1[0]; - const auto& DA = vec1[3]; - auto ABdotAB = dot_2d(AB, AB); - auto ADdotAD = dot_2d(DA, DA); - for (int i = 0; i < 4; i++) { - auto AP = pts2[i] - pts1[0]; - - auto APdotAB = dot_2d(AP, AB); - auto APdotAD = -dot_2d(AP, DA); - - if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { - intersections[num++] = pts2[i]; - } +namespace mmdeploy +{ + + namespace + { + struct RotatedBox + { + float x_ctr, y_ctr, w, h, a; + }; + struct Point + { + float x, y; + Point(const float& px = 0, const float& py = 0) + : x(px) + , y(py) + { + } + Point operator+(const Point& p) const + { + return Point(x + p.x, y + p.y); + } + Point& operator+=(const Point& p) + { + x += p.x; + y += p.y; + return *this; + } + Point operator-(const Point& p) const + { + return Point(x - p.x, y - p.y); + } + Point operator*(const float coeff) const + { + return Point(x * coeff, y * coeff); + } + }; + + float dot_2d(const Point& A, const Point& B) + { + return A.x * B.x + A.y * B.y; + } + + float cross_2d(const Point& A, const Point& B) + { + return A.x * B.y - B.x * A.y; + } + } // namespace + + void get_rotated_vertices(const RotatedBox& box, Point (&pts)[4]) + { + // M_PI / 180. == 0.01745329251 + // double theta = box.a * 0.01745329251; + // MODIFIED + double theta = box.a; + float cosTheta2 = (float)cos(theta) * 0.5f; + float sinTheta2 = (float)sin(theta) * 0.5f; + + // y: top --> down; x: left --> right + pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; + pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; + pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; + pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; + pts[2].x = 2 * box.x_ctr - pts[0].x; + pts[2].y = 2 * box.y_ctr - pts[0].y; + pts[3].x = 2 * box.x_ctr - pts[1].x; + pts[3].y = 2 * box.y_ctr - pts[1].y; } - } - - return num; -} - -int convex_hull_graham(const Point (&p)[24], const int& num_in, Point (&q)[24], - bool shift_to_zero = false) { - assert(num_in >= 2); - - // Step 1: - // Find point with minimum y - // if more than 1 points have the same minimum y, - // pick the one with the minimum x. - int t = 0; - for (int i = 1; i < num_in; i++) { - if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) { - t = i; + + int get_intersection_points(const Point (&pts1)[4], const Point (&pts2)[4], Point (&intersections)[24]) + { + // Line vector + // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] + Point vec1[4], vec2[4]; + for (int i = 0; i < 4; i++) + { + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; + } + + // Line test - test all line combos for intersection + int num = 0; // number of intersections + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + // Solve for 2x2 Ax=b + float det = cross_2d(vec2[j], vec1[i]); + + // This takes care of parallel lines + if (fabs(det) <= 1e-14) + { + continue; + } + + auto vec12 = pts2[j] - pts1[i]; + + float t1 = cross_2d(vec2[j], vec12) / det; + float t2 = cross_2d(vec1[i], vec12) / det; + + if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) + { + intersections[num++] = pts1[i] + vec1[i] * t1; + } + } + } + + // Check for vertices of rect1 inside rect2 + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) + { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD + + auto AP = pts1[i] - pts2[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) + { + intersections[num++] = pts1[i]; + } + } + } + + // Reverse the check - check for vertices of rect2 inside rect1 + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) + { + auto AP = pts2[i] - pts1[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) + { + intersections[num++] = pts2[i]; + } + } + } + + return num; } - } - auto& start = p[t]; // starting point - - // Step 2: - // Subtract starting point from every points (for sorting in the next step) - for (int i = 0; i < num_in; i++) { - q[i] = p[i] - start; - } - - // Swap the starting point to position 0 - auto tmp = q[0]; - q[0] = q[t]; - q[t] = tmp; - - // Step 3: - // Sort point 1 ~ num_in according to their relative cross-product values - // (essentially sorting according to angles) - // If the angles are the same, sort according to their distance to origin - float dist[24]; - for (int i = 0; i < num_in; i++) { - dist[i] = dot_2d(q[i], q[i]); - } - - // CPU version - std::sort(q + 1, q + num_in, [](const Point& A, const Point& B) -> bool { + + int convex_hull_graham(const Point (&p)[24], const int& num_in, Point (&q)[24], bool shift_to_zero = false) + { + assert(num_in >= 2); + + // Step 1: + // Find point with minimum y + // if more than 1 points have the same minimum y, + // pick the one with the minimum x. + int t = 0; + for (int i = 1; i < num_in; i++) + { + if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) + { + t = i; + } + } + auto& start = p[t]; // starting point + + // Step 2: + // Subtract starting point from every points (for sorting in the next step) + for (int i = 0; i < num_in; i++) + { + q[i] = p[i] - start; + } + + // Swap the starting point to position 0 + auto tmp = q[0]; + q[0] = q[t]; + q[t] = tmp; + + // Step 3: + // Sort point 1 ~ num_in according to their relative cross-product values + // (essentially sorting according to angles) + // If the angles are the same, sort according to their distance to origin + float dist[24]; + for (int i = 0; i < num_in; i++) + { + dist[i] = dot_2d(q[i], q[i]); + } + + // CPU version + std::sort(q + 1, q + num_in, [](const Point& A, const Point& B) -> bool + { float temp = cross_2d(A, B); if (fabs(temp) < 1e-6) { return dot_2d(A, A) < dot_2d(B, B); } else { return temp > 0; + } }); + // compute distance to origin after sort, since the points are now different. + for (int i = 0; i < num_in; i++) + { + dist[i] = dot_2d(q[i], q[i]); + } + + // Step 4: + // Make sure there are at least 2 points (that don't overlap with each other) + // in the stack + int k; // index of the non-overlapped second point + for (k = 1; k < num_in; k++) + { + if (dist[k] > 1e-8) + { + break; + } + } + if (k == num_in) + { + // We reach the end, which means the convex hull is just one point + q[0] = p[t]; + return 1; + } + q[1] = q[k]; + int m = 2; // 2 points in the stack + // Step 5: + // Finally we can start the scanning process. + // When a non-convex relationship between the 3 points is found + // (either concave shape or duplicated points), + // we pop the previous point from the stack + // until the 3-point relationship is convex again, or + // until the stack only contains two points + for (int i = k + 1; i < num_in; i++) + { + while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) + { + m--; + } + q[m++] = q[i]; + } + + // Step 6 (Optional): + // In general sense we need the original coordinates, so we + // need to shift the points back (reverting Step 2) + // But if we're only interested in getting the area/perimeter of the shape + // We can simply return. + if (!shift_to_zero) + { + for (int i = 0; i < m; i++) + { + q[i] += start; + } + } + + return m; } - }); - // compute distance to origin after sort, since the points are now different. - for (int i = 0; i < num_in; i++) { - dist[i] = dot_2d(q[i], q[i]); - } - - // Step 4: - // Make sure there are at least 2 points (that don't overlap with each other) - // in the stack - int k; // index of the non-overlapped second point - for (k = 1; k < num_in; k++) { - if (dist[k] > 1e-8) { - break; - } - } - if (k == num_in) { - // We reach the end, which means the convex hull is just one point - q[0] = p[t]; - return 1; - } - q[1] = q[k]; - int m = 2; // 2 points in the stack - // Step 5: - // Finally we can start the scanning process. - // When a non-convex relationship between the 3 points is found - // (either concave shape or duplicated points), - // we pop the previous point from the stack - // until the 3-point relationship is convex again, or - // until the stack only contains two points - for (int i = k + 1; i < num_in; i++) { - while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) { - m--; - } - q[m++] = q[i]; - } - - // Step 6 (Optional): - // In general sense we need the original coordinates, so we - // need to shift the points back (reverting Step 2) - // But if we're only interested in getting the area/perimeter of the shape - // We can simply return. - if (!shift_to_zero) { - for (int i = 0; i < m; i++) { - q[i] += start; - } - } - - return m; -} - -float polygon_area(const Point (&q)[24], const int& m) { - if (m <= 2) { - return 0; - } - - float area = 0; - for (int i = 1; i < m - 1; i++) { - area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); - } - - return area / 2.0; -} - -float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2) { - // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned - // from rotated_rect_intersection_pts - Point intersectPts[24], orderedPts[24]; - - Point pts1[4]; - Point pts2[4]; - get_rotated_vertices(box1, pts1); - get_rotated_vertices(box2, pts2); - - int num = get_intersection_points(pts1, pts2, intersectPts); - - if (num <= 2) { - return 0.0; - } - - // Convex Hull to order the intersection points in clockwise order and find - // the contour area. - int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); - return polygon_area(orderedPts, num_convex); -} - -NMSRotatedKernel::NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info) - : ort_(api), info_(info) { - iou_threshold_ = ort_.KernelInfoGetAttribute(info, "iou_threshold"); - score_threshold_ = ort_.KernelInfoGetAttribute(info, "score_threshold"); - - // create allocator - allocator_ = Ort::AllocatorWithDefaultOptions(); -} - -void NMSRotatedKernel::Compute(OrtKernelContext* context) { - const float iou_threshold = iou_threshold_; - const float score_threshold = score_threshold_; - - const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); - const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); - const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); - const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); - - OrtTensorDimensions boxes_dim(ort_, boxes); - OrtTensorDimensions scores_dim(ort_, scores); - - // loop over batch - int64_t nbatch = boxes_dim[0]; - int64_t nboxes = boxes_dim[1]; - int64_t nclass = scores_dim[1]; - assert(boxes_dim[2] == 5); //(cx,cy,w,h,theta) - - // allocate tmp memory - float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nbatch * nboxes * 5); - float* sc = (float*)allocator_.Alloc(sizeof(float) * nbatch * nclass * nboxes); - bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); - - memcpy(tmp_boxes, boxes_data, sizeof(float) * nbatch * nboxes * 5); - memcpy(sc, scores_data, sizeof(float) * nbatch * nclass * nboxes); - - // std::vector> res_order; - std::vector res_order; - for (int64_t k = 0; k < nbatch; k++) { - for (int64_t g = 0; g < nclass; g++) { - for (int64_t i = 0; i < nboxes; i++) { - select[i] = true; - } - // sort scores - std::vector tmp_sc; - for (int i = 0; i < nboxes; i++) { - tmp_sc.push_back(sc[k * nboxes * nclass + g * nboxes + i]); - } - std::vector order(tmp_sc.size()); - std::iota(order.begin(), order.end(), 0); - std::sort(order.begin(), order.end(), - [&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; }); - for (int64_t _i = 0; _i < nboxes; _i++) { - if (select[_i] == false) continue; - auto i = order[_i]; - for (int64_t _j = _i + 1; _j < nboxes; _j++) { - if (select[_j] == false) continue; - auto j = order[_j]; - RotatedBox box1, box2; - auto center_shift_x = - (tmp_boxes[k * nboxes * 5 + i * 5] + tmp_boxes[k * nboxes * 5 + j * 5]) / 2.0; - auto center_shift_y = - (tmp_boxes[k * nboxes * 5 + i * 5 + 1] + tmp_boxes[k * nboxes * 5 + j * 5 + 1]) / 2.0; - box1.x_ctr = tmp_boxes[k * nboxes * 5 + i * 5] - center_shift_x; - box1.y_ctr = tmp_boxes[k * nboxes * 5 + i * 5 + 1] - center_shift_y; - box1.w = tmp_boxes[k * nboxes * 5 + i * 5 + 2]; - box1.h = tmp_boxes[k * nboxes * 5 + i * 5 + 3]; - box1.a = tmp_boxes[k * nboxes * 5 + i * 5 + 4]; - box2.x_ctr = tmp_boxes[k * nboxes * 5 + j * 5] - center_shift_x; - box2.y_ctr = tmp_boxes[k * nboxes * 5 + j * 5 + 1] - center_shift_y; - box2.w = tmp_boxes[k * nboxes * 5 + j * 5 + 2]; - box2.h = tmp_boxes[k * nboxes * 5 + j * 5 + 3]; - box2.a = tmp_boxes[k * nboxes * 5 + j * 5 + 4]; - auto area1 = box1.w * box1.h; - auto area2 = box2.w * box2.h; - auto intersection = rotated_boxes_intersection(box1, box2); - float baseS = 1.0; - baseS = (area1 + area2 - intersection); - auto ovr = intersection / baseS; - if (ovr > iou_threshold) select[_j] = false; + + float polygon_area(const Point (&q)[24], const int& m) + { + if (m <= 2) + { + return 0; } - } - for (int i = 0; i < nboxes; i++) { - if (select[i] & (tmp_sc[order[i]] > score_threshold)) { - res_order.push_back(k); - res_order.push_back(g); - res_order.push_back(order[i]); + + float area = 0; + for (int i = 1; i < m - 1; i++) + { + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); } - } - } // class loop - } // batch loop - std::vector inds_dims({(int64_t)res_order.size() / 3, 3}); + return area / 2.0; + } + + float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2) + { + // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned + // from rotated_rect_intersection_pts + Point intersectPts[24], orderedPts[24]; - OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); - int64_t* res_data = ort_.GetTensorMutableData(res); + Point pts1[4]; + Point pts2[4]; + get_rotated_vertices(box1, pts1); + get_rotated_vertices(box2, pts2); - memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); + int num = get_intersection_points(pts1, pts2, intersectPts); - allocator_.Free(tmp_boxes); - allocator_.Free(sc); - allocator_.Free(select); -} + if (num <= 2) + { + return 0.0; + } + + // Convex Hull to order the intersection points in clockwise order and find + // the contour area. + int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); + return polygon_area(orderedPts, num_convex); + } + + NMSRotatedKernel::NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info) + : ort_(api) + , info_(info) + { + iou_threshold_ = ort_.KernelInfoGetAttribute(info, "iou_threshold"); + score_threshold_ = ort_.KernelInfoGetAttribute(info, "score_threshold"); + + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); + } + + void NMSRotatedKernel::Compute(OrtKernelContext* context) + { + const float iou_threshold = iou_threshold_; + const float score_threshold = score_threshold_; + + const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); + const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); + const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); + const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); + + OrtTensorDimensions boxes_dim(ort_, boxes); + OrtTensorDimensions scores_dim(ort_, scores); + + // loop over batch + int64_t nbatch = boxes_dim[0]; + int64_t nboxes = boxes_dim[1]; + int64_t nclass = scores_dim[1]; + assert(boxes_dim[2] == 5); //(cx,cy,w,h,theta) + + // allocate tmp memory + float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nbatch * nboxes * 5); + float* sc = (float*)allocator_.Alloc(sizeof(float) * nbatch * nclass * nboxes); + bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); + + memcpy(tmp_boxes, boxes_data, sizeof(float) * nbatch * nboxes * 5); + memcpy(sc, scores_data, sizeof(float) * nbatch * nclass * nboxes); + + // std::vector> res_order; + std::vector res_order; + for (int64_t k = 0; k < nbatch; k++) + { + for (int64_t g = 0; g < nclass; g++) + { + for (int64_t i = 0; i < nboxes; i++) + { + select[i] = true; + } + // sort scores + std::vector tmp_sc; + for (int i = 0; i < nboxes; i++) + { + tmp_sc.push_back(sc[k * nboxes * nclass + g * nboxes + i]); + } + std::vector order(tmp_sc.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&tmp_sc](int64_t id1, int64_t id2) + { return tmp_sc[id1] > tmp_sc[id2]; }); + for (int64_t _i = 0; _i < nboxes; _i++) + { + if (select[_i] == false) continue; + auto i = order[_i]; + for (int64_t _j = _i + 1; _j < nboxes; _j++) + { + if (select[_j] == false) continue; + auto j = order[_j]; + RotatedBox box1, box2; + auto center_shift_x = + (tmp_boxes[k * nboxes * 5 + i * 5] + tmp_boxes[k * nboxes * 5 + j * 5]) / 2.0; + auto center_shift_y = + (tmp_boxes[k * nboxes * 5 + i * 5 + 1] + tmp_boxes[k * nboxes * 5 + j * 5 + 1]) / 2.0; + box1.x_ctr = tmp_boxes[k * nboxes * 5 + i * 5] - center_shift_x; + box1.y_ctr = tmp_boxes[k * nboxes * 5 + i * 5 + 1] - center_shift_y; + box1.w = tmp_boxes[k * nboxes * 5 + i * 5 + 2]; + box1.h = tmp_boxes[k * nboxes * 5 + i * 5 + 3]; + box1.a = tmp_boxes[k * nboxes * 5 + i * 5 + 4]; + box2.x_ctr = tmp_boxes[k * nboxes * 5 + j * 5] - center_shift_x; + box2.y_ctr = tmp_boxes[k * nboxes * 5 + j * 5 + 1] - center_shift_y; + box2.w = tmp_boxes[k * nboxes * 5 + j * 5 + 2]; + box2.h = tmp_boxes[k * nboxes * 5 + j * 5 + 3]; + box2.a = tmp_boxes[k * nboxes * 5 + j * 5 + 4]; + auto area1 = box1.w * box1.h; + auto area2 = box2.w * box2.h; + auto intersection = rotated_boxes_intersection(box1, box2); + float baseS = 1.0; + baseS = (area1 + area2 - intersection); + auto ovr = intersection / baseS; + if (ovr > iou_threshold) select[_j] = false; + } + } + for (int i = 0; i < nboxes; i++) + { + if (select[i] & (tmp_sc[order[i]] > score_threshold)) + { + res_order.push_back(k); + res_order.push_back(g); + res_order.push_back(order[i]); + } + } + } // class loop + } // batch loop + + std::vector inds_dims({(int64_t)res_order.size() / 3, 3}); + + OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); + int64_t* res_data = ort_.GetTensorMutableData(res); + + memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); + + allocator_.Free(tmp_boxes); + allocator_.Free(sc); + allocator_.Free(select); + } -REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSRotatedOp); + REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSRotatedOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.h b/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.h index 6ed44ce410..3b4aa856a5 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.h @@ -10,39 +10,57 @@ #include #include -namespace mmdeploy { -struct NMSRotatedKernel { - NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info); - - void Compute(OrtKernelContext* context); - - private: - Ort::CustomOpApi ort_; - const OrtKernelInfo* info_; - Ort::AllocatorWithDefaultOptions allocator_; - float iou_threshold_; - float score_threshold_; -}; - -struct NMSRotatedOp : Ort::CustomOpBase { - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { - return new NMSRotatedKernel(api, info); - } - const char* GetName() const { return "NMSRotated"; } - - size_t GetInputTypeCount() const { return 2; } - ONNXTensorElementDataType GetInputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - - size_t GetOutputTypeCount() const { return 1; } - ONNXTensorElementDataType GetOutputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } - - // force cpu - const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } -}; +namespace mmdeploy +{ + struct NMSRotatedKernel + { + NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; + float iou_threshold_; + float score_threshold_; + }; + + struct NMSRotatedOp : Ort::CustomOpBase + { + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const + { + return new NMSRotatedKernel(api, info); + } + const char* GetName() const + { + return "NMSRotated"; + } + + size_t GetInputTypeCount() const + { + return 2; + } + ONNXTensorElementDataType GetInputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const + { + return 1; + } + ONNXTensorElementDataType GetOutputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + + // force cpu + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + } + }; } // namespace mmdeploy #endif // ONNXRUNTIME_NMS_ROTATED_H diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/onnxruntime_register.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/onnxruntime_register.cpp index f7b9cedff8..1159496843 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/onnxruntime_register.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/onnxruntime_register.cpp @@ -3,25 +3,30 @@ #include "ort_utils.h" -const char *c_MMDeployOpDomain = "mmdeploy"; +const char* c_MMDeployOpDomain = "mmdeploy"; -OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) { - const OrtApi *kOrtApi = api->GetApi(ORT_API_VERSION); - OrtStatus *status = nullptr; - for (auto &_op_list_pair : mmdeploy::get_mmdeploy_custom_ops()) { - OrtCustomOpDomain *domain = nullptr; - if (auto status = kOrtApi->CreateCustomOpDomain(_op_list_pair.first.c_str(), &domain)) { - return status; +OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) +{ + const OrtApi* kOrtApi = api->GetApi(ORT_API_VERSION); + OrtStatus* status = nullptr; + for (auto& _op_list_pair : mmdeploy::get_mmdeploy_custom_ops()) + { + OrtCustomOpDomain* domain = nullptr; + if (auto status = kOrtApi->CreateCustomOpDomain(_op_list_pair.first.c_str(), &domain)) + { + return status; + } + auto& _op_list = _op_list_pair.second; + for (auto& _op : _op_list) + { + if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) + { + return status; + } + } + // TODO: figure out what will return if failed. + status = kOrtApi->AddCustomOpDomain(options, domain); } - auto &_op_list = _op_list_pair.second; - for (auto &_op : _op_list) { - if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) { - return status; - } - } - // TODO: figure out what will return if failed. - status = kOrtApi->AddCustomOpDomain(options, domain); - } - return status; + return status; } diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.cpp index a8e7023fe1..4fbf6365d0 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.cpp @@ -5,233 +5,245 @@ #include "ort_utils.h" -namespace mmdeploy { -// implementation taken from Caffe2 -struct PreCalc { - int pos1; - int pos2; - int pos3; - int pos4; - float w1; - float w2; - float w3; - float w4; -}; - -void pre_calc_for_bilinear_interpolate(const int height, const int width, const int pooled_height, - const int pooled_width, const int iy_upper, - const int ix_upper, float roi_start_h, float roi_start_w, - float bin_size_h, float bin_size_w, int roi_bin_grid_h, - int roi_bin_grid_w, float roi_center_h, float roi_center_w, - float cos_theta, float sin_theta, - std::vector &pre_calc) { - int pre_calc_index = 0; - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - for (int iy = 0; iy < iy_upper; iy++) { - const float yy = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < ix_upper; ix++) { - const float xx = - roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - - // Rotate by theta around the center and translate - // In image space, (y, x) is the order for Right Handed System, - // and this is essentially multiplying the point by a rotation matrix - // to rotate it counterclockwise through angle theta. - float y = yy * cos_theta - xx * sin_theta + roi_center_h; - float x = yy * sin_theta + xx * cos_theta + roi_center_w; - // deal with: inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - PreCalc pc; - pc.pos1 = 0; - pc.pos2 = 0; - pc.pos3 = 0; - pc.pos4 = 0; - pc.w1 = 0; - pc.w2 = 0; - pc.w3 = 0; - pc.w4 = 0; - pre_calc[pre_calc_index] = pc; - pre_calc_index += 1; - continue; - } - - if (y < 0) { - y = 0; - } - if (x < 0) { - x = 0; - } - - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (float)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (float)x_low; - } else { - x_high = x_low + 1; - } - - float ly = y - y_low; - float lx = x - x_low; - float hy = 1. - ly, hx = 1. - lx; - float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - // save weights and indices - PreCalc pc; - pc.pos1 = y_low * width + x_low; - pc.pos2 = y_low * width + x_high; - pc.pos3 = y_high * width + x_low; - pc.pos4 = y_high * width + x_high; - pc.w1 = w1; - pc.w2 = w2; - pc.w3 = w3; - pc.w4 = w4; - pre_calc[pre_calc_index] = pc; - - pre_calc_index += 1; +namespace mmdeploy +{ + // implementation taken from Caffe2 + struct PreCalc + { + int pos1; + int pos2; + int pos3; + int pos4; + float w1; + float w2; + float w3; + float w4; + }; + + void pre_calc_for_bilinear_interpolate(const int height, const int width, const int pooled_height, const int pooled_width, const int iy_upper, const int ix_upper, float roi_start_h, float roi_start_w, float bin_size_h, float bin_size_w, int roi_bin_grid_h, int roi_bin_grid_w, float roi_center_h, float roi_center_w, float cos_theta, float sin_theta, std::vector& pre_calc) + { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) + { + for (int pw = 0; pw < pooled_width; pw++) + { + for (int iy = 0; iy < iy_upper; iy++) + { + const float yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) + { + const float xx = + roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + // In image space, (y, x) is the order for Right Handed System, + // and this is essentially multiplying the point by a rotation matrix + // to rotate it counterclockwise through angle theta. + float y = yy * cos_theta - xx * sin_theta + roi_center_h; + float x = yy * sin_theta + xx * cos_theta + roi_center_w; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) + { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y < 0) + { + y = 0; + } + if (x < 0) + { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) + { + y_high = y_low = height - 1; + y = (float)y_low; + } + else + { + y_high = y_low + 1; + } + + if (x_low >= width - 1) + { + x_high = x_low = width - 1; + x = (float)x_low; + } + else + { + x_high = x_low + 1; + } + + float ly = y - y_low; + float lx = x - x_low; + float hy = 1. - ly, hx = 1. - lx; + float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } } - } - } - } -} - -void ROIAlignRotatedForwardCPU(const int nthreads, const float *input, const float *rois, - float *output, const float &spatial_scale, const int aligned, - const int clockwise, const int channels, const int height, - const int width, const int pooled_height, const int pooled_width, - const int sampling_ratio) { - int n_rois = nthreads / channels / pooled_width / pooled_height; - // (n, c, ph, pw) is an element in the pooled output - // can be parallelized using omp - // #pragma omp parallel for num_threads(32) - for (int n = 0; n < n_rois; n++) { - int index_n = n * channels * pooled_width * pooled_height; - - const float *current_roi = rois + n * 6; - int roi_batch_ind = current_roi[0]; - - // Do not use rounding; this implementation detail is critical - float offset = aligned ? (float)0.5 : (float)0.0; - float roi_center_w = current_roi[1] * spatial_scale - offset; - float roi_center_h = current_roi[2] * spatial_scale - offset; - float roi_width = current_roi[3] * spatial_scale; - float roi_height = current_roi[4] * spatial_scale; - // float theta = current_roi[5] * M_PI / 180.0; - float theta = current_roi[5]; // Radian angle by default - if (clockwise) { - theta = -theta; } - float cos_theta = cos(theta); - float sin_theta = sin(theta); - if (!aligned) { // for backward-compatibility only - roi_width = std::max(roi_width, (float)1.); - roi_height = std::max(roi_height, (float)1.); - } - - float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - float bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - const float count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 - - // we want to precalculate indices and weights shared by all channels, - // this is the key point of optimization - std::vector pre_calc(roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); - - // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). - // Appropriate translation needs to be applied after. - float roi_start_h = -roi_height / 2.0; - float roi_start_w = -roi_width / 2.0; - pre_calc_for_bilinear_interpolate(height, width, pooled_height, pooled_width, roi_bin_grid_h, - roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, - bin_size_w, roi_bin_grid_h, roi_bin_grid_w, roi_center_h, - roi_center_w, cos_theta, sin_theta, pre_calc); - - for (int c = 0; c < channels; c++) { - int index_n_c = index_n + c * pooled_width * pooled_height; - const float *offset_input = input + (roi_batch_ind * channels + c) * height * width; - int pre_calc_index = 0; + void ROIAlignRotatedForwardCPU(const int nthreads, const float* input, const float* rois, float* output, const float& spatial_scale, const int aligned, const int clockwise, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio) + { + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) + { + int index_n = n * channels * pooled_width * pooled_height; + + const float* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + float offset = aligned ? (float)0.5 : (float)0.0; + float roi_center_w = current_roi[1] * spatial_scale - offset; + float roi_center_h = current_roi[2] * spatial_scale - offset; + float roi_width = current_roi[3] * spatial_scale; + float roi_height = current_roi[4] * spatial_scale; + // float theta = current_roi[5] * M_PI / 180.0; + float theta = current_roi[5]; // Radian angle by default + if (clockwise) + { + theta = -theta; + } + float cos_theta = cos(theta); + float sin_theta = sin(theta); + if (!aligned) + { // for backward-compatibility only + roi_width = std::max(roi_width, (float)1.); + roi_height = std::max(roi_height, (float)1.); + } - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - int index = index_n_c + ph * pooled_width + pw; + float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + float bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const float count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector pre_calc(roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + float roi_start_h = -roi_height / 2.0; + float roi_start_w = -roi_width / 2.0; + + pre_calc_for_bilinear_interpolate(height, width, pooled_height, pooled_width, roi_bin_grid_h, roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w, roi_bin_grid_h, roi_bin_grid_w, roi_center_h, roi_center_w, cos_theta, sin_theta, pre_calc); + + for (int c = 0; c < channels; c++) + { + int index_n_c = index_n + c * pooled_width * pooled_height; + const float* offset_input = input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) + { + for (int pw = 0; pw < pooled_width; pw++) + { + int index = index_n_c + ph * pooled_width + pw; + + float output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_input[pc.pos1] + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; + + output[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n + } - float output_val = 0.; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - PreCalc pc = pre_calc[pre_calc_index]; - output_val += pc.w1 * offset_input[pc.pos1] + pc.w2 * offset_input[pc.pos2] + - pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; + void MMCVRoIAlignRotatedKernel::Compute(OrtKernelContext* context) + { + // Setup inputs + const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0); + const float* X_data = reinterpret_cast(ort_.GetTensorData(input_X)); + const OrtValue* input_rois = ort_.KernelContext_GetInput(context, 1); + const float* rois = + reinterpret_cast(ort_.GetTensorData(input_rois)); + + // Setup output + OrtTensorDimensions out_dimensions(ort_, input_X); + OrtTensorDimensions roi_dimensions(ort_, input_rois); + + int batch_size = out_dimensions.data()[0]; + int input_channels = out_dimensions.data()[1]; + int input_height = out_dimensions.data()[2]; + int input_width = out_dimensions.data()[3]; + + out_dimensions.data()[0] = roi_dimensions.data()[0]; + out_dimensions.data()[2] = aligned_height_; + out_dimensions.data()[3] = aligned_width_; + + OrtValue* output = + ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); + float* out = ort_.GetTensorMutableData(output); + OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output); + ort_.ReleaseTensorTypeAndShapeInfo(output_info); + + // TODO: forward here + int output_size = out_dimensions.data()[0]; + for (auto i = 1; i < out_dimensions.size(); ++i) + { + output_size *= out_dimensions.data()[i]; + } + ROIAlignRotatedForwardCPU(output_size, X_data, rois, out, spatial_scale_, aligned_, clockwise_, input_channels, input_height, input_width, aligned_height_, aligned_width_, sampling_ratio_); + } - pre_calc_index += 1; - } - } - output_val /= count; - - output[index] = output_val; - } // for pw - } // for ph - } // for c - } // for n -} - -void MMCVRoIAlignRotatedKernel::Compute(OrtKernelContext *context) { - // Setup inputs - const OrtValue *input_X = ort_.KernelContext_GetInput(context, 0); - const float *X_data = reinterpret_cast(ort_.GetTensorData(input_X)); - const OrtValue *input_rois = ort_.KernelContext_GetInput(context, 1); - const float *rois = - reinterpret_cast(ort_.GetTensorData(input_rois)); - - // Setup output - OrtTensorDimensions out_dimensions(ort_, input_X); - OrtTensorDimensions roi_dimensions(ort_, input_rois); - - int batch_size = out_dimensions.data()[0]; - int input_channels = out_dimensions.data()[1]; - int input_height = out_dimensions.data()[2]; - int input_width = out_dimensions.data()[3]; - - out_dimensions.data()[0] = roi_dimensions.data()[0]; - out_dimensions.data()[2] = aligned_height_; - out_dimensions.data()[3] = aligned_width_; - - OrtValue *output = - ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); - float *out = ort_.GetTensorMutableData(output); - OrtTensorTypeAndShapeInfo *output_info = ort_.GetTensorTypeAndShape(output); - ort_.ReleaseTensorTypeAndShapeInfo(output_info); - - // TODO: forward here - int output_size = out_dimensions.data()[0]; - for (auto i = 1; i < out_dimensions.size(); ++i) { - output_size *= out_dimensions.data()[i]; - } - ROIAlignRotatedForwardCPU(output_size, X_data, rois, out, spatial_scale_, aligned_, clockwise_, - input_channels, input_height, input_width, aligned_height_, - aligned_width_, sampling_ratio_); -} - -REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVRoIAlignRotatedCustomOp); + REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVRoIAlignRotatedCustomOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.h b/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.h index c0129d31f8..24a90e5321 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.h @@ -10,50 +10,70 @@ #include #include -namespace mmdeploy { -struct MMCVRoIAlignRotatedKernel { - public: - MMCVRoIAlignRotatedKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) : ort_(ort) { - aligned_height_ = ort_.KernelInfoGetAttribute(info, "output_height"); - aligned_width_ = ort_.KernelInfoGetAttribute(info, "output_width"); - sampling_ratio_ = ort_.KernelInfoGetAttribute(info, "sampling_ratio"); - spatial_scale_ = ort_.KernelInfoGetAttribute(info, "spatial_scale"); - aligned_ = ort_.KernelInfoGetAttribute(info, "aligned"); - clockwise_ = ort_.KernelInfoGetAttribute(info, "clockwise"); - } - - void Compute(OrtKernelContext* context); - - private: - Ort::CustomOpApi ort_; - int aligned_height_; - int aligned_width_; - float spatial_scale_; - int sampling_ratio_; - int aligned_; - int clockwise_; -}; - -struct MMCVRoIAlignRotatedCustomOp - : Ort::CustomOpBase { - void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { - return new MMCVRoIAlignRotatedKernel(api, info); - } - const char* GetName() const { return "MMCVRoIAlignRotated"; } - - size_t GetInputTypeCount() const { return 2; } - ONNXTensorElementDataType GetInputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - - size_t GetOutputTypeCount() const { return 1; } - ONNXTensorElementDataType GetOutputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - - // force cpu - const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } -}; +namespace mmdeploy +{ + struct MMCVRoIAlignRotatedKernel + { + public: + MMCVRoIAlignRotatedKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) + : ort_(ort) + { + aligned_height_ = ort_.KernelInfoGetAttribute(info, "output_height"); + aligned_width_ = ort_.KernelInfoGetAttribute(info, "output_width"); + sampling_ratio_ = ort_.KernelInfoGetAttribute(info, "sampling_ratio"); + spatial_scale_ = ort_.KernelInfoGetAttribute(info, "spatial_scale"); + aligned_ = ort_.KernelInfoGetAttribute(info, "aligned"); + clockwise_ = ort_.KernelInfoGetAttribute(info, "clockwise"); + } + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + int aligned_height_; + int aligned_width_; + float spatial_scale_; + int sampling_ratio_; + int aligned_; + int clockwise_; + }; + + struct MMCVRoIAlignRotatedCustomOp + : Ort::CustomOpBase + { + void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const + { + return new MMCVRoIAlignRotatedKernel(api, info); + } + const char* GetName() const + { + return "MMCVRoIAlignRotated"; + } + + size_t GetInputTypeCount() const + { + return 2; + } + ONNXTensorElementDataType GetInputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const + { + return 1; + } + ONNXTensorElementDataType GetOutputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + // force cpu + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + } + }; } // namespace mmdeploy #endif // ONNXRUNTIME_ROI_ALIGN_ROTATED_H diff --git a/csrc/mmdeploy/backend_ops/tensorrt/CMakeLists.txt b/csrc/mmdeploy/backend_ops/tensorrt/CMakeLists.txt index a221311acd..d43e8c4a1b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/tensorrt/CMakeLists.txt @@ -4,28 +4,28 @@ project(mmdeploy_tensorrt_ops) include(${CMAKE_SOURCE_DIR}/cmake/tensorrt.cmake) # cub -if (NOT DEFINED CUB_ROOT_DIR) - if (CUDA_VERSION VERSION_LESS 11.0) - set(CUB_ROOT_DIR "${CMAKE_SOURCE_DIR}/third_party/cub") - endif () -endif () +if(NOT DEFINED CUB_ROOT_DIR) + if(CUDA_VERSION VERSION_LESS 11.0) + set(CUB_ROOT_DIR "${CMAKE_SOURCE_DIR}/third_party/cub") + endif() +endif() file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp *.cu) add_library(${PROJECT_NAME}_obj OBJECT "${BACKEND_OPS_SRCS}") -set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE 1) +set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE + 1) target_compile_definitions(${PROJECT_NAME}_obj - PRIVATE -DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT=1) + PRIVATE -DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT=1) target_include_directories(${PROJECT_NAME}_obj - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common) target_include_directories(${PROJECT_NAME}_obj - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/common) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/common) target_include_directories(${PROJECT_NAME}_obj - PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/include) + PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/include) target_include_directories(${PROJECT_NAME}_obj PRIVATE ${TENSORRT_INCLUDE_DIR}) target_include_directories(${PROJECT_NAME}_obj PRIVATE ${CUDNN_DIR}/include) target_include_directories(${PROJECT_NAME}_obj PRIVATE ${CUB_ROOT_DIR}) -target_link_libraries(${PROJECT_NAME}_obj - PUBLIC ${TENSORRT_LIBS} cublas cudnn) +target_link_libraries(${PROJECT_NAME}_obj PUBLIC ${TENSORRT_LIBS} cublas cudnn) mmdeploy_export(${PROJECT_NAME}_obj) # Build module library. It is used to convert onnx model to tensorrt engine diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp index 431f2dd63b..3bb08a5e22 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp @@ -9,225 +9,314 @@ #include "nms/kernel.h" #include "trt_serialize.hpp" -namespace mmdeploy { -using namespace nvinfer1; -using nvinfer1::plugin::NMSParameters; - -namespace { -static const char* NMS_PLUGIN_VERSION{"1"}; -static const char* NMS_PLUGIN_NAME{"TRTBatchedNMS"}; -} // namespace - -TRTBatchedNMS::TRTBatchedNMS(const std::string& name, NMSParameters params, bool returnIndex) - : TRTPluginBase(name), param(params), mReturnIndex(returnIndex) {} - -TRTBatchedNMS::TRTBatchedNMS(const std::string& name, const void* data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, ¶m); - deserialize_value(&data, &length, &mClipBoxes); - deserialize_value(&data, &length, &mReturnIndex); -} - -int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT { - int num = mReturnIndex ? 3 : 2; - return num; -} - -nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { - ASSERT(nbInputs == 2); - ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); - ASSERT(inputs[0].nbDims == 4); - ASSERT(inputs[1].nbDims == 3); - - nvinfer1::DimsExprs ret; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = exprBuilder.constant(param.keepTopK); - switch (outputIndex) { - case 0: - ret.nbDims = 3; - ret.d[2] = exprBuilder.constant(5); - break; - case 1: - ret.nbDims = 2; - break; - case 2: - ret.nbDims = 2; - default: - break; - } - - return ret; -} - -size_t TRTBatchedNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT { - size_t batch_size = inputs[0].dims.d[0]; - size_t boxes_size = inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3]; - size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2]; - size_t num_priors = inputs[0].dims.d[1]; - bool shareLocation = (inputs[0].dims.d[2] == 1); - int topk = param.topK > 0 && param.topK <= inputs[1].dims.d[1] ? param.topK : inputs[1].dims.d[1]; - return detectionInferenceWorkspaceSize(shareLocation, batch_size, boxes_size, score_size, - param.numClasses, num_priors, topk, DataType::kFLOAT, - DataType::kFLOAT); -} - -int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, - void* const* outputs, void* workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - const void* const locData = inputs[0]; - const void* const confData = inputs[1]; - - void* nmsedDets = outputs[0]; - void* nmsedLabels = outputs[1]; - void* nmsedIndex = mReturnIndex ? outputs[2] : nullptr; - - size_t batch_size = inputDesc[0].dims.d[0]; - size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3]; - size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2]; - size_t num_priors = inputDesc[0].dims.d[1]; - bool shareLocation = (inputDesc[0].dims.d[2] == 1); - - int topk = - param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1]; - bool rotated = false; - pluginStatus_t status = nmsInference( - stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId, - num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold, - DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, nmsedIndex, - workSpace, param.isNormalized, false, mClipBoxes, rotated); - ASSERT(status == STATUS_SUCCESS); - - return 0; -} - -size_t TRTBatchedNMS::getSerializationSize() const TRT_NOEXCEPT { - // NMSParameters - return sizeof(NMSParameters) + sizeof(mClipBoxes) + sizeof(mReturnIndex); -} - -void TRTBatchedNMS::serialize(void* buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, param); - serialize_value(&buffer, mClipBoxes); - serialize_value(&buffer, mReturnIndex); -} - -void TRTBatchedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments -} - -bool TRTBatchedNMS::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 3 || pos == 4) { - return ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - } - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -const char* TRTBatchedNMS::getPluginType() const TRT_NOEXCEPT { return NMS_PLUGIN_NAME; } - -const char* TRTBatchedNMS::getPluginVersion() const TRT_NOEXCEPT { return NMS_PLUGIN_VERSION; } - -IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT { - auto* plugin = new TRTBatchedNMS(mLayerName, param, mReturnIndex); - plugin->setPluginNamespace(mNamespace.c_str()); - plugin->setClipParam(mClipBoxes); - return plugin; -} - -nvinfer1::DataType TRTBatchedNMS::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, - int nbInputs) const TRT_NOEXCEPT { - ASSERT(index >= 0 && index < this->getNbOutputs()); - if (index == 1 || index == 2) { - return nvinfer1::DataType::kINT32; - } - return inputTypes[0]; -} - -void TRTBatchedNMS::setClipParam(bool clip) { mClipBoxes = clip; } - -TRTBatchedNMSCreator::TRTBatchedNMSCreator() { - mPluginAttributes.emplace_back( - PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("topk", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back( - PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back( - PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("return_index", nullptr, PluginFieldType::kINT32, 1)); - - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char* TRTBatchedNMSCreator::getPluginName() const TRT_NOEXCEPT { return NMS_PLUGIN_NAME; } - -const char* TRTBatchedNMSCreator::getPluginVersion() const TRT_NOEXCEPT { - return NMS_PLUGIN_VERSION; -} - -IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(const char* name, - const PluginFieldCollection* fc) TRT_NOEXCEPT { - const PluginField* fields = fc->fields; - bool clipBoxes = true; - bool returnIndex = false; - nvinfer1::plugin::NMSParameters params{}; - - for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; - if (!strcmp(attrName, "background_label_id")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.backgroundLabelId = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "num_classes")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.numClasses = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "topk")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.topK = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "keep_topk")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.keepTopK = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "score_threshold")) { - ASSERT(fields[i].type == PluginFieldType::kFLOAT32); - params.scoreThreshold = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "iou_threshold")) { - ASSERT(fields[i].type == PluginFieldType::kFLOAT32); - params.iouThreshold = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "is_normalized")) { - params.isNormalized = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "clip_boxes")) { - clipBoxes = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "return_index")) { - returnIndex = *(static_cast(fields[i].data)); - } - } - - TRTBatchedNMS* plugin = new TRTBatchedNMS(name, params, returnIndex); - plugin->setClipParam(clipBoxes); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; -} - -IPluginV2Ext* TRTBatchedNMSCreator::deserializePlugin(const char* name, const void* serialData, - size_t serialLength) TRT_NOEXCEPT { - // This object will be deleted when the network is destroyed, which will - // call NMS::destroy() - TRTBatchedNMS* plugin = new TRTBatchedNMS(name, serialData, serialLength); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTBatchedNMSCreator); +namespace mmdeploy +{ + using namespace nvinfer1; + using nvinfer1::plugin::NMSParameters; + + namespace + { + static const char* NMS_PLUGIN_VERSION{"1"}; + static const char* NMS_PLUGIN_NAME{"TRTBatchedNMS"}; + } // namespace + + TRTBatchedNMS::TRTBatchedNMS(const std::string& name, NMSParameters params, bool returnIndex) + : TRTPluginBase(name) + , param(params) + , mReturnIndex(returnIndex) + { + } + + TRTBatchedNMS::TRTBatchedNMS(const std::string& name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, ¶m); + deserialize_value(&data, &length, &mClipBoxes); + deserialize_value(&data, &length, &mReturnIndex); + } + + int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT + { + int num = mReturnIndex ? 3 : 2; + return num; + } + + nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + ASSERT(nbInputs == 2); + ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); + ASSERT(inputs[0].nbDims == 4); + ASSERT(inputs[1].nbDims == 3); + + nvinfer1::DimsExprs ret; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = exprBuilder.constant(param.keepTopK); + switch (outputIndex) + { + case 0: + ret.nbDims = 3; + ret.d[2] = exprBuilder.constant(5); + break; + case 1: + ret.nbDims = 2; + break; + case 2: + ret.nbDims = 2; + default: + break; + } + + return ret; + } + + size_t TRTBatchedNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + size_t batch_size = inputs[0].dims.d[0]; + size_t boxes_size = inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3]; + size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2]; + size_t num_priors = inputs[0].dims.d[1]; + bool shareLocation = (inputs[0].dims.d[2] == 1); + int topk = param.topK > 0 && param.topK <= inputs[1].dims.d[1] ? param.topK : inputs[1].dims.d[1]; + return detectionInferenceWorkspaceSize(shareLocation, + batch_size, + boxes_size, + score_size, + param.numClasses, + num_priors, + topk, + DataType::kFLOAT, + DataType::kFLOAT); + } + + int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + const void* const locData = inputs[0]; + const void* const confData = inputs[1]; + + void* nmsedDets = outputs[0]; + void* nmsedLabels = outputs[1]; + void* nmsedIndex = mReturnIndex ? outputs[2] : nullptr; + + size_t batch_size = inputDesc[0].dims.d[0]; + size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3]; + size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2]; + size_t num_priors = inputDesc[0].dims.d[1]; + bool shareLocation = (inputDesc[0].dims.d[2] == 1); + + int topk = + param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1]; + bool rotated = false; + pluginStatus_t status = nmsInference(stream, + batch_size, + boxes_size, + score_size, + shareLocation, + param.backgroundLabelId, + num_priors, + param.numClasses, + topk, + param.keepTopK, + param.scoreThreshold, + param.iouThreshold, + DataType::kFLOAT, + locData, + DataType::kFLOAT, + confData, + nmsedDets, + nmsedLabels, + nmsedIndex, + workSpace, + param.isNormalized, + false, + mClipBoxes, + rotated); + ASSERT(status == STATUS_SUCCESS); + + return 0; + } + + size_t TRTBatchedNMS::getSerializationSize() const TRT_NOEXCEPT + { + // NMSParameters + return sizeof(NMSParameters) + sizeof(mClipBoxes) + sizeof(mReturnIndex); + } + + void TRTBatchedNMS::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, param); + serialize_value(&buffer, mClipBoxes); + serialize_value(&buffer, mReturnIndex); + } + + void TRTBatchedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + } + + bool TRTBatchedNMS::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 3 || pos == 4) + { + return ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + const char* TRTBatchedNMS::getPluginType() const TRT_NOEXCEPT + { + return NMS_PLUGIN_NAME; + } + + const char* TRTBatchedNMS::getPluginVersion() const TRT_NOEXCEPT + { + return NMS_PLUGIN_VERSION; + } + + IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT + { + auto* plugin = new TRTBatchedNMS(mLayerName, param, mReturnIndex); + plugin->setPluginNamespace(mNamespace.c_str()); + plugin->setClipParam(mClipBoxes); + return plugin; + } + + nvinfer1::DataType TRTBatchedNMS::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + ASSERT(index >= 0 && index < this->getNbOutputs()); + if (index == 1 || index == 2) + { + return nvinfer1::DataType::kINT32; + } + return inputTypes[0]; + } + + void TRTBatchedNMS::setClipParam(bool clip) + { + mClipBoxes = clip; + } + + TRTBatchedNMSCreator::TRTBatchedNMSCreator() + { + mPluginAttributes.emplace_back( + PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("topk", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back( + PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back( + PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("return_index", nullptr, PluginFieldType::kINT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTBatchedNMSCreator::getPluginName() const TRT_NOEXCEPT + { + return NMS_PLUGIN_NAME; + } + + const char* TRTBatchedNMSCreator::getPluginVersion() const TRT_NOEXCEPT + { + return NMS_PLUGIN_VERSION; + } + + IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(const char* name, + const PluginFieldCollection* fc) TRT_NOEXCEPT + { + const PluginField* fields = fc->fields; + bool clipBoxes = true; + bool returnIndex = false; + nvinfer1::plugin::NMSParameters params{}; + + for (int i = 0; i < fc->nbFields; ++i) + { + const char* attrName = fields[i].name; + if (!strcmp(attrName, "background_label_id")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.backgroundLabelId = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "num_classes")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.numClasses = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "topk")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.topK = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "keep_topk")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.keepTopK = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "score_threshold")) + { + ASSERT(fields[i].type == PluginFieldType::kFLOAT32); + params.scoreThreshold = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "iou_threshold")) + { + ASSERT(fields[i].type == PluginFieldType::kFLOAT32); + params.iouThreshold = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "is_normalized")) + { + params.isNormalized = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "clip_boxes")) + { + clipBoxes = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "return_index")) + { + returnIndex = *(static_cast(fields[i].data)); + } + } + + TRTBatchedNMS* plugin = new TRTBatchedNMS(name, params, returnIndex); + plugin->setClipParam(clipBoxes); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + + IPluginV2Ext* TRTBatchedNMSCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + // This object will be deleted when the network is destroyed, which will + // call NMS::destroy() + TRTBatchedNMS* plugin = new TRTBatchedNMS(name, serialData, serialLength); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTBatchedNMSCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp index d1e5d643db..b1d77a54d0 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp @@ -8,75 +8,94 @@ #include "NvInferPluginUtils.h" #include "trt_plugin_base.hpp" -namespace mmdeploy { +namespace mmdeploy +{ -enum NMSReturnType { RETURN_DETS = 1, RETURN_INDEX = 1 << 1 }; + enum NMSReturnType + { + RETURN_DETS = 1, + RETURN_INDEX = 1 << 1 + }; -class TRTBatchedNMS : public TRTPluginBase { - public: - TRTBatchedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param, bool returnIndex); + class TRTBatchedNMS : public TRTPluginBase + { + public: + TRTBatchedNMS(const std::string& name, + nvinfer1::plugin::NMSParameters param, + bool returnIndex); - TRTBatchedNMS(const std::string& name, const void* data, size_t length); + TRTBatchedNMS(const std::string& name, const void* data, size_t length); - ~TRTBatchedNMS() TRT_NOEXCEPT override = default; + ~TRTBatchedNMS() TRT_NOEXCEPT override = default; - int getNbOutputs() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, - int nbInputs, nvinfer1::IExprBuilder& exprBuilder) - TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, - void* const* outputs, void* workSpace, cudaStream_t stream) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void* buffer) const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, - int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; - const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginType() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputType, - int nbInputs) const TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputType, + int nbInputs) const TRT_NOEXCEPT override; - void setClipParam(bool clip); + void setClipParam(bool clip); - private: - nvinfer1::plugin::NMSParameters param{}; - bool mClipBoxes{}; - bool mReturnIndex{}; -}; + private: + nvinfer1::plugin::NMSParameters param{}; + bool mClipBoxes{}; + bool mReturnIndex{}; + }; -class TRTBatchedNMSCreator : public TRTPluginCreatorBase { - public: - TRTBatchedNMSCreator(); + class TRTBatchedNMSCreator : public TRTPluginCreatorBase + { + public: + TRTBatchedNMSCreator(); - ~TRTBatchedNMSCreator() TRT_NOEXCEPT override = default; + ~TRTBatchedNMSCreator() TRT_NOEXCEPT override = default; - const char* getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2Ext* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2Ext* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; - nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_BATCHED_NMS_PLUGIN_CUSTOM_H diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp index 9d977bc937..80b5be6abc 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp @@ -8,222 +8,295 @@ #include "nms/kernel.h" #include "trt_serialize.hpp" -namespace mmdeploy { -using namespace nvinfer1; -using nvinfer1::plugin::NMSParameters; - -namespace { -static const char* NMS_PLUGIN_VERSION{"1"}; -static const char* NMS_PLUGIN_NAME{"TRTBatchedRotatedNMS"}; -} // namespace - -TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, NMSParameters params) - : TRTPluginBase(name), param(params) {} - -TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, ¶m); - deserialize_value(&data, &length, &mClipBoxes); -} - -int TRTBatchedRotatedNMS::getNbOutputs() const TRT_NOEXCEPT { return 2; } - -nvinfer1::DimsExprs TRTBatchedRotatedNMS::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { - ASSERT(nbInputs == 2); - ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); - ASSERT(inputs[0].nbDims == 4); - ASSERT(inputs[1].nbDims == 3); - - nvinfer1::DimsExprs ret; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = exprBuilder.constant(param.keepTopK); - switch (outputIndex) { - case 0: - ret.nbDims = 3; - ret.d[2] = exprBuilder.constant(6); - break; - case 1: - ret.nbDims = 2; - break; - default: - break; - } - - return ret; -} - -size_t TRTBatchedRotatedNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT { - size_t batch_size = inputs[0].dims.d[0]; - size_t boxes_size = inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3]; - size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2]; - size_t num_priors = inputs[0].dims.d[1]; - bool shareLocation = (inputs[0].dims.d[2] == 1); - int topk = param.topK > 0 && param.topK <= inputs[1].dims.d[1] ? param.topK : inputs[1].dims.d[1]; - return detectionInferenceWorkspaceSize(shareLocation, batch_size, boxes_size, score_size, - param.numClasses, num_priors, topk, DataType::kFLOAT, - DataType::kFLOAT); -} - -int TRTBatchedRotatedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - const void* const locData = inputs[0]; - const void* const confData = inputs[1]; - - void* nmsedDets = outputs[0]; - void* nmsedLabels = outputs[1]; - - size_t batch_size = inputDesc[0].dims.d[0]; - size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3]; - size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2]; - size_t num_priors = inputDesc[0].dims.d[1]; - bool shareLocation = (inputDesc[0].dims.d[2] == 1); - - int topk = - param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1]; - bool rotated = true; - pluginStatus_t status = nmsInference( - stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId, - num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold, - DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, nullptr, - workSpace, param.isNormalized, false, mClipBoxes, rotated); - ASSERT(status == STATUS_SUCCESS); - - return 0; -} - -size_t TRTBatchedRotatedNMS::getSerializationSize() const TRT_NOEXCEPT { - // NMSParameters, - return sizeof(NMSParameters) + sizeof(bool); -} - -void TRTBatchedRotatedNMS::serialize(void* buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, param); - serialize_value(&buffer, mClipBoxes); -} - -void TRTBatchedRotatedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments -} - -bool TRTBatchedRotatedNMS::supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 3) { - return ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - } - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -const char* TRTBatchedRotatedNMS::getPluginType() const TRT_NOEXCEPT { return NMS_PLUGIN_NAME; } - -const char* TRTBatchedRotatedNMS::getPluginVersion() const TRT_NOEXCEPT { - return NMS_PLUGIN_VERSION; -} - -IPluginV2DynamicExt* TRTBatchedRotatedNMS::clone() const TRT_NOEXCEPT { - auto* plugin = new TRTBatchedRotatedNMS(mLayerName, param); - plugin->setPluginNamespace(mNamespace.c_str()); - plugin->setClipParam(mClipBoxes); - return plugin; -} - -nvinfer1::DataType TRTBatchedRotatedNMS::getOutputDataType(int index, - const nvinfer1::DataType* inputTypes, - int nbInputs) const TRT_NOEXCEPT { - ASSERT(index >= 0 && index < this->getNbOutputs()); - if (index == 1) { - return nvinfer1::DataType::kINT32; - } - return inputTypes[0]; -} - -void TRTBatchedRotatedNMS::setClipParam(bool clip) { mClipBoxes = clip; } - -TRTBatchedRotatedNMSCreator::TRTBatchedRotatedNMSCreator() { - mPluginAttributes.emplace_back( - PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("topk", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back( - PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back( - PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1)); - - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char* TRTBatchedRotatedNMSCreator::getPluginName() const TRT_NOEXCEPT { - return NMS_PLUGIN_NAME; -} - -const char* TRTBatchedRotatedNMSCreator::getPluginVersion() const TRT_NOEXCEPT { - return NMS_PLUGIN_VERSION; -} - -IPluginV2Ext* TRTBatchedRotatedNMSCreator::createPlugin( - const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT { - const PluginField* fields = fc->fields; - bool clipBoxes = true; - nvinfer1::plugin::NMSParameters params{}; - - for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; - if (!strcmp(attrName, "background_label_id")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.backgroundLabelId = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "num_classes")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.numClasses = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "topk")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.topK = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "keep_topk")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.keepTopK = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "score_threshold")) { - ASSERT(fields[i].type == PluginFieldType::kFLOAT32); - params.scoreThreshold = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "iou_threshold")) { - ASSERT(fields[i].type == PluginFieldType::kFLOAT32); - params.iouThreshold = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "is_normalized")) { - params.isNormalized = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "clip_boxes")) { - clipBoxes = *(static_cast(fields[i].data)); - } - } - - TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, params); - plugin->setClipParam(clipBoxes); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; -} - -IPluginV2Ext* TRTBatchedRotatedNMSCreator::deserializePlugin(const char* name, - const void* serialData, - size_t serialLength) TRT_NOEXCEPT { - // This object will be deleted when the network is destroyed, which will - // call NMS::destroy() - TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, serialData, serialLength); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTBatchedRotatedNMSCreator); +namespace mmdeploy +{ + using namespace nvinfer1; + using nvinfer1::plugin::NMSParameters; + + namespace + { + static const char* NMS_PLUGIN_VERSION{"1"}; + static const char* NMS_PLUGIN_NAME{"TRTBatchedRotatedNMS"}; + } // namespace + + TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, NMSParameters params) + : TRTPluginBase(name) + , param(params) + { + } + + TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, ¶m); + deserialize_value(&data, &length, &mClipBoxes); + } + + int TRTBatchedRotatedNMS::getNbOutputs() const TRT_NOEXCEPT + { + return 2; + } + + nvinfer1::DimsExprs TRTBatchedRotatedNMS::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + ASSERT(nbInputs == 2); + ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); + ASSERT(inputs[0].nbDims == 4); + ASSERT(inputs[1].nbDims == 3); + + nvinfer1::DimsExprs ret; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = exprBuilder.constant(param.keepTopK); + switch (outputIndex) + { + case 0: + ret.nbDims = 3; + ret.d[2] = exprBuilder.constant(6); + break; + case 1: + ret.nbDims = 2; + break; + default: + break; + } + + return ret; + } + + size_t TRTBatchedRotatedNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + size_t batch_size = inputs[0].dims.d[0]; + size_t boxes_size = inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3]; + size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2]; + size_t num_priors = inputs[0].dims.d[1]; + bool shareLocation = (inputs[0].dims.d[2] == 1); + int topk = param.topK > 0 && param.topK <= inputs[1].dims.d[1] ? param.topK : inputs[1].dims.d[1]; + return detectionInferenceWorkspaceSize(shareLocation, batch_size, boxes_size, score_size, param.numClasses, num_priors, topk, DataType::kFLOAT, DataType::kFLOAT); + } + + int TRTBatchedRotatedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + const void* const locData = inputs[0]; + const void* const confData = inputs[1]; + + void* nmsedDets = outputs[0]; + void* nmsedLabels = outputs[1]; + + size_t batch_size = inputDesc[0].dims.d[0]; + size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3]; + size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2]; + size_t num_priors = inputDesc[0].dims.d[1]; + bool shareLocation = (inputDesc[0].dims.d[2] == 1); + + int topk = + param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1]; + bool rotated = true; + pluginStatus_t status = nmsInference( + stream, + batch_size, + boxes_size, + score_size, + shareLocation, + param.backgroundLabelId, + num_priors, + param.numClasses, + topk, + param.keepTopK, + param.scoreThreshold, + param.iouThreshold, + DataType::kFLOAT, + locData, + DataType::kFLOAT, + confData, + nmsedDets, + nmsedLabels, + nullptr, + workSpace, + param.isNormalized, + false, + mClipBoxes, + rotated); + ASSERT(status == STATUS_SUCCESS); + + return 0; + } + + size_t TRTBatchedRotatedNMS::getSerializationSize() const TRT_NOEXCEPT + { + // NMSParameters, + return sizeof(NMSParameters) + sizeof(bool); + } + + void TRTBatchedRotatedNMS::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, param); + serialize_value(&buffer, mClipBoxes); + } + + void TRTBatchedRotatedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + } + + bool TRTBatchedRotatedNMS::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 3) + { + return ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + const char* TRTBatchedRotatedNMS::getPluginType() const TRT_NOEXCEPT + { + return NMS_PLUGIN_NAME; + } + + const char* TRTBatchedRotatedNMS::getPluginVersion() const TRT_NOEXCEPT + { + return NMS_PLUGIN_VERSION; + } + + IPluginV2DynamicExt* TRTBatchedRotatedNMS::clone() const TRT_NOEXCEPT + { + auto* plugin = new TRTBatchedRotatedNMS(mLayerName, param); + plugin->setPluginNamespace(mNamespace.c_str()); + plugin->setClipParam(mClipBoxes); + return plugin; + } + + nvinfer1::DataType TRTBatchedRotatedNMS::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + ASSERT(index >= 0 && index < this->getNbOutputs()); + if (index == 1) + { + return nvinfer1::DataType::kINT32; + } + return inputTypes[0]; + } + + void TRTBatchedRotatedNMS::setClipParam(bool clip) + { + mClipBoxes = clip; + } + + TRTBatchedRotatedNMSCreator::TRTBatchedRotatedNMSCreator() + { + mPluginAttributes.emplace_back( + PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("topk", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back( + PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back( + PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTBatchedRotatedNMSCreator::getPluginName() const TRT_NOEXCEPT + { + return NMS_PLUGIN_NAME; + } + + const char* TRTBatchedRotatedNMSCreator::getPluginVersion() const TRT_NOEXCEPT + { + return NMS_PLUGIN_VERSION; + } + + IPluginV2Ext* TRTBatchedRotatedNMSCreator::createPlugin( + const char* name, + const PluginFieldCollection* fc) TRT_NOEXCEPT + { + const PluginField* fields = fc->fields; + bool clipBoxes = true; + nvinfer1::plugin::NMSParameters params{}; + + for (int i = 0; i < fc->nbFields; ++i) + { + const char* attrName = fields[i].name; + if (!strcmp(attrName, "background_label_id")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.backgroundLabelId = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "num_classes")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.numClasses = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "topk")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.topK = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "keep_topk")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.keepTopK = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "score_threshold")) + { + ASSERT(fields[i].type == PluginFieldType::kFLOAT32); + params.scoreThreshold = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "iou_threshold")) + { + ASSERT(fields[i].type == PluginFieldType::kFLOAT32); + params.iouThreshold = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "is_normalized")) + { + params.isNormalized = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "clip_boxes")) + { + clipBoxes = *(static_cast(fields[i].data)); + } + } + + TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, params); + plugin->setClipParam(clipBoxes); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + + IPluginV2Ext* TRTBatchedRotatedNMSCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + // This object will be deleted when the network is destroyed, which will + // call NMS::destroy() + TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, serialData, serialLength); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTBatchedRotatedNMSCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp index 66479eb7e7..be156dc9c9 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp @@ -7,72 +7,85 @@ #include "NvInferPluginUtils.h" #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTBatchedRotatedNMS : public TRTPluginBase { - public: - TRTBatchedRotatedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param); +namespace mmdeploy +{ + class TRTBatchedRotatedNMS : public TRTPluginBase + { + public: + TRTBatchedRotatedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param); - TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length); + TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length); - ~TRTBatchedRotatedNMS() TRT_NOEXCEPT override = default; + ~TRTBatchedRotatedNMS() TRT_NOEXCEPT override = default; - int getNbOutputs() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, - int nbInputs, nvinfer1::IExprBuilder& exprBuilder) - TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, - void* const* outputs, void* workSpace, cudaStream_t stream) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void* buffer) const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, - int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; - const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginType() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputType, - int nbInputs) const TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputType, + int nbInputs) const TRT_NOEXCEPT override; - void setClipParam(bool clip); + void setClipParam(bool clip); - private: - nvinfer1::plugin::NMSParameters param{}; - bool mClipBoxes{}; -}; + private: + nvinfer1::plugin::NMSParameters param{}; + bool mClipBoxes{}; + }; -class TRTBatchedRotatedNMSCreator : public TRTPluginCreatorBase { - public: - TRTBatchedRotatedNMSCreator(); + class TRTBatchedRotatedNMSCreator : public TRTPluginCreatorBase + { + public: + TRTBatchedRotatedNMSCreator(); - ~TRTBatchedRotatedNMSCreator() TRT_NOEXCEPT override = default; + ~TRTBatchedRotatedNMSCreator() TRT_NOEXCEPT override = default; - const char* getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2Ext* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2Ext* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; - nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp index 0f236e4956..6f46a9f295 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp @@ -10,176 +10,228 @@ #include "trt_serialize.hpp" using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"TRTBicubicInterpolate"}; -} // namespace - -TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string &name, - std::vector scale_factor, bool align_corners) - : TRTPluginBase(name), mScaleFactor(scale_factor), mAlignCorners(align_corners) {} - -TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mScaleFactor); - deserialize_value(&data, &length, &mAlignCorners); -} - -nvinfer1::IPluginV2DynamicExt *TRTBicubicInterpolate::clone() const TRT_NOEXCEPT { - TRTBicubicInterpolate *plugin = - new TRTBicubicInterpolate(mLayerName, mScaleFactor, mAlignCorners); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTBicubicInterpolate::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[0].d[1]; - auto height = exprBuilder.constant(mScaleFactor[0]); - auto width = exprBuilder.constant(mScaleFactor[1]); - auto d2 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], *height); - auto d3 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[3], *width); - ret.d[2] = d2; - ret.d[3] = d3; - - return ret; -} - -bool TRTBicubicInterpolate::supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -void TRTBicubicInterpolate::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t TRTBicubicInterpolate::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTBicubicInterpolate::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int batch = inputDesc[0].dims.d[0]; - int channels = inputDesc[0].dims.d[1]; - int height = inputDesc[0].dims.d[2]; - int width = inputDesc[0].dims.d[3]; - - int height_out = outputDesc[0].dims.d[2]; - int width_out = outputDesc[0].dims.d[3]; - const void *x = inputs[0]; - void *output = outputs[0]; - - // TODO: add fp16 support - auto data_type = inputDesc[0].type; - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - bicubic_interpolate((float *)x, (float *)output, batch, channels, height, width, - height_out, width_out, mAlignCorners, stream); - break; - default: - return 1; - break; - } - - return 0; -} - -nvinfer1::DataType TRTBicubicInterpolate::getOutputDataType(int index, - const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *TRTBicubicInterpolate::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTBicubicInterpolate::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTBicubicInterpolate::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTBicubicInterpolate::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mScaleFactor) + serialized_size(mAlignCorners); -} - -void TRTBicubicInterpolate::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mScaleFactor); - serialize_value(&buffer, mAlignCorners); -} - -////////////////////// creator ///////////////////////////// - -TRTBicubicInterpolateCreator::TRTBicubicInterpolateCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(nvinfer1::PluginField("scale_factor")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("align_corners")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTBicubicInterpolateCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTBicubicInterpolateCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - nvinfer1::Dims size{2, {1, 1}}; - std::vector scale_factor; - bool align_corners = 1; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("scale_factor") == 0) { - int data_size = (fc->fields[i].length); - if (data_size != 2) { - data_size = data_size / sizeof(float); - } - ASSERT(data_size == 2) - const float *data_start = static_cast(fc->fields[i].data); - scale_factor = std::vector(data_start, data_start + data_size); - } - - if (field_name.compare("align_corners") == 0) { - align_corners = static_cast(fc->fields[i].data)[0]; - } - } - - TRTBicubicInterpolate *plugin = new TRTBicubicInterpolate(name, scale_factor, align_corners); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTBicubicInterpolate(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(TRTBicubicInterpolateCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"TRTBicubicInterpolate"}; + } // namespace + + TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string& name, + std::vector scale_factor, + bool align_corners) + : TRTPluginBase(name) + , mScaleFactor(scale_factor) + , mAlignCorners(align_corners) + { + } + + TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mScaleFactor); + deserialize_value(&data, &length, &mAlignCorners); + } + + nvinfer1::IPluginV2DynamicExt* TRTBicubicInterpolate::clone() const TRT_NOEXCEPT + { + TRTBicubicInterpolate* plugin = + new TRTBicubicInterpolate(mLayerName, mScaleFactor, mAlignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTBicubicInterpolate::getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + auto height = exprBuilder.constant(mScaleFactor[0]); + auto width = exprBuilder.constant(mScaleFactor[1]); + auto d2 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], *height); + auto d3 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[3], *width); + ret.d[2] = d2; + ret.d[3] = d3; + + return ret; + } + + bool TRTBicubicInterpolate::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + void TRTBicubicInterpolate::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + + size_t TRTBicubicInterpolate::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTBicubicInterpolate::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + + int height_out = outputDesc[0].dims.d[2]; + int width_out = outputDesc[0].dims.d[3]; + const void* x = inputs[0]; + void* output = outputs[0]; + + // TODO: add fp16 support + auto data_type = inputDesc[0].type; + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + bicubic_interpolate((float*)x, + (float*)output, + batch, + channels, + height, + width, + height_out, + width_out, + mAlignCorners, + stream); + break; + default: + return 1; + break; + } + + return 0; + } + + nvinfer1::DataType TRTBicubicInterpolate::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* TRTBicubicInterpolate::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTBicubicInterpolate::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTBicubicInterpolate::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTBicubicInterpolate::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mScaleFactor) + serialized_size(mAlignCorners); + } + + void TRTBicubicInterpolate::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mScaleFactor); + serialize_value(&buffer, mAlignCorners); + } + + ////////////////////// creator ///////////////////////////// + + TRTBicubicInterpolateCreator::TRTBicubicInterpolateCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("scale_factor")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("align_corners")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTBicubicInterpolateCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTBicubicInterpolateCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTBicubicInterpolateCreator::createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + nvinfer1::Dims size{2, {1, 1}}; + std::vector scale_factor; + bool align_corners = 1; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("scale_factor") == 0) + { + int data_size = (fc->fields[i].length); + if (data_size != 2) + { + data_size = data_size / sizeof(float); + } + ASSERT(data_size == 2) + const float* data_start = static_cast(fc->fields[i].data); + scale_factor = std::vector(data_start, data_start + data_size); + } + + if (field_name.compare("align_corners") == 0) + { + align_corners = static_cast(fc->fields[i].data)[0]; + } + } + + TRTBicubicInterpolate* plugin = new TRTBicubicInterpolate(name, scale_factor, align_corners); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTBicubicInterpolateCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTBicubicInterpolate(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(TRTBicubicInterpolateCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp index 37ad7cf9ff..9a66c5e718 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp @@ -7,61 +7,78 @@ #include #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTBicubicInterpolate : public TRTPluginBase { - public: - TRTBicubicInterpolate(const std::string &name, std::vector scale_factor, - bool align_corners); - - TRTBicubicInterpolate(const std::string name, const void *data, size_t length); - - TRTBicubicInterpolate() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; - - private: - std::vector mScaleFactor; - bool mAlignCorners; -}; - -class TRTBicubicInterpolateCreator : public TRTPluginCreatorBase { - public: - TRTBicubicInterpolateCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class TRTBicubicInterpolate : public TRTPluginBase + { + public: + TRTBicubicInterpolate(const std::string& name, std::vector scale_factor, bool align_corners); + + TRTBicubicInterpolate(const std::string name, const void* data, size_t length); + + TRTBicubicInterpolate() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + + private: + std::vector mScaleFactor; + bool mAlignCorners; + }; + + class TRTBicubicInterpolateCreator : public TRTPluginCreatorBase + { + public: + TRTBicubicInterpolateCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_BICUBIC_INTERPOLATE_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu index efb078c431..2c189e0a45 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu @@ -12,159 +12,236 @@ // Based on // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm -template -__device__ __forceinline__ static scalar_t cubic_convolution1(scalar_t x, scalar_t A) { - return ((A + 2) * x - (A + 3)) * x * x + 1; +template +__device__ __forceinline__ static scalar_t cubic_convolution1(scalar_t x, scalar_t A) +{ + return ((A + 2) * x - (A + 3)) * x * x + 1; } -template -__device__ __forceinline__ static scalar_t cubic_convolution2(scalar_t x, scalar_t A) { - return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +template +__device__ __forceinline__ static scalar_t cubic_convolution2(scalar_t x, scalar_t A) +{ + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; } -template +template __device__ __forceinline__ static void get_cubic_upsample_coefficients(scalar_t coeffs[4], - scalar_t t) { - scalar_t A = -0.75; - - scalar_t x1 = t; - coeffs[0] = cubic_convolution2(x1 + 1.0, A); - coeffs[1] = cubic_convolution1(x1, A); - - // opposite coefficients - scalar_t x2 = 1.0 - t; - coeffs[2] = cubic_convolution1(x2, A); - coeffs[3] = cubic_convolution2(x2 + 1.0, A); + scalar_t t) +{ + scalar_t A = -0.75; + + scalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + scalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); } -template -__device__ __forceinline__ static scalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, - scalar_t x3, scalar_t t) { - scalar_t coeffs[4]; - get_cubic_upsample_coefficients(coeffs, t); - - return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +template +__device__ __forceinline__ static scalar_t cubic_interp1d(scalar_t x0, + scalar_t x1, + scalar_t x2, + scalar_t x3, + scalar_t t) +{ + scalar_t coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; } /* Used by UpSampleBicubic2d.cu */ -template -__device__ __forceinline__ static scalar_t upsample_get_value_bounded(const scalar_t *data, - int batch, int channel, - int batchsize, int channels, - int height, int width, int y, - int x) { - int access_y = max(min(y, height - 1), 0); - int access_x = max(min(x, width - 1), 0); - return data[batch * channels * height * width + channel * height * width + access_y * width + - access_x]; +template +__device__ __forceinline__ static scalar_t upsample_get_value_bounded(const scalar_t* data, + int batch, + int channel, + int batchsize, + int channels, + int height, + int width, + int y, + int x) +{ + int access_y = max(min(y, height - 1), 0); + int access_x = max(min(x, width - 1), 0); + return data[batch * channels * height * width + channel * height * width + + access_y * width + + access_x]; } -template -__device__ __forceinline__ scalar_t -area_pixel_compute_source_index(scalar_t scale, int64_t dst_index, bool align_corners, bool cubic) { - if (align_corners) { - return scale * dst_index; - } else { - scalar_t src_idx = scale * (dst_index + 0.5) - 0.5; - // [Note] Follow Opencv resize logic: - // We allow negative src_idx here and later will use - // dx = src_idx - floorf(src_idx) - // to compute the "distance"(which affects weights). - // For linear modes, weight distribution doesn't matter - // for negative indices as they use 2 pixels to interpolate. - // For example, [-1, 0], they both use pixel 0 value so it - // doesn't affect if we bound the src_idx to 0 or not. - // TODO: Our current linear mode impls use unbound indices - // where we should and then remove this cubic flag. - // This matters in cubic mode, as we might need [-1, 0, 1, 2] - // to interpolate and the weights can be affected. - return (!cubic && src_idx < 0) ? scalar_t(0) : src_idx; - } +template +__device__ __forceinline__ scalar_t area_pixel_compute_source_index(scalar_t scale, + int64_t dst_index, + bool align_corners, + bool cubic) +{ + if (align_corners) + { + return scale * dst_index; + } + else + { + scalar_t src_idx = scale * (dst_index + 0.5) - 0.5; + // [Note] Follow Opencv resize logic: + // We allow negative src_idx here and later will use + // dx = src_idx - floorf(src_idx) + // to compute the "distance"(which affects weights). + // For linear modes, weight distribution doesn't matter + // for negative indices as they use 2 pixels to interpolate. + // For example, [-1, 0], they both use pixel 0 value so it + // doesn't affect if we bound the src_idx to 0 or not. + // TODO: Our current linear mode impls use unbound indices + // where we should and then remove this cubic flag. + // This matters in cubic mode, as we might need [-1, 0, 1, 2] + // to interpolate and the weights can be affected. + return (!cubic && src_idx < 0) ? scalar_t(0) : src_idx; + } } // cubic interpolation pytorch -template -__global__ void resize_cubic_kernel_torch(const int num_elements, const scalar_t *src, - const int batchsize, const int channels, int srcWidth, - int srcHeight, scalar_t *dst, int dstWidth, int dstHeight, - bool align_corners, float height_scale, - float width_scale) { - CUDA_1D_KERNEL_LOOP(index, num_elements) { - // Special case: input and output are the same size, just copy - const int output_x = index % dstWidth; - const int output_y = index / dstWidth; - - if (srcHeight == dstHeight && srcWidth == dstWidth) { - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; c++) { - const scalar_t val = src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + - output_y * dstWidth + output_x]; - dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth + - output_x] = val; +template +__global__ void resize_cubic_kernel_torch(const int num_elements, + const scalar_t* src, + const int batchsize, + const int channels, + int srcWidth, + int srcHeight, + scalar_t* dst, + int dstWidth, + int dstHeight, + bool align_corners, + float height_scale, + float width_scale) +{ + CUDA_1D_KERNEL_LOOP(index, num_elements) + { + // Special case: input and output are the same size, just copy + const int output_x = index % dstWidth; + const int output_y = index / dstWidth; + + if (srcHeight == dstHeight && srcWidth == dstWidth) + { + for (int n = 0; n < batchsize; n++) + { + for (int c = 0; c < channels; c++) + { + const scalar_t val = src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + + output_y * dstWidth + + output_x]; + dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + + output_y * dstWidth + + output_x] = val; + } + } + return; } - } - return; - } - // Interpolation kernel - scalar_t real_x = - area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true); - int in_x = floorf(real_x); - scalar_t t_x = real_x - in_x; - - scalar_t real_y = - area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true); - int in_y = floorf(real_y); - scalar_t t_y = real_y - in_y; - - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; c++) { - scalar_t coefficients[4]; - - for (int k = 0; k < 4; k++) { - coefficients[k] = cubic_interp1d( - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x - 1), - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x + 0), - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x + 1), - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x + 2), - t_x); + // Interpolation kernel + scalar_t real_x = + area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true); + int in_x = floorf(real_x); + scalar_t t_x = real_x - in_x; + + scalar_t real_y = + area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true); + int in_y = floorf(real_y); + scalar_t t_y = real_y - in_y; + + for (int n = 0; n < batchsize; n++) + { + for (int c = 0; c < channels; c++) + { + scalar_t coefficients[4]; + + for (int k = 0; k < 4; k++) + { + coefficients[k] = cubic_interp1d( + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, in_y - 1 + k, in_x - 1), + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, in_y - 1 + k, in_x + 0), + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, in_y - 1 + k, in_x + 1), + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, in_y - 1 + k, in_x + 2), + t_x); + } + + dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + + output_y * dstWidth + + output_x] = scalar_t(cubic_interp1d(coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + t_y)); + } } - - dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth + - output_x] = scalar_t(cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], - coefficients[3], t_y)); - } } - } } -template -void resizeGPU(const scalar_t *pIn_d, scalar_t *pOut_d, int batch, int channels, int srcWidth, - int srcHeight, int dstWidth, int dstHeight, bool align_corners, - cudaStream_t stream) { - float height_scale = float(srcHeight) / dstHeight; - float width_scale = float(srcWidth) / dstWidth; - if (align_corners && dstWidth > 1 && dstHeight > 1) { - height_scale = (float)(srcHeight - 1) / (dstHeight - 1); - width_scale = (float)(srcWidth - 1) / (dstWidth - 1); - } - int n = batch * dstWidth * dstHeight * channels; - resize_cubic_kernel_torch<<>>( - dstWidth * dstHeight, pIn_d, batch, channels, srcWidth, srcHeight, pOut_d, dstWidth, - dstHeight, align_corners, height_scale, width_scale); +template +void resizeGPU(const scalar_t* pIn_d, + scalar_t* pOut_d, + int batch, + int channels, + int srcWidth, + int srcHeight, + int dstWidth, + int dstHeight, + bool align_corners, + cudaStream_t stream) +{ + float height_scale = float(srcHeight) / dstHeight; + float width_scale = float(srcWidth) / dstWidth; + if (align_corners && dstWidth > 1 && dstHeight > 1) + { + height_scale = (float)(srcHeight - 1) / (dstHeight - 1); + width_scale = (float)(srcWidth - 1) / (dstWidth - 1); + } + int n = batch * dstWidth * dstHeight * channels; + resize_cubic_kernel_torch<<>>(dstWidth * dstHeight, + pIn_d, + batch, + channels, + srcWidth, + srcHeight, + pOut_d, + dstWidth, + dstHeight, + align_corners, + height_scale, + width_scale); } -template -void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch, int channels, - int in_height, int in_width, int out_height, int out_width, - bool align_corners, cudaStream_t stream) { - resizeGPU(input, output, batch, channels, in_width, in_height, out_width, out_height, - align_corners, stream); +template +void bicubic_interpolate(const scalar_t* input, + scalar_t* output, + int batch, + int channels, + int in_height, + int in_width, + int out_height, + int out_width, + bool align_corners, + cudaStream_t stream) +{ + resizeGPU(input, + output, + batch, + channels, + in_width, + in_height, + out_width, + out_height, + align_corners, + stream); } -template void bicubic_interpolate(const float *input, float *output, int batch, int channels, - int in_height, int in_width, int out_height, int out_width, - bool align_corners, cudaStream_t stream); +template void bicubic_interpolate(const float* input, + float* output, + int batch, + int channels, + int in_height, + int in_width, + int out_height, + int out_width, + bool align_corners, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp index 66560f59f5..4ecf16c5fe 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp @@ -4,8 +4,15 @@ #include "common_cuda_helper.hpp" -template -void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch, int channels, - int in_height, int in_width, int out_height, int out_width, - bool align_corners, cudaStream_t stream); +template +void bicubic_interpolate(const scalar_t* input, + scalar_t* output, + int batch, + int channels, + int in_height, + int in_width, + int out_height, + int out_width, + bool align_corners, + cudaStream_t stream); #endif // TRT_BICUBIC_INTERPOLATE_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/common_cuda_helper.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/common_cuda_helper.hpp index c76cac8a32..c71de75638 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/common_cuda_helper.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/common_cuda_helper.hpp @@ -9,25 +9,27 @@ #include #define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) #define THREADS_PER_BLOCK 512 #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) -inline int GET_BLOCKS(const int N) { - int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); - int max_block_num = 4096; - return std::min(optimal_block_num, max_block_num); +inline int GET_BLOCKS(const int N) +{ + int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); + int max_block_num = 4096; + return std::min(optimal_block_num, max_block_num); } -#define cudaCheckError() \ - { \ - cudaError_t e = cudaGetLastError(); \ - if (e != cudaSuccess) { \ - printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(0); \ - } \ - } +#define cudaCheckError() \ + { \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) \ + { \ + printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(0); \ + } \ + } /** * Returns a view of the original tensor with its dimensions permuted. @@ -39,44 +41,61 @@ inline int GET_BLOCKS(const int N) { * @param[in] src_dim dim of src tensor * @param[in] stream cuda stream handle */ -template -void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim, - cudaStream_t stream = 0); +template +void memcpyPermute(scalar_t* dst, + const scalar_t* src, + int* src_size, + int* permute, + int src_dim, + cudaStream_t stream = 0); -template -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha, - const scalar_t* A, int lda, const scalar_t* B, int ldb, - const scalar_t* beta, scalar_t* C, int ldc); +template +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const scalar_t* alpha, + const scalar_t* A, + int lda, + const scalar_t* B, + int ldb, + const scalar_t* beta, + scalar_t* C, + int ldc); -template +template __device__ __forceinline__ scalar_t bilinear_interpolate(const scalar_t* __restrict__ input, - const int height, const int width, - scalar_t y, scalar_t x) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; + const int height, + const int width, + scalar_t y, + scalar_t x) +{ + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; - y = min(scalar_t(height - 1), max(scalar_t(0), y)); - x = min(scalar_t(width - 1), max(scalar_t(0), x)); + y = min(scalar_t(height - 1), max(scalar_t(0), y)); + x = min(scalar_t(width - 1), max(scalar_t(0), x)); - const int y_low = floor(y); - const int x_low = floor(x); - const int y_high = ceil(y); - const int x_high = ceil(x); + const int y_low = floor(y); + const int x_low = floor(x); + const int y_high = ceil(y); + const int x_high = ceil(x); - const scalar_t v1 = input[y_low * width + x_low]; - const scalar_t v2 = input[y_low * width + x_high]; - const scalar_t v3 = input[y_high * width + x_low]; - const scalar_t v4 = input[y_high * width + x_high]; + const scalar_t v1 = input[y_low * width + x_low]; + const scalar_t v2 = input[y_low * width + x_high]; + const scalar_t v3 = input[y_high * width + x_low]; + const scalar_t v4 = input[y_high * width + x_high]; - // lerp can be performed by fma - const scalar_t ly = y - y_low; - const scalar_t lx = x - x_low; - const scalar_t v_low = fma(v2 - v1, lx, v1); - const scalar_t v_high = fma(v4 - v3, lx, v3); - const scalar_t val = fma(v_high - v_low, ly, v_low); + // lerp can be performed by fma + const scalar_t ly = y - y_low; + const scalar_t lx = x - x_low; + const scalar_t v_low = fma(v2 - v1, lx, v1); + const scalar_t v_high = fma(v4 - v3, lx, v3); + const scalar_t val = fma(v_high - v_low, ly, v_low); - return val; + return val; } #endif // COMMON_CUDA_HELPER diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp index 22cffa0605..542db78b96 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp @@ -6,14 +6,29 @@ #include "cuda_runtime_api.h" #include "kernel.h" -pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatchBoxesSize, - const int perBatchScoresSize, const bool shareLocation, - const int backgroundLabelId, const int numPredsPerClass, - const int numClasses, const int topK, const int keepTopK, - const float scoreThreshold, const float iouThreshold, - const DataType DT_BBOX, const void* locData, const DataType DT_SCORE, - const void* confData, void* nmsedDets, void* nmsedLabels, - void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid, - bool clipBoxes, bool rotated = false); +pluginStatus_t nmsInference(cudaStream_t stream, + const int N, + const int perBatchBoxesSize, + const int perBatchScoresSize, + const bool shareLocation, + const int backgroundLabelId, + const int numPredsPerClass, + const int numClasses, + const int topK, + const int keepTopK, + const float scoreThreshold, + const float iouThreshold, + const DataType DT_BBOX, + const void* locData, + const DataType DT_SCORE, + const void* confData, + void* nmsedDets, + void* nmsedLabels, + void* nmsedIndex, + void* workspace, + bool isNormalized, + bool confSigmoid, + bool clipBoxes, + bool rotated = false); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/cub_helper.h b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/cub_helper.h index 93fd2a4fb9..19efec4ac5 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/cub_helper.h +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/cub_helper.h @@ -2,14 +2,19 @@ // modify from // https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin #include "kernel.h" -template -size_t cubSortPairsWorkspaceSize(int num_items, int num_segments) { - size_t temp_storage_bytes = 0; - cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)NULL, temp_storage_bytes, - (const KeyT*)NULL, (KeyT*)NULL, - (const ValueT*)NULL, (ValueT*)NULL, - num_items, // # items - num_segments, // # segments - (const int*)NULL, (const int*)NULL); - return temp_storage_bytes; +template +size_t cubSortPairsWorkspaceSize(int num_items, int num_segments) +{ + size_t temp_storage_bytes = 0; + cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)NULL, + temp_storage_bytes, + (const KeyT*)NULL, + (KeyT*)NULL, + (const ValueT*)NULL, + (ValueT*)NULL, + num_items, // # items + num_segments, // # segments + (const int*)NULL, + (const int*)NULL); + return temp_storage_bytes; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h index 1b50fa4e9f..6e690731d9 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h @@ -15,72 +15,152 @@ using namespace nvinfer1; #define DEBUG_ENABLE 0 -template -struct Bbox { - T xmin, ymin, xmax, ymax; - Bbox(T xmin, T ymin, T xmax, T ymax) : xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax) {} - Bbox() = default; +template +struct Bbox +{ + T xmin, ymin, xmax, ymax; + Bbox(T xmin, T ymin, T xmax, T ymax) + : xmin(xmin) + , ymin(ymin) + , xmax(xmax) + , ymax(ymax) + { + } + Bbox() = default; }; -size_t get_cuda_arch(int devID); - -int8_t* alignPtr(int8_t* ptr, uintptr_t to); - -int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize); - -void setUniformOffsets(cudaStream_t stream, int num_segments, int offset, int* d_offsets); - -pluginStatus_t allClassNMS(cudaStream_t stream, int num, int num_classes, int num_preds_per_class, - int top_k, float nms_threshold, bool share_location, bool isNormalized, - DataType DT_SCORE, DataType DT_BBOX, void* bbox_data, - void* beforeNMS_scores, void* beforeNMS_index_array, - void* afterNMS_scores, void* afterNMS_index_array, bool flipXY = false); - -pluginStatus_t allClassRotatedNMS(cudaStream_t stream, int num, int num_classes, - int num_preds_per_class, int top_k, float nms_threshold, - bool share_location, bool isNormalized, DataType DT_SCORE, - DataType DT_BBOX, void* bbox_data, void* beforeNMS_scores, - void* beforeNMS_index_array, void* afterNMS_scores, - void* afterNMS_index_array, bool flipXY = false); - -size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX); - -size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1, DataType DT_BBOX); - -size_t sortScoresPerClassWorkspaceSize(int num, int num_classes, int num_preds_per_class, - DataType DT_CONF); - -size_t sortScoresPerImageWorkspaceSize(int num_images, int num_items_per_image, DataType DT_SCORE); - -pluginStatus_t sortScoresPerImage(cudaStream_t stream, int num_images, int num_items_per_image, - DataType DT_SCORE, void* unsorted_scores, - void* unsorted_bbox_indices, void* sorted_scores, - void* sorted_bbox_indices, void* workspace); - -pluginStatus_t sortScoresPerClass(cudaStream_t stream, int num, int num_classes, - int num_preds_per_class, int background_label_id, - float confidence_threshold, DataType DT_SCORE, - void* conf_scores_gpu, void* index_array_gpu, void* workspace); - -size_t calculateTotalWorkspaceSize(size_t* workspaces, int count); - -pluginStatus_t permuteData(cudaStream_t stream, int nthreads, int num_classes, int num_data, - int num_dim, DataType DT_DATA, bool confSigmoid, const void* data, - void* new_data); - -size_t detectionForwardPreNMSSize(int N, int C2); - -size_t detectionForwardPostNMSSize(int N, int numClasses, int topK); - -pluginStatus_t gatherNMSOutputs(cudaStream_t stream, bool shareLocation, int numImages, - int numPredsPerClass, int numClasses, int topK, int keepTopK, - DataType DT_BBOX, DataType DT_SCORE, const void* indices, - const void* scores, const void* bboxData, void* nmsedDets, - void* nmsedLabels, void* nmsedIndex = nullptr, - bool clipBoxes = true, bool rotated = false); - -size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses, - int numPredsPerClass, int topK, DataType DT_BBOX, - DataType DT_SCORE); +size_t get_cuda_arch(int devID); + +int8_t* alignPtr(int8_t* ptr, + uintptr_t to); + +int8_t* nextWorkspacePtr(int8_t* ptr, + uintptr_t previousWorkspaceSize); + +void setUniformOffsets(cudaStream_t stream, + int num_segments, + int offset, + int* d_offsets); + +pluginStatus_t allClassNMS(cudaStream_t stream, + int num, + int num_classes, + int num_preds_per_class, + int top_k, + float nms_threshold, + bool share_location, + bool isNormalized, + DataType DT_SCORE, + DataType DT_BBOX, + void* bbox_data, + void* beforeNMS_scores, + void* beforeNMS_index_array, + void* afterNMS_scores, + void* afterNMS_index_array, + bool flipXY = false); + +pluginStatus_t allClassRotatedNMS(cudaStream_t stream, + int num, + int num_classes, + int num_preds_per_class, + int top_k, + float nms_threshold, + bool share_location, + bool isNormalized, + DataType DT_SCORE, + DataType DT_BBOX, + void* bbox_data, + void* beforeNMS_scores, + void* beforeNMS_index_array, + void* afterNMS_scores, + void* afterNMS_index_array, + bool flipXY = false); + +size_t detectionForwardBBoxDataSize(int N, + int C1, + DataType DT_BBOX); + +size_t detectionForwardBBoxPermuteSize(bool shareLocation, + int N, + int C1, + DataType DT_BBOX); + +size_t sortScoresPerClassWorkspaceSize(int num, + int num_classes, + int num_preds_per_class, + DataType DT_CONF); + +size_t sortScoresPerImageWorkspaceSize(int num_images, + int num_items_per_image, + DataType DT_SCORE); + +pluginStatus_t sortScoresPerImage(cudaStream_t stream, + int num_images, + int num_items_per_image, + DataType DT_SCORE, + void* unsorted_scores, + void* unsorted_bbox_indices, + void* sorted_scores, + void* sorted_bbox_indices, + void* workspace); + +pluginStatus_t sortScoresPerClass(cudaStream_t stream, + int num, + int num_classes, + int num_preds_per_class, + int background_label_id, + float confidence_threshold, + DataType DT_SCORE, + void* conf_scores_gpu, + void* index_array_gpu, + void* workspace); + +size_t calculateTotalWorkspaceSize(size_t* workspaces, + int count); + +pluginStatus_t permuteData(cudaStream_t stream, + int nthreads, + int num_classes, + int num_data, + int num_dim, + DataType DT_DATA, + bool confSigmoid, + const void* data, + void* new_data); + +size_t detectionForwardPreNMSSize(int N, + int C2); + +size_t detectionForwardPostNMSSize(int N, + int numClasses, + int topK); + +pluginStatus_t gatherNMSOutputs(cudaStream_t stream, + bool shareLocation, + int numImages, + int numPredsPerClass, + int numClasses, + int topK, + int keepTopK, + DataType DT_BBOX, + DataType DT_SCORE, + const void* indices, + const void* scores, + const void* bboxData, + void* nmsedDets, + void* nmsedLabels, + void* nmsedIndex = nullptr, + bool clipBoxes = true, + bool rotated = false); + +size_t detectionInferenceWorkspaceSize(bool shareLocation, + int N, + int C1, + int C2, + int numClasses, + int numPredsPerClass, + int topK, + DataType DT_BBOX, + DataType DT_SCORE); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp index 8440bb6219..cbe5c1a34c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp @@ -5,73 +5,106 @@ #include "NvInferVersion.h" #include "trt_plugin_helper.hpp" -namespace mmdeploy { +namespace mmdeploy +{ #if NV_TENSORRT_MAJOR > 7 -#define TRT_NOEXCEPT noexcept + #define TRT_NOEXCEPT noexcept #else -#define TRT_NOEXCEPT + #define TRT_NOEXCEPT #endif -class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt { - public: - TRTPluginBase(const std::string &name) : mLayerName(name) {} - // IPluginV2 Methods - const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; } - int initialize() TRT_NOEXCEPT override { return STATUS_SUCCESS; } - void terminate() TRT_NOEXCEPT override {} - void destroy() TRT_NOEXCEPT override { delete this; } - void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override { - mNamespace = pluginNamespace; - } - const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } + class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt + { + public: + TRTPluginBase(const std::string& name) + : mLayerName(name) + { + } + // IPluginV2 Methods + const char* getPluginVersion() const TRT_NOEXCEPT override + { + return "1"; + } + int initialize() TRT_NOEXCEPT override + { + return STATUS_SUCCESS; + } + void terminate() TRT_NOEXCEPT override {} + void destroy() TRT_NOEXCEPT override + { + delete this; + } + void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override + { + mNamespace = pluginNamespace; + } + const char* getPluginNamespace() const TRT_NOEXCEPT override + { + return mNamespace.c_str(); + } - virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override {} + virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override {} - virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override { - return 0; - } + virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override + { + return 0; + } - virtual void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override {} + virtual void attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override {} - virtual void detachFromContext() TRT_NOEXCEPT override {} + virtual void detachFromContext() TRT_NOEXCEPT override {} - protected: - const std::string mLayerName; - std::string mNamespace; + protected: + const std::string mLayerName; + std::string mNamespace; #if NV_TENSORRT_MAJOR < 8 - protected: - // To prevent compiler warnings. - using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; - using nvinfer1::IPluginV2DynamicExt::enqueue; - using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; - using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; - using nvinfer1::IPluginV2DynamicExt::supportsFormat; + protected: + // To prevent compiler warnings. + using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::enqueue; + using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; + using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::supportsFormat; #endif -}; + }; -class TRTPluginCreatorBase : public nvinfer1::IPluginCreator { - public: - const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; }; + class TRTPluginCreatorBase : public nvinfer1::IPluginCreator + { + public: + const char* getPluginVersion() const TRT_NOEXCEPT override + { + return "1"; + }; - const nvinfer1::PluginFieldCollection *getFieldNames() TRT_NOEXCEPT override { return &mFC; } + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override + { + return &mFC; + } - void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override { - mNamespace = pluginNamespace; - } + void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override + { + mNamespace = pluginNamespace; + } - const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } + const char* getPluginNamespace() const TRT_NOEXCEPT override + { + return mNamespace.c_str(); + } - protected: - nvinfer1::PluginFieldCollection mFC; - std::vector mPluginAttributes; - std::string mNamespace; -}; + protected: + nvinfer1::PluginFieldCollection mFC; + std::vector mPluginAttributes; + std::string mNamespace; + }; } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_helper.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_helper.hpp index 41b47acdbe..050c0dd308 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_helper.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_helper.hpp @@ -11,145 +11,159 @@ cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t* cudnn_dtype); // Enumerator for status -typedef enum { - STATUS_SUCCESS = 0, - STATUS_FAILURE = 1, - STATUS_BAD_PARAM = 2, - STATUS_NOT_SUPPORTED = 3, - STATUS_NOT_INITIALIZED = 4 +typedef enum +{ + STATUS_SUCCESS = 0, + STATUS_FAILURE = 1, + STATUS_BAD_PARAM = 2, + STATUS_NOT_SUPPORTED = 3, + STATUS_NOT_INITIALIZED = 4 } pluginStatus_t; -#define ASSERT(assertion) \ - { \ - if (!(assertion)) { \ - std::cerr << "#assertion" << __FILE__ << "," << __LINE__ << std::endl; \ - abort(); \ - } \ - } - -#define CUASSERT(status_) \ - { \ - auto s_ = status_; \ - if (s_ != cudaSuccess) { \ - std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << ", " << cudaGetErrorString(s_) \ - << std::endl; \ - } \ - } -#define CUBLASASSERT(status_) \ - { \ - auto s_ = status_; \ - if (s_ != CUBLAS_STATUS_SUCCESS) { \ - std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ - } \ - } -#define CUERRORMSG(status_) \ - { \ - auto s_ = status_; \ - if (s_ != 0) std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ - } +#define ASSERT(assertion) \ + { \ + if (!(assertion)) \ + { \ + std::cerr << "#assertion" << __FILE__ << "," << __LINE__ << std::endl; \ + abort(); \ + } \ + } + +#define CUASSERT(status_) \ + { \ + auto s_ = status_; \ + if (s_ != cudaSuccess) \ + { \ + std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << ", " << cudaGetErrorString(s_) \ + << std::endl; \ + } \ + } +#define CUBLASASSERT(status_) \ + { \ + auto s_ = status_; \ + if (s_ != CUBLAS_STATUS_SUCCESS) \ + { \ + std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ + } \ + } +#define CUERRORMSG(status_) \ + { \ + auto s_ = status_; \ + if (s_ != 0) std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ + } #ifndef DEBUG -#define CHECK(status) \ - do { \ - if (status != 0) abort(); \ - } while (0) - -#define ASSERT_PARAM(exp) \ - do { \ - if (!(exp)) return STATUS_BAD_PARAM; \ - } while (0) - -#define ASSERT_FAILURE(exp) \ - do { \ - if (!(exp)) return STATUS_FAILURE; \ - } while (0) - -#define CSC(call, err) \ - do { \ - cudaError_t cudaStatus = call; \ - if (cudaStatus != cudaSuccess) { \ - return err; \ - } \ - } while (0) - -#define DEBUG_PRINTF(...) \ - do { \ - } while (0) + #define CHECK(status) \ + do { \ + if (status != 0) abort(); \ + } while (0) + + #define ASSERT_PARAM(exp) \ + do { \ + if (!(exp)) return STATUS_BAD_PARAM; \ + } while (0) + + #define ASSERT_FAILURE(exp) \ + do { \ + if (!(exp)) return STATUS_FAILURE; \ + } while (0) + + #define CSC(call, err) \ + do { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + return err; \ + } \ + } while (0) + + #define DEBUG_PRINTF(...) \ + do { \ + } while (0) #else -#define ASSERT_PARAM(exp) \ - do { \ - if (!(exp)) { \ - fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ - return STATUS_BAD_PARAM; \ - } \ - } while (0) - -#define ASSERT_FAILURE(exp) \ - do { \ - if (!(exp)) { \ - fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ - return STATUS_FAILURE; \ - } \ - } while (0) - -#define CSC(call, err) \ - do { \ - cudaError_t cudaStatus = call; \ - if (cudaStatus != cudaSuccess) { \ - printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ - return err; \ - } \ - } while (0) - -#define CHECK(status) \ - { \ - if (status != 0) { \ - DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ - abort(); \ - } \ - } - -#define DEBUG_PRINTF(...) \ - do { \ - printf(__VA_ARGS__); \ - } while (0) + #define ASSERT_PARAM(exp) \ + do { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_BAD_PARAM; \ + } \ + } while (0) + + #define ASSERT_FAILURE(exp) \ + do { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_FAILURE; \ + } \ + } while (0) + + #define CSC(call, err) \ + do { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ + return err; \ + } \ + } while (0) + + #define CHECK(status) \ + { \ + if (status != 0) \ + { \ + DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ + abort(); \ + } \ + } + + #define DEBUG_PRINTF(...) \ + do { \ + printf(__VA_ARGS__); \ + } while (0) #endif -namespace mmdeploy { - -const int MAXTENSORDIMS = 10; - -struct TensorDesc { - int shape[MAXTENSORDIMS]; - int stride[MAXTENSORDIMS]; - int dim; -}; - -inline unsigned int getElementSize(nvinfer1::DataType t) { - switch (t) { - case nvinfer1::DataType::kINT32: - return 4; - case nvinfer1::DataType::kFLOAT: - return 4; - case nvinfer1::DataType::kHALF: - return 2; - // case nvinfer1::DataType::kBOOL: - case nvinfer1::DataType::kINT8: - return 1; - default: - throw std::runtime_error("Invalid DataType."); - } - throw std::runtime_error("Invalid DataType."); - return 0; -} - -inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) { - return size_t((origin_size + aligned_number - 1) / aligned_number) * aligned_number; -} +namespace mmdeploy +{ + + const int MAXTENSORDIMS = 10; + + struct TensorDesc + { + int shape[MAXTENSORDIMS]; + int stride[MAXTENSORDIMS]; + int dim; + }; + + inline unsigned int getElementSize(nvinfer1::DataType t) + { + switch (t) + { + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + // case nvinfer1::DataType::kBOOL: + case nvinfer1::DataType::kINT8: + return 1; + default: + throw std::runtime_error("Invalid DataType."); + } + throw std::runtime_error("Invalid DataType."); + return 0; + } + + inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) + { + return size_t((origin_size + aligned_number - 1) / aligned_number) * aligned_number; + } } // namespace mmdeploy #endif // TRT_PLUGIN_HELPER_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_serialize.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_serialize.hpp index db88184432..c059a7cfb8 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_serialize.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_serialize.hpp @@ -9,89 +9,117 @@ #include #include -template +template inline void serialize_value(void** buffer, T const& value); -template +template inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value); -namespace { - -template -struct Serializer {}; - -template -struct Serializer::value || std::is_enum::value || - std::is_pod::value>::type> { - static size_t serialized_size(T const& value) { return sizeof(T); } - static void serialize(void** buffer, T const& value) { - ::memcpy(*buffer, &value, sizeof(T)); - reinterpret_cast(*buffer) += sizeof(T); - } - static void deserialize(void const** buffer, size_t* buffer_size, T* value) { - assert(*buffer_size >= sizeof(T)); - ::memcpy(value, *buffer, sizeof(T)); - reinterpret_cast(*buffer) += sizeof(T); - *buffer_size -= sizeof(T); - } -}; - -template <> -struct Serializer { - static size_t serialized_size(const char* value) { return strlen(value) + 1; } - static void serialize(void** buffer, const char* value) { - ::strcpy(static_cast(*buffer), value); - reinterpret_cast(*buffer) += strlen(value) + 1; - } - static void deserialize(void const** buffer, size_t* buffer_size, const char** value) { - *value = static_cast(*buffer); - size_t data_size = strnlen(*value, *buffer_size) + 1; - assert(*buffer_size >= data_size); - reinterpret_cast(*buffer) += data_size; - *buffer_size -= data_size; - } -}; - -template -struct Serializer, - typename std::enable_if::value || std::is_enum::value || - std::is_pod::value>::type> { - static size_t serialized_size(std::vector const& value) { - return sizeof(value.size()) + value.size() * sizeof(T); - } - static void serialize(void** buffer, std::vector const& value) { - serialize_value(buffer, value.size()); - size_t nbyte = value.size() * sizeof(T); - ::memcpy(*buffer, value.data(), nbyte); - reinterpret_cast(*buffer) += nbyte; - } - static void deserialize(void const** buffer, size_t* buffer_size, std::vector* value) { - size_t size; - deserialize_value(buffer, buffer_size, &size); - value->resize(size); - size_t nbyte = value->size() * sizeof(T); - assert(*buffer_size >= nbyte); - ::memcpy(value->data(), *buffer, nbyte); - reinterpret_cast(*buffer) += nbyte; - *buffer_size -= nbyte; - } -}; +namespace +{ + + template + struct Serializer + { + }; + + template + struct Serializer::value || std::is_enum::value || + std::is_pod::value>::type> + { + static size_t serialized_size(T const& value) + { + return sizeof(T); + } + + static void serialize(void** buffer, T const& value) + { + ::memcpy(*buffer, &value, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + } + + static void deserialize(void const** buffer, size_t* buffer_size, T* value) + { + assert(*buffer_size >= sizeof(T)); + ::memcpy(value, *buffer, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + *buffer_size -= sizeof(T); + } + }; + + template<> + struct Serializer + { + static size_t serialized_size(const char* value) + { + return strlen(value) + 1; + } + + static void serialize(void** buffer, const char* value) + { + ::strcpy(static_cast(*buffer), value); + reinterpret_cast(*buffer) += strlen(value) + 1; + } + + static void deserialize(void const** buffer, size_t* buffer_size, const char** value) + { + *value = static_cast(*buffer); + size_t data_size = strnlen(*value, *buffer_size) + 1; + assert(*buffer_size >= data_size); + reinterpret_cast(*buffer) += data_size; + *buffer_size -= data_size; + } + }; + + template + struct Serializer, + typename std::enable_if::value || std::is_enum::value || + std::is_pod::value>::type> + { + static size_t serialized_size(std::vector const& value) + { + return sizeof(value.size()) + value.size() * sizeof(T); + } + + static void serialize(void** buffer, std::vector const& value) + { + serialize_value(buffer, value.size()); + size_t nbyte = value.size() * sizeof(T); + ::memcpy(*buffer, value.data(), nbyte); + reinterpret_cast(*buffer) += nbyte; + } + + static void deserialize(void const** buffer, size_t* buffer_size, std::vector* value) + { + size_t size; + deserialize_value(buffer, buffer_size, &size); + value->resize(size); + size_t nbyte = value->size() * sizeof(T); + assert(*buffer_size >= nbyte); + ::memcpy(value->data(), *buffer, nbyte); + reinterpret_cast(*buffer) += nbyte; + *buffer_size -= nbyte; + } + }; } // namespace -template -inline size_t serialized_size(T const& value) { - return Serializer::serialized_size(value); +template +inline size_t serialized_size(T const& value) +{ + return Serializer::serialized_size(value); } -template -inline void serialize_value(void** buffer, T const& value) { - return Serializer::serialize(buffer, value); +template +inline void serialize_value(void** buffer, T const& value) +{ + return Serializer::serialize(buffer, value); } -template -inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value) { - return Serializer::deserialize(buffer, buffer_size, value); +template +inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value) +{ + return Serializer::deserialize(buffer, buffer_size, value); } #endif // TRT_SERIALIZE_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu index 44c08152db..99aba5704c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu @@ -7,261 +7,381 @@ const static int BS = 512; -template -__device__ T_BBOX bboxSize(const Bbox &bbox, const bool normalized, T_BBOX offset) { - if (bbox.xmax < bbox.xmin || bbox.ymax < bbox.ymin) { - // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. - return 0; - } else { - T_BBOX width = bbox.xmax - bbox.xmin; - T_BBOX height = bbox.ymax - bbox.ymin; - if (normalized) { - return width * height; - } else { - // If bbox is not within range [0, 1]. - return (width + offset) * (height + offset); +template +__device__ T_BBOX bboxSize(const Bbox& bbox, + const bool normalized, + T_BBOX offset) +{ + if (bbox.xmax < bbox.xmin || bbox.ymax < bbox.ymin) + { + // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. + return 0; + } + else + { + T_BBOX width = bbox.xmax - bbox.xmin; + T_BBOX height = bbox.ymax - bbox.ymin; + if (normalized) + { + return width * height; + } + else + { + // If bbox is not within range [0, 1]. + return (width + offset) * (height + offset); + } } - } } -template -__device__ void intersectBbox(const Bbox &bbox1, const Bbox &bbox2, - Bbox *intersect_bbox) { - if (bbox2.xmin > bbox1.xmax || bbox2.xmax < bbox1.xmin || bbox2.ymin > bbox1.ymax || - bbox2.ymax < bbox1.ymin) { - // Return [0, 0, 0, 0] if there is no intersection. - intersect_bbox->xmin = T_BBOX(0); - intersect_bbox->ymin = T_BBOX(0); - intersect_bbox->xmax = T_BBOX(0); - intersect_bbox->ymax = T_BBOX(0); - } else { - intersect_bbox->xmin = max(bbox1.xmin, bbox2.xmin); - intersect_bbox->ymin = max(bbox1.ymin, bbox2.ymin); - intersect_bbox->xmax = min(bbox1.xmax, bbox2.xmax); - intersect_bbox->ymax = min(bbox1.ymax, bbox2.ymax); - } +template +__device__ void intersectBbox(const Bbox& bbox1, + const Bbox& bbox2, + Bbox* intersect_bbox) +{ + if (bbox2.xmin > bbox1.xmax || bbox2.xmax < bbox1.xmin || bbox2.ymin > bbox1.ymax || + bbox2.ymax < bbox1.ymin) + { + // Return [0, 0, 0, 0] if there is no intersection. + intersect_bbox->xmin = T_BBOX(0); + intersect_bbox->ymin = T_BBOX(0); + intersect_bbox->xmax = T_BBOX(0); + intersect_bbox->ymax = T_BBOX(0); + } + else + { + intersect_bbox->xmin = max(bbox1.xmin, bbox2.xmin); + intersect_bbox->ymin = max(bbox1.ymin, bbox2.ymin); + intersect_bbox->xmax = min(bbox1.xmax, bbox2.xmax); + intersect_bbox->ymax = min(bbox1.ymax, bbox2.ymax); + } } -template -__device__ float jaccardOverlap(const Bbox &bbox1, const Bbox &bbox2, - const bool normalized, T_BBOX offset) { - Bbox intersect_bbox; - intersectBbox(bbox1, bbox2, &intersect_bbox); - float intersect_width, intersect_height; - if (normalized) { - intersect_width = intersect_bbox.xmax - intersect_bbox.xmin; - intersect_height = intersect_bbox.ymax - intersect_bbox.ymin; - } else { - intersect_width = intersect_bbox.xmax - intersect_bbox.xmin + offset; - intersect_height = intersect_bbox.ymax - intersect_bbox.ymin + offset; - } - if (intersect_width > 0 && intersect_height > 0) { - float intersect_size = intersect_width * intersect_height; - float bbox1_size = bboxSize(bbox1, normalized, offset); - float bbox2_size = bboxSize(bbox2, normalized, offset); - return intersect_size / (bbox1_size + bbox2_size - intersect_size); - } else { - return 0.; - } +template +__device__ float jaccardOverlap(const Bbox& bbox1, + const Bbox& bbox2, + const bool normalized, + T_BBOX offset) +{ + Bbox intersect_bbox; + intersectBbox(bbox1, bbox2, &intersect_bbox); + float intersect_width, intersect_height; + if (normalized) + { + intersect_width = intersect_bbox.xmax - intersect_bbox.xmin; + intersect_height = intersect_bbox.ymax - intersect_bbox.ymin; + } + else + { + intersect_width = intersect_bbox.xmax - intersect_bbox.xmin + offset; + intersect_height = intersect_bbox.ymax - intersect_bbox.ymin + offset; + } + if (intersect_width > 0 && intersect_height > 0) + { + float intersect_size = intersect_width * intersect_height; + float bbox1_size = bboxSize(bbox1, normalized, offset); + float bbox2_size = bboxSize(bbox2, normalized, offset); + return intersect_size / (bbox1_size + bbox2_size - intersect_size); + } + else + { + return 0.; + } } /********** new NMS for only score and index array **********/ -// clang-format off -template +template __global__ void #ifdef __CUDA_ARCH__ -#if __CUDA_ARCH__ == 620 || __CUDA_ARCH__ == 530 -__launch_bounds__(512) + #if __CUDA_ARCH__ == 620 || __CUDA_ARCH__ == 530 + __launch_bounds__(512) + #endif #endif -#endif -allClassNMS_kernel(const int num, const int num_classes, const int num_preds_per_class, - const int top_k, const float nms_threshold, const bool share_location, - const bool isNormalized, - T_BBOX *bbox_data, // bbox_data should be float to preserve - // location information - T_SCORE *beforeNMS_scores, int *beforeNMS_index_array, - T_SCORE *afterNMS_scores, int *afterNMS_index_array, bool flipXY = false) { - // clang-format on - //__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE]; - __shared__ bool kept_bboxinfo_flag[TSIZE * BS]; - for (int i = 0; i < num; i++) { - const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class; - const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation - const int bbox_idx_offset = - share_location ? (i * num_preds_per_class) : (i * num_classes * num_preds_per_class); - - // local thread data - int loc_bboxIndex[TSIZE]; - Bbox loc_bbox[TSIZE]; - - // initialize Bbox, Bboxinfo, kept_bboxinfo_flag - // Eliminate shared memory RAW hazard - __syncthreads(); + allClassNMS_kernel(const int num, + const int num_classes, + const int num_preds_per_class, + const int top_k, + const float nms_threshold, + const bool share_location, + const bool isNormalized, + T_BBOX* bbox_data, // bbox_data should be float to preserve location information + T_SCORE* beforeNMS_scores, + int* beforeNMS_index_array, + T_SCORE* afterNMS_scores, + int* afterNMS_index_array, + bool flipXY = false) +{ + //__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE]; + __shared__ bool kept_bboxinfo_flag[TSIZE * BS]; + for (int i = 0; i < num; i++) + { + const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class; + const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation + const int bbox_idx_offset = + share_location ? (i * num_preds_per_class) : (i * num_classes * num_preds_per_class); + + // local thread data + int loc_bboxIndex[TSIZE]; + Bbox loc_bbox[TSIZE]; + + // initialize Bbox, Bboxinfo, kept_bboxinfo_flag + // Eliminate shared memory RAW hazard + __syncthreads(); #pragma unroll - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int item_idx = offset + cur_idx; + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + + if (item_idx < max_idx) + { + loc_bboxIndex[t] = beforeNMS_index_array[item_idx]; + + if (loc_bboxIndex[t] >= 0) + // if (loc_bboxIndex[t] != -1) + { + const int bbox_data_idx = share_location ? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset) : loc_bboxIndex[t]; + + loc_bbox[t].xmin = + flipXY ? bbox_data[bbox_data_idx * 4 + 1] : bbox_data[bbox_data_idx * 4 + 0]; + loc_bbox[t].ymin = + flipXY ? bbox_data[bbox_data_idx * 4 + 0] : bbox_data[bbox_data_idx * 4 + 1]; + loc_bbox[t].xmax = + flipXY ? bbox_data[bbox_data_idx * 4 + 3] : bbox_data[bbox_data_idx * 4 + 2]; + loc_bbox[t].ymax = + flipXY ? bbox_data[bbox_data_idx * 4 + 2] : bbox_data[bbox_data_idx * 4 + 3]; + kept_bboxinfo_flag[cur_idx] = true; + } + else + { + kept_bboxinfo_flag[cur_idx] = false; + } + } + else + { + kept_bboxinfo_flag[cur_idx] = false; + } + } - if (item_idx < max_idx) { - loc_bboxIndex[t] = beforeNMS_index_array[item_idx]; + // filter out overlapped boxes with lower scores + int ref_item_idx = offset; + int ref_bbox_idx = share_location ? + (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) : + beforeNMS_index_array[ref_item_idx]; - if (loc_bboxIndex[t] >= 0) - // if (loc_bboxIndex[t] != -1) + while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) { - const int bbox_data_idx = share_location - ? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset) - : loc_bboxIndex[t]; - - loc_bbox[t].xmin = - flipXY ? bbox_data[bbox_data_idx * 4 + 1] : bbox_data[bbox_data_idx * 4 + 0]; - loc_bbox[t].ymin = - flipXY ? bbox_data[bbox_data_idx * 4 + 0] : bbox_data[bbox_data_idx * 4 + 1]; - loc_bbox[t].xmax = - flipXY ? bbox_data[bbox_data_idx * 4 + 3] : bbox_data[bbox_data_idx * 4 + 2]; - loc_bbox[t].ymax = - flipXY ? bbox_data[bbox_data_idx * 4 + 2] : bbox_data[bbox_data_idx * 4 + 3]; - kept_bboxinfo_flag[cur_idx] = true; - } else { - kept_bboxinfo_flag[cur_idx] = false; + Bbox ref_bbox; + ref_bbox.xmin = flipXY ? bbox_data[ref_bbox_idx * 4 + 1] : bbox_data[ref_bbox_idx * 4 + 0]; + ref_bbox.ymin = flipXY ? bbox_data[ref_bbox_idx * 4 + 0] : bbox_data[ref_bbox_idx * 4 + 1]; + ref_bbox.xmax = flipXY ? bbox_data[ref_bbox_idx * 4 + 3] : bbox_data[ref_bbox_idx * 4 + 2]; + ref_bbox.ymax = flipXY ? bbox_data[ref_bbox_idx * 4 + 2] : bbox_data[ref_bbox_idx * 4 + 3]; + + // Eliminate shared memory RAW hazard + __syncthreads(); + + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + + if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) + { + // TODO: may need to add bool normalized as argument, HERE true means + // normalized + if (jaccardOverlap(ref_bbox, loc_bbox[t], isNormalized, T_BBOX(0)) > nms_threshold) + { + kept_bboxinfo_flag[cur_idx] = false; + } + } + } + __syncthreads(); + + do { + ref_item_idx++; + } while (ref_item_idx < max_idx && !kept_bboxinfo_flag[ref_item_idx - offset]); + + ref_bbox_idx = + share_location ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) : beforeNMS_index_array[ref_item_idx]; } - } else { - kept_bboxinfo_flag[cur_idx] = false; - } - } - // filter out overlapped boxes with lower scores - int ref_item_idx = offset; - int ref_bbox_idx = - share_location - ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) - : beforeNMS_index_array[ref_item_idx]; - - while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) { - Bbox ref_bbox; - ref_bbox.xmin = flipXY ? bbox_data[ref_bbox_idx * 4 + 1] : bbox_data[ref_bbox_idx * 4 + 0]; - ref_bbox.ymin = flipXY ? bbox_data[ref_bbox_idx * 4 + 0] : bbox_data[ref_bbox_idx * 4 + 1]; - ref_bbox.xmax = flipXY ? bbox_data[ref_bbox_idx * 4 + 3] : bbox_data[ref_bbox_idx * 4 + 2]; - ref_bbox.ymax = flipXY ? bbox_data[ref_bbox_idx * 4 + 2] : bbox_data[ref_bbox_idx * 4 + 3]; - - // Eliminate shared memory RAW hazard - __syncthreads(); - - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int item_idx = offset + cur_idx; - - if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) { - // TODO: may need to add bool normalized as argument, HERE true means - // normalized - if (jaccardOverlap(ref_bbox, loc_bbox[t], isNormalized, T_BBOX(0)) > nms_threshold) { - kept_bboxinfo_flag[cur_idx] = false; - } + // store data + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int read_item_idx = offset + cur_idx; + const int write_item_idx = (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx; + /* + * If not not keeping the bbox + * Set the score to 0 + * Set the bounding box index to -1 + */ + if (read_item_idx < max_idx) + { + afterNMS_scores[write_item_idx] = + kept_bboxinfo_flag[cur_idx] ? beforeNMS_scores[read_item_idx] : 0.0f; + afterNMS_index_array[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1; + } } - } - __syncthreads(); - - do { - ref_item_idx++; - } while (ref_item_idx < max_idx && !kept_bboxinfo_flag[ref_item_idx - offset]); - - ref_bbox_idx = - share_location - ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) - : beforeNMS_index_array[ref_item_idx]; } - - // store data - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int read_item_idx = offset + cur_idx; - const int write_item_idx = (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx; - /* - * If not not keeping the bbox - * Set the score to 0 - * Set the bounding box index to -1 - */ - if (read_item_idx < max_idx) { - afterNMS_scores[write_item_idx] = - kept_bboxinfo_flag[cur_idx] ? beforeNMS_scores[read_item_idx] : 0.0f; - afterNMS_index_array[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1; - } - } - } } -template -pluginStatus_t allClassNMS_gpu(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, void *bbox_data, void *beforeNMS_scores, - void *beforeNMS_index_array, void *afterNMS_scores, - void *afterNMS_index_array, bool flipXY = false) { +template +pluginStatus_t allClassNMS_gpu(cudaStream_t stream, + const int num, + const int num_classes, + const int num_preds_per_class, + const int top_k, + const float nms_threshold, + const bool share_location, + const bool isNormalized, + void* bbox_data, + void* beforeNMS_scores, + void* beforeNMS_index_array, + void* afterNMS_scores, + void* afterNMS_index_array, + bool flipXY = false) +{ #define P(tsize) allClassNMS_kernel - void (*kernel[10])(const int, const int, const int, const int, const float, const bool, - const bool, float *, T_SCORE *, int *, T_SCORE *, int *, bool) = { - P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10), - }; - - const int GS = num_classes; - const int t_size = (top_k + BS - 1) / BS; - - ASSERT(t_size <= 10); - kernel[t_size - 1]<<>>( - num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized, - (T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array, - (T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array, flipXY); - - cudaError_t code = cudaGetLastError(); - CUASSERT(code); - CSC(code, STATUS_FAILURE); - return STATUS_SUCCESS; + void (*kernel[10])(const int, + const int, + const int, + const int, + const float, + const bool, + const bool, + float*, + T_SCORE*, + int*, + T_SCORE*, + int*, + bool) = { + P(1), + P(2), + P(3), + P(4), + P(5), + P(6), + P(7), + P(8), + P(9), + P(10), + }; + + const int GS = num_classes; + const int t_size = (top_k + BS - 1) / BS; + + ASSERT(t_size <= 10); + kernel[t_size - 1]<<>>( + num, + num_classes, + num_preds_per_class, + top_k, + nms_threshold, + share_location, + isNormalized, + (T_BBOX*)bbox_data, + (T_SCORE*)beforeNMS_scores, + (int*)beforeNMS_index_array, + (T_SCORE*)afterNMS_scores, + (int*)afterNMS_index_array, + flipXY); + + cudaError_t code = cudaGetLastError(); + CUASSERT(code); + CSC(code, STATUS_FAILURE); + return STATUS_SUCCESS; } // allClassNMS LAUNCH CONFIG -typedef pluginStatus_t (*nmsFunc)(cudaStream_t, const int, const int, const int, const int, - const float, const bool, const bool, void *, void *, void *, - void *, void *, bool); - -struct nmsLaunchConfigSSD { - DataType t_score; - DataType t_bbox; - nmsFunc function; - - nmsLaunchConfigSSD(DataType t_score, DataType t_bbox) : t_score(t_score), t_bbox(t_bbox) {} - nmsLaunchConfigSSD(DataType t_score, DataType t_bbox, nmsFunc function) - : t_score(t_score), t_bbox(t_bbox), function(function) {} - bool operator==(const nmsLaunchConfigSSD &other) { - return t_score == other.t_score && t_bbox == other.t_bbox; - } +typedef pluginStatus_t (*nmsFunc)(cudaStream_t, + const int, + const int, + const int, + const int, + const float, + const bool, + const bool, + void*, + void*, + void*, + void*, + void*, + bool); + +struct nmsLaunchConfigSSD +{ + DataType t_score; + DataType t_bbox; + nmsFunc function; + + nmsLaunchConfigSSD(DataType t_score, DataType t_bbox) + : t_score(t_score) + , t_bbox(t_bbox) + { + } + nmsLaunchConfigSSD(DataType t_score, DataType t_bbox, nmsFunc function) + : t_score(t_score) + , t_bbox(t_bbox) + , function(function) + { + } + bool operator==(const nmsLaunchConfigSSD& other) + { + return t_score == other.t_score && t_bbox == other.t_bbox; + } }; static std::vector nmsFuncVec; -bool nmsInit() { - nmsFuncVec.push_back( - nmsLaunchConfigSSD(DataType::kFLOAT, DataType::kFLOAT, allClassNMS_gpu)); - return true; +bool nmsInit() +{ + nmsFuncVec.push_back( + nmsLaunchConfigSSD(DataType::kFLOAT, DataType::kFLOAT, allClassNMS_gpu)); + return true; } -static bool initialized = nmsInit(); - -pluginStatus_t allClassNMS(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, const DataType DT_SCORE, const DataType DT_BBOX, - void *bbox_data, void *beforeNMS_scores, void *beforeNMS_index_array, - void *afterNMS_scores, void *afterNMS_index_array, bool flipXY) { - nmsLaunchConfigSSD lc(DT_SCORE, DT_BBOX); - for (unsigned i = 0; i < nmsFuncVec.size(); ++i) { - if (lc == nmsFuncVec[i]) { - DEBUG_PRINTF("all class nms kernel %d\n", i); - return nmsFuncVec[i].function(stream, num, num_classes, num_preds_per_class, top_k, - nms_threshold, share_location, isNormalized, bbox_data, - beforeNMS_scores, beforeNMS_index_array, afterNMS_scores, - afterNMS_index_array, flipXY); +static bool initialized = nmsInit(); + +pluginStatus_t allClassNMS(cudaStream_t stream, + const int num, + const int num_classes, + const int num_preds_per_class, + const int top_k, + const float nms_threshold, + const bool share_location, + const bool isNormalized, + const DataType DT_SCORE, + const DataType DT_BBOX, + void* bbox_data, + void* beforeNMS_scores, + void* beforeNMS_index_array, + void* afterNMS_scores, + void* afterNMS_index_array, + bool flipXY) +{ + nmsLaunchConfigSSD lc(DT_SCORE, DT_BBOX); + for (unsigned i = 0; i < nmsFuncVec.size(); ++i) + { + if (lc == nmsFuncVec[i]) + { + DEBUG_PRINTF("all class nms kernel %d\n", i); + return nmsFuncVec[i].function(stream, + num, + num_classes, + num_preds_per_class, + top_k, + nms_threshold, + share_location, + isNormalized, + bbox_data, + beforeNMS_scores, + beforeNMS_index_array, + afterNMS_scores, + afterNMS_index_array, + flipXY); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu index 0edea2bfaf..e8c1cd2187 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu @@ -6,490 +6,636 @@ #include "nms/kernel.h" -template -struct RotatedBox { - T x_ctr, y_ctr, w, h, a; +template +struct RotatedBox +{ + T x_ctr, y_ctr, w, h, a; }; -template -struct Point { - T x, y; - __host__ __device__ __forceinline__ Point(const T &px = 0, const T &py = 0) : x(px), y(py) {} - __host__ __device__ __forceinline__ Point operator+(const Point &p) const { - return Point(x + p.x, y + p.y); - } - __host__ __device__ __forceinline__ Point &operator+=(const Point &p) { - x += p.x; - y += p.y; - return *this; - } - __host__ __device__ __forceinline__ Point operator-(const Point &p) const { - return Point(x - p.x, y - p.y); - } - __host__ __device__ __forceinline__ Point operator*(const T coeff) const { - return Point(x * coeff, y * coeff); - } +template +struct Point +{ + T x, y; + __host__ __device__ __forceinline__ Point(const T& px = 0, const T& py = 0) + : x(px) + , y(py) + { + } + + __host__ __device__ __forceinline__ Point operator+(const Point& p) const + { + return Point(x + p.x, y + p.y); + } + + __host__ __device__ __forceinline__ Point& operator+=(const Point& p) + { + x += p.x; + y += p.y; + return *this; + } + + __host__ __device__ __forceinline__ Point operator-(const Point& p) const + { + return Point(x - p.x, y - p.y); + } + + __host__ __device__ __forceinline__ Point operator*(const T coeff) const + { + return Point(x * coeff, y * coeff); + } }; -template -__host__ __device__ __forceinline__ T dot_2d(const Point &A, const Point &B) { - return A.x * B.x + A.y * B.y; +template +__host__ __device__ __forceinline__ T dot_2d(const Point& A, const Point& B) +{ + return A.x * B.x + A.y * B.y; } -template -__host__ __device__ __forceinline__ T cross_2d(const Point &A, const Point &B) { - return A.x * B.y - B.x * A.y; +template +__host__ __device__ __forceinline__ T cross_2d(const Point& A, const Point& B) +{ + return A.x * B.y - B.x * A.y; } -template -__host__ __device__ __forceinline__ void get_rotated_vertices(const RotatedBox &box, - Point (&pts)[4]) { - // M_PI / 180. == 0.01745329251 - // double theta = box.a * 0.01745329251; - // MODIFIED - double theta = box.a; - T cosTheta2 = (T)cos(theta) * 0.5f; - T sinTheta2 = (T)sin(theta) * 0.5f; - - // y: top --> down; x: left --> right - pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; - pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; - pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; - pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; - pts[2].x = 2 * box.x_ctr - pts[0].x; - pts[2].y = 2 * box.y_ctr - pts[0].y; - pts[3].x = 2 * box.x_ctr - pts[1].x; - pts[3].y = 2 * box.y_ctr - pts[1].y; +template +__host__ __device__ __forceinline__ void get_rotated_vertices(const RotatedBox& box, + Point (&pts)[4]) +{ + // M_PI / 180. == 0.01745329251 + // double theta = box.a * 0.01745329251; + // MODIFIED + double theta = box.a; + T cosTheta2 = (T)cos(theta) * 0.5f; + T sinTheta2 = (T)sin(theta) * 0.5f; + + // y: top --> down; x: left --> right + pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; + pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; + pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; + pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; + pts[2].x = 2 * box.x_ctr - pts[0].x; + pts[2].y = 2 * box.y_ctr - pts[0].y; + pts[3].x = 2 * box.x_ctr - pts[1].x; + pts[3].y = 2 * box.y_ctr - pts[1].y; } -template +template __host__ __device__ __forceinline__ int get_intersection_points(const Point (&pts1)[4], const Point (&pts2)[4], - Point (&intersections)[24]) { - // Line vector - // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] - Point vec1[4], vec2[4]; - for (int i = 0; i < 4; i++) { - vec1[i] = pts1[(i + 1) % 4] - pts1[i]; - vec2[i] = pts2[(i + 1) % 4] - pts2[i]; - } - - // Line test - test all line combos for intersection - int num = 0; // number of intersections - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - // Solve for 2x2 Ax=b - T det = cross_2d(vec2[j], vec1[i]); - - // This takes care of parallel lines - if (fabs(det) <= 1e-14) { - continue; - } - - auto vec12 = pts2[j] - pts1[i]; - - T t1 = cross_2d(vec2[j], vec12) / det; - T t2 = cross_2d(vec1[i], vec12) / det; - - if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) { - intersections[num++] = pts1[i] + vec1[i] * t1; - } + Point (&intersections)[24]) +{ + // Line vector + // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] + Point vec1[4], vec2[4]; + for (int i = 0; i < 4; i++) + { + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; } - } - - // Check for vertices of rect1 inside rect2 - { - const auto &AB = vec2[0]; - const auto &DA = vec2[3]; - auto ABdotAB = dot_2d(AB, AB); - auto ADdotAD = dot_2d(DA, DA); - for (int i = 0; i < 4; i++) { - // assume ABCD is the rectangle, and P is the point to be judged - // P is inside ABCD iff. P's projection on AB lies within AB - // and P's projection on AD lies within AD - - auto AP = pts1[i] - pts2[0]; - - auto APdotAB = dot_2d(AP, AB); - auto APdotAD = -dot_2d(AP, DA); - - if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { - intersections[num++] = pts1[i]; - } + + // Line test - test all line combos for intersection + int num = 0; // number of intersections + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + // Solve for 2x2 Ax=b + T det = cross_2d(vec2[j], vec1[i]); + + // This takes care of parallel lines + if (fabs(det) <= 1e-14) + { + continue; + } + + auto vec12 = pts2[j] - pts1[i]; + + T t1 = cross_2d(vec2[j], vec12) / det; + T t2 = cross_2d(vec1[i], vec12) / det; + + if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) + { + intersections[num++] = pts1[i] + vec1[i] * t1; + } + } } - } - - // Reverse the check - check for vertices of rect2 inside rect1 - { - const auto &AB = vec1[0]; - const auto &DA = vec1[3]; - auto ABdotAB = dot_2d(AB, AB); - auto ADdotAD = dot_2d(DA, DA); - for (int i = 0; i < 4; i++) { - auto AP = pts2[i] - pts1[0]; - - auto APdotAB = dot_2d(AP, AB); - auto APdotAD = -dot_2d(AP, DA); - - if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { - intersections[num++] = pts2[i]; - } + + // Check for vertices of rect1 inside rect2 + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) + { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD + + auto AP = pts1[i] - pts2[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) + { + intersections[num++] = pts1[i]; + } + } } - } - return num; + // Reverse the check - check for vertices of rect2 inside rect1 + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) + { + auto AP = pts2[i] - pts1[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) + { + intersections[num++] = pts2[i]; + } + } + } + + return num; } -template +template __host__ __device__ __forceinline__ int convex_hull_graham(const Point (&p)[24], - const int &num_in, Point (&q)[24], - bool shift_to_zero = false) { - assert(num_in >= 2); - - // Step 1: - // Find point with minimum y - // if more than 1 points have the same minimum y, - // pick the one with the minimum x. - int t = 0; - for (int i = 1; i < num_in; i++) { - if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) { - t = i; + const int& num_in, + Point (&q)[24], + bool shift_to_zero = false) +{ + assert(num_in >= 2); + + // Step 1: + // Find point with minimum y + // if more than 1 points have the same minimum y, + // pick the one with the minimum x. + int t = 0; + for (int i = 1; i < num_in; i++) + { + if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) + { + t = i; + } } - } - auto &start = p[t]; // starting point - - // Step 2: - // Subtract starting point from every points (for sorting in the next step) - for (int i = 0; i < num_in; i++) { - q[i] = p[i] - start; - } - - // Swap the starting point to position 0 - auto tmp = q[0]; - q[0] = q[t]; - q[t] = tmp; - - // Step 3: - // Sort point 1 ~ num_in according to their relative cross-product values - // (essentially sorting according to angles) - // If the angles are the same, sort according to their distance to origin - T dist[24]; - for (int i = 0; i < num_in; i++) { - dist[i] = dot_2d(q[i], q[i]); - } - - for (int i = 1; i < num_in - 1; i++) { - for (int j = i + 1; j < num_in; j++) { - T crossProduct = cross_2d(q[i], q[j]); - if ((crossProduct < -1e-6) || (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) { - auto q_tmp = q[i]; - q[i] = q[j]; - q[j] = q_tmp; - auto dist_tmp = dist[i]; - dist[i] = dist[j]; - dist[j] = dist_tmp; - } + auto& start = p[t]; // starting point + + // Step 2: + // Subtract starting point from every points (for sorting in the next step) + for (int i = 0; i < num_in; i++) + { + q[i] = p[i] - start; } - } - - // Step 4: - // Make sure there are at least 2 points (that don't overlap with each other) - // in the stack - int k; // index of the non-overlapped second point - for (k = 1; k < num_in; k++) { - if (dist[k] > 1e-8) { - break; + + // Swap the starting point to position 0 + auto tmp = q[0]; + q[0] = q[t]; + q[t] = tmp; + + // Step 3: + // Sort point 1 ~ num_in according to their relative cross-product values + // (essentially sorting according to angles) + // If the angles are the same, sort according to their distance to origin + T dist[24]; + for (int i = 0; i < num_in; i++) + { + dist[i] = dot_2d(q[i], q[i]); } - } - if (k == num_in) { - // We reach the end, which means the convex hull is just one point - q[0] = p[t]; - return 1; - } - q[1] = q[k]; - int m = 2; // 2 points in the stack - // Step 5: - // Finally we can start the scanning process. - // When a non-convex relationship between the 3 points is found - // (either concave shape or duplicated points), - // we pop the previous point from the stack - // until the 3-point relationship is convex again, or - // until the stack only contains two points - for (int i = k + 1; i < num_in; i++) { - while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) { - m--; + + for (int i = 1; i < num_in - 1; i++) + { + for (int j = i + 1; j < num_in; j++) + { + T crossProduct = cross_2d(q[i], q[j]); + if ((crossProduct < -1e-6) || (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) + { + auto q_tmp = q[i]; + q[i] = q[j]; + q[j] = q_tmp; + auto dist_tmp = dist[i]; + dist[i] = dist[j]; + dist[j] = dist_tmp; + } + } } - q[m++] = q[i]; - } - - // Step 6 (Optional): - // In general sense we need the original coordinates, so we - // need to shift the points back (reverting Step 2) - // But if we're only interested in getting the area/perimeter of the shape - // We can simply return. - if (!shift_to_zero) { - for (int i = 0; i < m; i++) { - q[i] += start; + + // Step 4: + // Make sure there are at least 2 points (that don't overlap with each other) + // in the stack + int k; // index of the non-overlapped second point + for (k = 1; k < num_in; k++) + { + if (dist[k] > 1e-8) + { + break; + } + } + if (k == num_in) + { + // We reach the end, which means the convex hull is just one point + q[0] = p[t]; + return 1; + } + q[1] = q[k]; + int m = 2; // 2 points in the stack + // Step 5: + // Finally we can start the scanning process. + // When a non-convex relationship between the 3 points is found + // (either concave shape or duplicated points), + // we pop the previous point from the stack + // until the 3-point relationship is convex again, or + // until the stack only contains two points + for (int i = k + 1; i < num_in; i++) + { + while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) + { + m--; + } + q[m++] = q[i]; } - } - return m; + // Step 6 (Optional): + // In general sense we need the original coordinates, so we + // need to shift the points back (reverting Step 2) + // But if we're only interested in getting the area/perimeter of the shape + // We can simply return. + if (!shift_to_zero) + { + for (int i = 0; i < m; i++) + { + q[i] += start; + } + } + + return m; } -template -__host__ __device__ __forceinline__ T polygon_area(const Point (&q)[24], const int &m) { - if (m <= 2) { - return 0; - } +template +__host__ __device__ __forceinline__ T polygon_area(const Point (&q)[24], const int& m) +{ + if (m <= 2) + { + return 0; + } - T area = 0; - for (int i = 1; i < m - 1; i++) { - area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); - } + T area = 0; + for (int i = 1; i < m - 1; i++) + { + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); + } - return area / 2.0; + return area / 2.0; } -template -__host__ __device__ __forceinline__ T rotated_boxes_intersection(const RotatedBox &box1, - const RotatedBox &box2) { - // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned - // from rotated_rect_intersection_pts - Point intersectPts[24], orderedPts[24]; +template +__host__ __device__ __forceinline__ T rotated_boxes_intersection(const RotatedBox& box1, + const RotatedBox& box2) +{ + // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned + // from rotated_rect_intersection_pts + Point intersectPts[24], orderedPts[24]; - Point pts1[4]; - Point pts2[4]; - get_rotated_vertices(box1, pts1); - get_rotated_vertices(box2, pts2); + Point pts1[4]; + Point pts2[4]; + get_rotated_vertices(box1, pts1); + get_rotated_vertices(box2, pts2); - int num = get_intersection_points(pts1, pts2, intersectPts); + int num = get_intersection_points(pts1, pts2, intersectPts); - if (num <= 2) { - return 0.0; - } + if (num <= 2) + { + return 0.0; + } - // Convex Hull to order the intersection points in clockwise order and find - // the contour area. - int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); - return polygon_area(orderedPts, num_convex); + // Convex Hull to order the intersection points in clockwise order and find + // the contour area. + int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); + return polygon_area(orderedPts, num_convex); } -template -__host__ __device__ __forceinline__ T single_box_iou_rotated(T const *const box1_raw, - T const *const box2_raw) { - // shift center to the middle point to achieve higher precision in result - RotatedBox box1, box2; - auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; - auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0; - box1.x_ctr = box1_raw[0] - center_shift_x; - box1.y_ctr = box1_raw[1] - center_shift_y; - box1.w = box1_raw[2]; - box1.h = box1_raw[3]; - box1.a = box1_raw[4]; - box2.x_ctr = box2_raw[0] - center_shift_x; - box2.y_ctr = box2_raw[1] - center_shift_y; - box2.w = box2_raw[2]; - box2.h = box2_raw[3]; - box2.a = box2_raw[4]; - - const T area1 = box1.w * box1.h; - const T area2 = box2.w * box2.h; - if (area1 < 1e-14 || area2 < 1e-14) { - return 1.0f; - } - - const T intersection = rotated_boxes_intersection(box1, box2); - T baseS = 1.0; - baseS = (area1 + area2 - intersection); - const T iou = intersection / baseS; - return iou; +template +__host__ __device__ __forceinline__ T single_box_iou_rotated(T const* const box1_raw, + T const* const box2_raw) +{ + // shift center to the middle point to achieve higher precision in result + RotatedBox box1, box2; + auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; + auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0; + box1.x_ctr = box1_raw[0] - center_shift_x; + box1.y_ctr = box1_raw[1] - center_shift_y; + box1.w = box1_raw[2]; + box1.h = box1_raw[3]; + box1.a = box1_raw[4]; + box2.x_ctr = box2_raw[0] - center_shift_x; + box2.y_ctr = box2_raw[1] - center_shift_y; + box2.w = box2_raw[2]; + box2.h = box2_raw[3]; + box2.a = box2_raw[4]; + + const T area1 = box1.w * box1.h; + const T area2 = box2.w * box2.h; + if (area1 < 1e-14 || area2 < 1e-14) + { + return 1.0f; + } + + const T intersection = rotated_boxes_intersection(box1, box2); + T baseS = 1.0; + baseS = (area1 + area2 - intersection); + const T iou = intersection / baseS; + return iou; } /********** new NMS for only score and index array **********/ -template -__global__ void allClassRotatedNMS_kernel(const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, - T_BBOX *bbox_data, // bbox_data should be float to - // preserve location information - T_SCORE *beforeNMS_scores, int *beforeNMS_index_array, - T_SCORE *afterNMS_scores, int *afterNMS_index_array) { - //__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE]; - extern __shared__ bool kept_bboxinfo_flag[]; - for (int i = 0; i < num; i++) { - const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class; - const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation - const int bbox_idx_offset = - share_location ? (i * num_preds_per_class) : (i * num_classes * num_preds_per_class); - - // local thread data - int loc_bboxIndex[TSIZE]; - T_BBOX loc_bbox[TSIZE * 5]; - - // initialize Bbox, Bboxinfo, kept_bboxinfo_flag - // Eliminate shared memory RAW hazard - __syncthreads(); +template +__global__ void allClassRotatedNMS_kernel(const int num, + const int num_classes, + const int num_preds_per_class, + const int top_k, + const float nms_threshold, + const bool share_location, + const bool isNormalized, + T_BBOX* bbox_data, // bbox_data should be float to preserve location information + T_SCORE* beforeNMS_scores, + int* beforeNMS_index_array, + T_SCORE* afterNMS_scores, + int* afterNMS_index_array) +{ + //__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE]; + extern __shared__ bool kept_bboxinfo_flag[]; + for (int i = 0; i < num; i++) + { + const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class; + const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation + const int bbox_idx_offset = + share_location ? (i * num_preds_per_class) : (i * num_classes * num_preds_per_class); + + // local thread data + int loc_bboxIndex[TSIZE]; + T_BBOX loc_bbox[TSIZE * 5]; + + // initialize Bbox, Bboxinfo, kept_bboxinfo_flag + // Eliminate shared memory RAW hazard + __syncthreads(); #pragma unroll - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int item_idx = offset + cur_idx; + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + + if (item_idx < max_idx) + { + loc_bboxIndex[t] = beforeNMS_index_array[item_idx]; + + if (loc_bboxIndex[t] >= 0) + // if (loc_bboxIndex[t] != -1) + { + const int bbox_data_idx = share_location ? + (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset) : + loc_bboxIndex[t]; + memcpy(&loc_bbox[t * 5], &bbox_data[bbox_data_idx * 5], 5 * sizeof(T_BBOX)); + kept_bboxinfo_flag[cur_idx] = true; + } + else + { + kept_bboxinfo_flag[cur_idx] = false; + } + } + else + { + kept_bboxinfo_flag[cur_idx] = false; + } + } - if (item_idx < max_idx) { - loc_bboxIndex[t] = beforeNMS_index_array[item_idx]; + // filter out overlapped boxes with lower scores + int ref_item_idx = offset; + int ref_bbox_idx = share_location ? + (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) : + beforeNMS_index_array[ref_item_idx]; - if (loc_bboxIndex[t] >= 0) - // if (loc_bboxIndex[t] != -1) + while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) { - const int bbox_data_idx = share_location - ? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset) - : loc_bboxIndex[t]; - memcpy(&loc_bbox[t * 5], &bbox_data[bbox_data_idx * 5], 5 * sizeof(T_BBOX)); - kept_bboxinfo_flag[cur_idx] = true; - } else { - kept_bboxinfo_flag[cur_idx] = false; + T_BBOX ref_bbox[5]; + memcpy(&ref_bbox[0], &bbox_data[ref_bbox_idx * 5], 5 * sizeof(T_BBOX)); + + // Eliminate shared memory RAW hazard + __syncthreads(); + + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + + if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) + { + // TODO: may need to add bool normalized as argument, HERE true means + // normalized + if (single_box_iou_rotated(&ref_bbox[0], loc_bbox + t * 5) > nms_threshold) + { + kept_bboxinfo_flag[cur_idx] = false; + } + } + } + __syncthreads(); + + do { + ref_item_idx++; + } while (ref_item_idx < max_idx && !kept_bboxinfo_flag[ref_item_idx - offset]); + + ref_bbox_idx = share_location ? + (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) : + beforeNMS_index_array[ref_item_idx]; } - } else { - kept_bboxinfo_flag[cur_idx] = false; - } - } - // filter out overlapped boxes with lower scores - int ref_item_idx = offset; - int ref_bbox_idx = - share_location - ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) - : beforeNMS_index_array[ref_item_idx]; - - while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) { - T_BBOX ref_bbox[5]; - memcpy(&ref_bbox[0], &bbox_data[ref_bbox_idx * 5], 5 * sizeof(T_BBOX)); - - // Eliminate shared memory RAW hazard - __syncthreads(); - - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int item_idx = offset + cur_idx; - - if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) { - // TODO: may need to add bool normalized as argument, HERE true means - // normalized - if (single_box_iou_rotated(&ref_bbox[0], loc_bbox + t * 5) > nms_threshold) { - kept_bboxinfo_flag[cur_idx] = false; - } + // store data + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int read_item_idx = offset + cur_idx; + const int write_item_idx = (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx; + /* + * If not not keeping the bbox + * Set the score to 0 + * Set the bounding box index to -1 + */ + if (read_item_idx < max_idx) + { + afterNMS_scores[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? + beforeNMS_scores[read_item_idx] : + 0.0f; + afterNMS_index_array[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1; + } } - } - __syncthreads(); - - do { - ref_item_idx++; - } while (ref_item_idx < max_idx && !kept_bboxinfo_flag[ref_item_idx - offset]); - - ref_bbox_idx = - share_location - ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) - : beforeNMS_index_array[ref_item_idx]; } - - // store data - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int read_item_idx = offset + cur_idx; - const int write_item_idx = (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx; - /* - * If not not keeping the bbox - * Set the score to 0 - * Set the bounding box index to -1 - */ - if (read_item_idx < max_idx) { - afterNMS_scores[write_item_idx] = - kept_bboxinfo_flag[cur_idx] ? beforeNMS_scores[read_item_idx] : 0.0f; - afterNMS_index_array[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1; - } - } - } } -template -pluginStatus_t allClassRotatedNMS_gpu(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, void *bbox_data, - void *beforeNMS_scores, void *beforeNMS_index_array, - void *afterNMS_scores, void *afterNMS_index_array) { +template +pluginStatus_t allClassRotatedNMS_gpu(cudaStream_t stream, + const int num, + const int num_classes, + const int num_preds_per_class, + const int top_k, + const float nms_threshold, + const bool share_location, + const bool isNormalized, + void* bbox_data, + void* beforeNMS_scores, + void* beforeNMS_index_array, + void* afterNMS_scores, + void* afterNMS_index_array) +{ #define P(tsize) allClassRotatedNMS_kernel - void (*kernel[10])(const int, const int, const int, const int, const float, const bool, - const bool, float *, T_SCORE *, int *, T_SCORE *, int *) = { - P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10), - }; - - const int BS = 512; - const int GS = num_classes; - const int t_size = (top_k + BS - 1) / BS; - - ASSERT(t_size <= 10); - kernel[t_size - 1]<<>>( - num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized, - (T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array, - (T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array); - - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; + void (*kernel[10])(const int, + const int, + const int, + const int, + const float, + const bool, + const bool, + float*, + T_SCORE*, + int*, + T_SCORE*, + int*) = { + P(1), + P(2), + P(3), + P(4), + P(5), + P(6), + P(7), + P(8), + P(9), + P(10), + }; + + const int BS = 512; + const int GS = num_classes; + const int t_size = (top_k + BS - 1) / BS; + + ASSERT(t_size <= 10); + kernel[t_size - 1]<<>>(num, + num_classes, + num_preds_per_class, + top_k, + nms_threshold, + share_location, + isNormalized, + (T_BBOX*)bbox_data, + (T_SCORE*)beforeNMS_scores, + (int*)beforeNMS_index_array, + (T_SCORE*)afterNMS_scores, + (int*)afterNMS_index_array); + + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // allClassNMS LAUNCH CONFIG -typedef pluginStatus_t (*rotatedNmsFunc)(cudaStream_t, const int, const int, const int, const int, - const float, const bool, const bool, void *, void *, - void *, void *, void *); - -struct rotatedNmsLaunchConfig { - DataType t_score; - DataType t_bbox; - rotatedNmsFunc function; - - rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox) : t_score(t_score), t_bbox(t_bbox) {} - rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox, rotatedNmsFunc function) - : t_score(t_score), t_bbox(t_bbox), function(function) {} - bool operator==(const rotatedNmsLaunchConfig &other) { - return t_score == other.t_score && t_bbox == other.t_bbox; - } +typedef pluginStatus_t (*rotatedNmsFunc)(cudaStream_t, + const int, + const int, + const int, + const int, + const float, + const bool, + const bool, + void*, + void*, + void*, + void*, + void*); + +struct rotatedNmsLaunchConfig +{ + DataType t_score; + DataType t_bbox; + rotatedNmsFunc function; + + rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox) + : t_score(t_score) + , t_bbox(t_bbox) + { + } + rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox, rotatedNmsFunc function) + : t_score(t_score) + , t_bbox(t_bbox) + , function(function) + { + } + bool operator==(const rotatedNmsLaunchConfig& other) + { + return t_score == other.t_score && t_bbox == other.t_bbox; + } }; static std::vector rotatedNmsFuncVec; -bool rotatedNmsInit() { - rotatedNmsFuncVec.push_back(rotatedNmsLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, - allClassRotatedNMS_gpu)); - return true; +bool rotatedNmsInit() +{ + rotatedNmsFuncVec.push_back(rotatedNmsLaunchConfig(DataType::kFLOAT, + DataType::kFLOAT, + allClassRotatedNMS_gpu)); + return true; } -static bool initialized = rotatedNmsInit(); - -pluginStatus_t allClassRotatedNMS(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, const DataType DT_SCORE, - const DataType DT_BBOX, void *bbox_data, void *beforeNMS_scores, - void *beforeNMS_index_array, void *afterNMS_scores, - void *afterNMS_index_array, bool) { - auto __cuda_arch__ = get_cuda_arch(0); // assume there is only one arch 7.2 device - if (__cuda_arch__ == 720 && top_k >= 1000) { - printf("Warning: pre_top_k need to be reduced for devices with arch 7.2, got pre_top_k=%d\n", - top_k); - } - rotatedNmsLaunchConfig lc(DT_SCORE, DT_BBOX); - - for (unsigned i = 0; i < rotatedNmsFuncVec.size(); ++i) { - if (lc == rotatedNmsFuncVec[i]) { - DEBUG_PRINTF("all class rotated nms kernel %d\n", i); - return rotatedNmsFuncVec[i].function(stream, num, num_classes, num_preds_per_class, top_k, - nms_threshold, share_location, isNormalized, bbox_data, - beforeNMS_scores, beforeNMS_index_array, afterNMS_scores, - afterNMS_index_array); +static bool initialized = rotatedNmsInit(); + +pluginStatus_t allClassRotatedNMS(cudaStream_t stream, + const int num, + const int num_classes, + const int num_preds_per_class, + const int top_k, + const float nms_threshold, + const bool share_location, + const bool isNormalized, + const DataType DT_SCORE, + const DataType DT_BBOX, + void* bbox_data, + void* beforeNMS_scores, + void* beforeNMS_index_array, + void* afterNMS_scores, + void* afterNMS_index_array, + bool) +{ + auto __cuda_arch__ = get_cuda_arch(0); // assume there is only one arch 7.2 device + if (__cuda_arch__ == 720 && top_k >= 1000) + { + printf("Warning: pre_top_k need to be reduced for devices with arch 7.2, got pre_top_k=%d\n", + top_k); + } + rotatedNmsLaunchConfig lc(DT_SCORE, DT_BBOX); + + for (unsigned i = 0; i < rotatedNmsFuncVec.size(); ++i) + { + if (lc == rotatedNmsFuncVec[i]) + { + DEBUG_PRINTF("all class rotated nms kernel %d\n", i); + return rotatedNmsFuncVec[i].function(stream, + num, + num_classes, + num_preds_per_class, + top_k, + nms_threshold, + share_location, + isNormalized, + bbox_data, + beforeNMS_scores, + beforeNMS_index_array, + afterNMS_scores, + afterNMS_index_array); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp index 71cb7a8592..b5f880d87f 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp @@ -3,123 +3,215 @@ // https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin #include "nms/batched_nms_kernel.hpp" -pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatchBoxesSize, - const int perBatchScoresSize, const bool shareLocation, - const int backgroundLabelId, const int numPredsPerClass, - const int numClasses, const int topK, const int keepTopK, - const float scoreThreshold, const float iouThreshold, - const DataType DT_BBOX, const void* locData, const DataType DT_SCORE, - const void* confData, void* nmsedDets, void* nmsedLabels, - void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid, - bool clipBoxes, bool rotated) { - const int topKVal = topK < 0 ? numPredsPerClass : topK; - const int keepTopKVal = keepTopK < 0 ? numPredsPerClass : keepTopK; - // locCount = batch_size * number_boxes_per_sample * 4 - const int locCount = N * perBatchBoxesSize; - /* - * shareLocation - * Bounding box are shared among all classes, i.e., a bounding box could be - * classified as any candidate class. Otherwise Bounding box are designed for - * specific classes, i.e., a bounding box could be classified as one certain - * class or not (binary classification). - */ - const int numLocClasses = shareLocation ? 1 : numClasses; - - size_t bboxDataSize = detectionForwardBBoxDataSize(N, perBatchBoxesSize, DataType::kFLOAT); - void* bboxDataRaw = workspace; - cudaMemcpyAsync(bboxDataRaw, locData, bboxDataSize, cudaMemcpyDeviceToDevice, stream); - pluginStatus_t status; - - /* - * bboxDataRaw format: - * [batch size, numPriors (per sample), numLocClasses, 4] - */ - // float for now - void* bboxData; - size_t bboxPermuteSize = - detectionForwardBBoxPermuteSize(shareLocation, N, perBatchBoxesSize, DataType::kFLOAT); - void* bboxPermute = nextWorkspacePtr((int8_t*)bboxDataRaw, bboxDataSize); - - /* - * After permutation, bboxData format: - * [batch_size, numLocClasses, numPriors (per sample) (numPredsPerClass), 4] - * This is equivalent to swapping axis - */ - if (!shareLocation) { - status = permuteData(stream, locCount, numLocClasses, numPredsPerClass, rotated ? 5 : 4, - DataType::kFLOAT, false, bboxDataRaw, bboxPermute); +pluginStatus_t nmsInference(cudaStream_t stream, + const int N, + const int perBatchBoxesSize, + const int perBatchScoresSize, + const bool shareLocation, + const int backgroundLabelId, + const int numPredsPerClass, + const int numClasses, + const int topK, + const int keepTopK, + const float scoreThreshold, + const float iouThreshold, + const DataType DT_BBOX, + const void* locData, + const DataType DT_SCORE, + const void* confData, + void* nmsedDets, + void* nmsedLabels, + void* nmsedIndex, + void* workspace, + bool isNormalized, + bool confSigmoid, + bool clipBoxes, + bool rotated) +{ + const int topKVal = topK < 0 ? numPredsPerClass : topK; + const int keepTopKVal = keepTopK < 0 ? numPredsPerClass : keepTopK; + // locCount = batch_size * number_boxes_per_sample * 4 + const int locCount = N * perBatchBoxesSize; + /* + * shareLocation + * Bounding box are shared among all classes, i.e., a bounding box could be + * classified as any candidate class. Otherwise Bounding box are designed for + * specific classes, i.e., a bounding box could be classified as one certain + * class or not (binary classification). + */ + const int numLocClasses = shareLocation ? 1 : numClasses; + + size_t bboxDataSize = detectionForwardBBoxDataSize(N, perBatchBoxesSize, DataType::kFLOAT); + void* bboxDataRaw = workspace; + cudaMemcpyAsync(bboxDataRaw, locData, bboxDataSize, cudaMemcpyDeviceToDevice, stream); + pluginStatus_t status; + + /* + * bboxDataRaw format: + * [batch size, numPriors (per sample), numLocClasses, 4] + */ + // float for now + void* bboxData; + size_t bboxPermuteSize = detectionForwardBBoxPermuteSize(shareLocation, + N, + perBatchBoxesSize, + DataType::kFLOAT); + void* bboxPermute = nextWorkspacePtr((int8_t*)bboxDataRaw, bboxDataSize); + + /* + * After permutation, bboxData format: + * [batch_size, numLocClasses, numPriors (per sample) (numPredsPerClass), 4] + * This is equivalent to swapping axis + */ + if (!shareLocation) + { + status = permuteData(stream, + locCount, + numLocClasses, + numPredsPerClass, + rotated ? 5 : 4, + DataType::kFLOAT, + false, + bboxDataRaw, + bboxPermute); + ASSERT_FAILURE(status == STATUS_SUCCESS); + bboxData = bboxPermute; + } + /* + * If shareLocation, numLocClasses = 1 + * No need to permute data on linear memory + */ + else + { + bboxData = bboxDataRaw; + } + + /* + * Conf data format + * [batch size, numPriors * param.numClasses, 1, 1] + */ + const int numScores = N * perBatchScoresSize; + size_t totalScoresSize = detectionForwardPreNMSSize(N, perBatchScoresSize); + void* scores = nextWorkspacePtr((int8_t*)bboxPermute, bboxPermuteSize); + + // need a conf_scores + /* + * After permutation, bboxData format: + * [batch_size, numClasses, numPredsPerClass, 1] + */ + status = permuteData(stream, + numScores, + numClasses, + numPredsPerClass, + 1, + DataType::kFLOAT, + confSigmoid, + confData, + scores); ASSERT_FAILURE(status == STATUS_SUCCESS); - bboxData = bboxPermute; - } - /* - * If shareLocation, numLocClasses = 1 - * No need to permute data on linear memory - */ - else { - bboxData = bboxDataRaw; - } - - /* - * Conf data format - * [batch size, numPriors * param.numClasses, 1, 1] - */ - const int numScores = N * perBatchScoresSize; - size_t totalScoresSize = detectionForwardPreNMSSize(N, perBatchScoresSize); - void* scores = nextWorkspacePtr((int8_t*)bboxPermute, bboxPermuteSize); - - // need a conf_scores - /* - * After permutation, bboxData format: - * [batch_size, numClasses, numPredsPerClass, 1] - */ - status = permuteData(stream, numScores, numClasses, numPredsPerClass, 1, DataType::kFLOAT, - confSigmoid, confData, scores); - ASSERT_FAILURE(status == STATUS_SUCCESS); - - size_t indicesSize = detectionForwardPreNMSSize(N, perBatchScoresSize); - void* indices = nextWorkspacePtr((int8_t*)scores, totalScoresSize); - - size_t postNMSScoresSize = detectionForwardPostNMSSize(N, numClasses, topKVal); - size_t postNMSIndicesSize = detectionForwardPostNMSSize(N, numClasses, topKVal); - void* postNMSScores = nextWorkspacePtr((int8_t*)indices, indicesSize); - void* postNMSIndices = nextWorkspacePtr((int8_t*)postNMSScores, postNMSScoresSize); - - void* sortingWorkspace = nextWorkspacePtr((int8_t*)postNMSIndices, postNMSIndicesSize); - // Sort the scores so that the following NMS could be applied. - - status = sortScoresPerClass(stream, N, numClasses, numPredsPerClass, backgroundLabelId, - scoreThreshold, DataType::kFLOAT, scores, indices, sortingWorkspace); - ASSERT_FAILURE(status == STATUS_SUCCESS); - - // This is set to true as the input bounding boxes are of the format [ymin, - // xmin, ymax, xmax]. The default implementation assumes [xmin, ymin, xmax, - // ymax] - bool flipXY = false; - // NMS - if (rotated) { - status = allClassRotatedNMS(stream, N, numClasses, numPredsPerClass, topKVal, iouThreshold, - shareLocation, isNormalized, DataType::kFLOAT, DataType::kFLOAT, - bboxData, scores, indices, postNMSScores, postNMSIndices, flipXY); - } else { - status = allClassNMS(stream, N, numClasses, numPredsPerClass, topKVal, iouThreshold, - shareLocation, isNormalized, DataType::kFLOAT, DataType::kFLOAT, bboxData, - scores, indices, postNMSScores, postNMSIndices, flipXY); - } - - ASSERT_FAILURE(status == STATUS_SUCCESS); - - // Sort the bounding boxes after NMS using scores - status = sortScoresPerImage(stream, N, numClasses * topKVal, DataType::kFLOAT, postNMSScores, - postNMSIndices, scores, indices, sortingWorkspace); - - ASSERT_FAILURE(status == STATUS_SUCCESS); - - // Gather data from the sorted bounding boxes after NMS - status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass, numClasses, topKVal, - keepTopKVal, DataType::kFLOAT, DataType::kFLOAT, indices, scores, - bboxData, nmsedDets, nmsedLabels, nmsedIndex, clipBoxes, rotated); - - ASSERT_FAILURE(status == STATUS_SUCCESS); - - return STATUS_SUCCESS; + + size_t indicesSize = detectionForwardPreNMSSize(N, perBatchScoresSize); + void* indices = nextWorkspacePtr((int8_t*)scores, totalScoresSize); + + size_t postNMSScoresSize = detectionForwardPostNMSSize(N, numClasses, topKVal); + size_t postNMSIndicesSize = detectionForwardPostNMSSize(N, numClasses, topKVal); + void* postNMSScores = nextWorkspacePtr((int8_t*)indices, indicesSize); + void* postNMSIndices = nextWorkspacePtr((int8_t*)postNMSScores, postNMSScoresSize); + + void* sortingWorkspace = nextWorkspacePtr((int8_t*)postNMSIndices, postNMSIndicesSize); + // Sort the scores so that the following NMS could be applied. + + status = sortScoresPerClass(stream, + N, + numClasses, + numPredsPerClass, + backgroundLabelId, + scoreThreshold, + DataType::kFLOAT, + scores, + indices, + sortingWorkspace); + ASSERT_FAILURE(status == STATUS_SUCCESS); + + // This is set to true as the input bounding boxes are of the format [ymin, + // xmin, ymax, xmax]. The default implementation assumes [xmin, ymin, xmax, + // ymax] + bool flipXY = false; + // NMS + if (rotated) + { + status = allClassRotatedNMS(stream, + N, + numClasses, + numPredsPerClass, + topKVal, + iouThreshold, + shareLocation, + isNormalized, + DataType::kFLOAT, + DataType::kFLOAT, + bboxData, + scores, + indices, + postNMSScores, + postNMSIndices, + flipXY); + } + else + { + status = allClassNMS(stream, + N, + numClasses, + numPredsPerClass, + topKVal, + iouThreshold, + shareLocation, + isNormalized, + DataType::kFLOAT, + DataType::kFLOAT, + bboxData, + scores, + indices, + postNMSScores, + postNMSIndices, + flipXY); + } + + ASSERT_FAILURE(status == STATUS_SUCCESS); + + // Sort the bounding boxes after NMS using scores + status = sortScoresPerImage(stream, + N, + numClasses * topKVal, + DataType::kFLOAT, + postNMSScores, + postNMSIndices, + scores, + indices, + sortingWorkspace); + + ASSERT_FAILURE(status == STATUS_SUCCESS); + + // Gather data from the sorted bounding boxes after NMS + status = gatherNMSOutputs(stream, + shareLocation, + N, + numPredsPerClass, + numClasses, + topKVal, + keepTopKVal, + DataType::kFLOAT, + DataType::kFLOAT, + indices, + scores, + bboxData, + nmsedDets, + nmsedLabels, + nmsedIndex, + clipBoxes, + rotated); + + ASSERT_FAILURE(status == STATUS_SUCCESS); + + return STATUS_SUCCESS; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu index 58419f8c16..803924a4ee 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu @@ -6,159 +6,237 @@ #include "nms/kernel.h" #include "trt_plugin_helper.hpp" -template -__launch_bounds__(nthds_per_cta) __global__ - void gatherNMSOutputs_kernel(const bool shareLocation, const int numImages, - const int numPredsPerClass, const int numClasses, const int topK, - const int keepTopK, const int *indices, const T_SCORE *scores, - const T_BBOX *bboxData, T_BBOX *nmsedDets, int *nmsedLabels, - int *nmsedIndex, bool clipBoxes) { - if (keepTopK > topK) return; - for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; i < numImages * keepTopK; - i += gridDim.x * nthds_per_cta) { - const int imgId = i / keepTopK; - const int detId = i % keepTopK; - const int offset = imgId * numClasses * topK; - const int index = indices[offset + detId]; - const T_SCORE score = scores[offset + detId]; - if (index == -1) { - nmsedLabels[i] = -1; - if (nmsedIndex != nullptr) { - nmsedIndex[i] = -1; - } - if (rotated) { - nmsedDets[i * 6] = 0; - nmsedDets[i * 6 + 1] = 0; - nmsedDets[i * 6 + 2] = 0; - nmsedDets[i * 6 + 3] = 0; - nmsedDets[i * 6 + 4] = 0; - nmsedDets[i * 6 + 5] = 0; - } else { - nmsedDets[i * 5] = 0; - nmsedDets[i * 5 + 1] = 0; - nmsedDets[i * 5 + 2] = 0; - nmsedDets[i * 5 + 3] = 0; - nmsedDets[i * 5 + 4] = 0; - } - } else { - const int bboxOffset = - imgId * (shareLocation ? numPredsPerClass : (numClasses * numPredsPerClass)); - nmsedLabels[i] = (index % (numClasses * numPredsPerClass)) / numPredsPerClass; // label - if (rotated) { - const int bboxId = ((shareLocation ? (index % numPredsPerClass) - : index % (numClasses * numPredsPerClass)) + - bboxOffset) * - 5; - if (nmsedIndex != nullptr) { - nmsedIndex[i] = bboxId / 5 - bboxOffset; +template +__launch_bounds__(nthds_per_cta) __global__ void gatherNMSOutputs_kernel(const bool shareLocation, + const int numImages, + const int numPredsPerClass, + const int numClasses, + const int topK, + const int keepTopK, + const int* indices, + const T_SCORE* scores, + const T_BBOX* bboxData, + T_BBOX* nmsedDets, + int* nmsedLabels, + int* nmsedIndex, + bool clipBoxes) +{ + if (keepTopK > topK) return; + + for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; i < numImages * keepTopK; i += gridDim.x * nthds_per_cta) + { + const int imgId = i / keepTopK; + const int detId = i % keepTopK; + const int offset = imgId * numClasses * topK; + const int index = indices[offset + detId]; + const T_SCORE score = scores[offset + detId]; + + if (index == -1) + { + nmsedLabels[i] = -1; + if (nmsedIndex != nullptr) + { + nmsedIndex[i] = -1; + } + if (rotated) + { + nmsedDets[i * 6] = 0; + nmsedDets[i * 6 + 1] = 0; + nmsedDets[i * 6 + 2] = 0; + nmsedDets[i * 6 + 3] = 0; + nmsedDets[i * 6 + 4] = 0; + nmsedDets[i * 6 + 5] = 0; + } + else + { + nmsedDets[i * 5] = 0; + nmsedDets[i * 5 + 1] = 0; + nmsedDets[i * 5 + 2] = 0; + nmsedDets[i * 5 + 3] = 0; + nmsedDets[i * 5 + 4] = 0; + } } - // clipped bbox xmin - nmsedDets[i * 6] = - clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; - // clipped bbox ymin - nmsedDets[i * 6 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 1]; - // clipped bbox xmax - nmsedDets[i * 6 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 2]; - // clipped bbox ymax - nmsedDets[i * 6 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 3]; - // clipped bbox angle - nmsedDets[i * 6 + 4] = clipBoxes ? max(min(bboxData[bboxId + 4], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 4]; - nmsedDets[i * 6 + 5] = score; - } else { - const int bboxId = ((shareLocation ? (index % numPredsPerClass) - : index % (numClasses * numPredsPerClass)) + - bboxOffset) * - 4; - if (nmsedIndex != nullptr) { - nmsedIndex[i] = bboxId / 4 - bboxOffset; + else + { + const int bboxOffset = + imgId * (shareLocation ? numPredsPerClass : (numClasses * numPredsPerClass)); + nmsedLabels[i] = (index % (numClasses * numPredsPerClass)) / numPredsPerClass; // label + if (rotated) + { + const int bboxId = ((shareLocation ? (index % numPredsPerClass) : index % (numClasses * numPredsPerClass)) + + bboxOffset) * + 5; + if (nmsedIndex != nullptr) + { + nmsedIndex[i] = bboxId / 5 - bboxOffset; + } + // clipped bbox xmin + nmsedDets[i * 6] = + clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; + // clipped bbox ymin + nmsedDets[i * 6 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 1]; + // clipped bbox xmax + nmsedDets[i * 6 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 2]; + // clipped bbox ymax + nmsedDets[i * 6 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 3]; + // clipped bbox angle + nmsedDets[i * 6 + 4] = clipBoxes ? max(min(bboxData[bboxId + 4], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 4]; + nmsedDets[i * 6 + 5] = score; + } + else + { + const int bboxId = ((shareLocation ? (index % numPredsPerClass) : index % (numClasses * numPredsPerClass)) + + bboxOffset) * + 4; + if (nmsedIndex != nullptr) + { + nmsedIndex[i] = bboxId / 4 - bboxOffset; + } + // clipped bbox xmin + nmsedDets[i * 5] = + clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; + // clipped bbox ymin + nmsedDets[i * 5 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 1]; + // clipped bbox xmax + nmsedDets[i * 5 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 2]; + // clipped bbox ymax + nmsedDets[i * 5 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 3]; + nmsedDets[i * 5 + 4] = score; + } } - // clipped bbox xmin - nmsedDets[i * 5] = - clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; - // clipped bbox ymin - nmsedDets[i * 5 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 1]; - // clipped bbox xmax - nmsedDets[i * 5 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 2]; - // clipped bbox ymax - nmsedDets[i * 5 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 3]; - nmsedDets[i * 5 + 4] = score; - } } - } } -template -pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocation, - const int numImages, const int numPredsPerClass, - const int numClasses, const int topK, const int keepTopK, - const void *indices, const void *scores, const void *bboxData, - void *nmsedDets, void *nmsedLabels, void *nmsedIndex, - bool clipBoxes) { - const int BS = 32; - const int GS = 32; - gatherNMSOutputs_kernel<<>>( - shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK, (int *)indices, - (T_SCORE *)scores, (T_BBOX *)bboxData, (T_BBOX *)nmsedDets, (int *)nmsedLabels, - (int *)nmsedIndex, clipBoxes); +template +pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, + const bool shareLocation, + const int numImages, + const int numPredsPerClass, + const int numClasses, + const int topK, + const int keepTopK, + const void* indices, + const void* scores, + const void* bboxData, + void* nmsedDets, + void* nmsedLabels, + void* nmsedIndex, + bool clipBoxes) +{ + const int BS = 32; + const int GS = 32; + gatherNMSOutputs_kernel<<>>( + shareLocation, + numImages, + numPredsPerClass, + numClasses, + topK, + keepTopK, + (int*)indices, + (T_SCORE*)scores, + (T_BBOX*)bboxData, + (T_BBOX*)nmsedDets, + (int*)nmsedLabels, + (int*)nmsedIndex, + clipBoxes); - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // gatherNMSOutputs LAUNCH CONFIG {{{ -typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int, const int, const int, - const int, const int, const void *, const void *, const void *, - void *, void *, void *, bool); -struct nmsOutLaunchConfig { - DataType t_bbox; - DataType t_score; - bool rotated; - nmsOutFunc function; +typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, + const bool, + const int, + const int, + const int, + const int, + const int, + const void*, + const void*, + const void*, + void*, + void*, + void*, + bool); +struct nmsOutLaunchConfig +{ + DataType t_bbox; + DataType t_score; + bool rotated; + nmsOutFunc function; - nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated) - : t_bbox(t_bbox), t_score(t_score), rotated(rotated) {} - nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated, nmsOutFunc function) - : t_bbox(t_bbox), t_score(t_score), rotated(rotated), function(function) {} - bool operator==(const nmsOutLaunchConfig &other) { - return t_bbox == other.t_bbox && t_score == other.t_score && rotated == other.rotated; - } + nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated) + : t_bbox(t_bbox) + , t_score(t_score) + , rotated(rotated) + { + } + nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated, nmsOutFunc function) + : t_bbox(t_bbox) + , t_score(t_score) + , rotated(rotated) + , function(function) + { + } + bool operator==(const nmsOutLaunchConfig& other) + { + return t_bbox == other.t_bbox && t_score == other.t_score && rotated == other.rotated; + } }; using nvinfer1::DataType; static std::vector nmsOutFuncVec; -bool nmsOutputInit() { - nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, false, - gatherNMSOutputs_gpu)); - nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, true, - gatherNMSOutputs_gpu)); - return true; +bool nmsOutputInit() +{ + nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, false, gatherNMSOutputs_gpu)); + nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, true, gatherNMSOutputs_gpu)); + return true; } -static bool initialized = nmsOutputInit(); +static bool initialized = nmsOutputInit(); -pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation, const int numImages, - const int numPredsPerClass, const int numClasses, const int topK, - const int keepTopK, const DataType DT_BBOX, const DataType DT_SCORE, - const void *indices, const void *scores, const void *bboxData, - void *nmsedDets, void *nmsedLabels, void *nmsedIndex, - bool clipBoxes, bool rotated) { - nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE, rotated); - for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) { - if (lc == nmsOutFuncVec[i]) { - DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i); - return nmsOutFuncVec[i].function(stream, shareLocation, numImages, numPredsPerClass, - numClasses, topK, keepTopK, indices, scores, bboxData, - nmsedDets, nmsedLabels, nmsedIndex, clipBoxes); +pluginStatus_t gatherNMSOutputs(cudaStream_t stream, + const bool shareLocation, + const int numImages, + const int numPredsPerClass, + const int numClasses, + const int topK, + const int keepTopK, + const DataType DT_BBOX, + const DataType DT_SCORE, + const void* indices, + const void* scores, + const void* bboxData, + void* nmsedDets, + void* nmsedLabels, + void* nmsedIndex, + bool clipBoxes, + bool rotated) +{ + nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE, rotated); + for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) + { + if (lc == nmsOutFuncVec[i]) + { + DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i); + return nmsOutFuncVec[i].function(stream, + shareLocation, + numImages, + numPredsPerClass, + numClasses, + topK, + keepTopK, + indices, + scores, + bboxData, + nmsedDets, + nmsedLabels, + nmsedIndex, + clipBoxes); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/kernel.cu index f0e1c9d0cc..36228de174 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/kernel.cu @@ -12,96 +12,118 @@ #define CUDA_MEM_ALIGN 256 // return cuda arch -size_t get_cuda_arch(int devID) { - int computeMode = -1, major = 0, minor = 0; - CUASSERT(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, devID)); - CUASSERT(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, devID)); - CUASSERT(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, devID)); - return major * 100 + minor * 10; +size_t get_cuda_arch(int devID) +{ + int computeMode = -1, major = 0, minor = 0; + CUASSERT(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, devID)); + CUASSERT(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, devID)); + CUASSERT(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, devID)); + return major * 100 + minor * 10; } // ALIGNPTR -int8_t *alignPtr(int8_t *ptr, uintptr_t to) { - uintptr_t addr = (uintptr_t)ptr; - if (addr % to) { - addr += to - addr % to; - } - return (int8_t *)addr; +int8_t* alignPtr(int8_t* ptr, uintptr_t to) +{ + uintptr_t addr = (uintptr_t)ptr; + if (addr % to) + { + addr += to - addr % to; + } + return (int8_t*)addr; } // NEXTWORKSPACEPTR -int8_t *nextWorkspacePtr(int8_t *ptr, uintptr_t previousWorkspaceSize) { - uintptr_t addr = (uintptr_t)ptr; - addr += previousWorkspaceSize; - return alignPtr((int8_t *)addr, CUDA_MEM_ALIGN); +int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) +{ + uintptr_t addr = (uintptr_t)ptr; + addr += previousWorkspaceSize; + return alignPtr((int8_t*)addr, CUDA_MEM_ALIGN); } // CALCULATE TOTAL WORKSPACE SIZE -size_t calculateTotalWorkspaceSize(size_t *workspaces, int count) { - size_t total = 0; - for (int i = 0; i < count; i++) { - total += workspaces[i]; - if (workspaces[i] % CUDA_MEM_ALIGN) { - total += CUDA_MEM_ALIGN - (workspaces[i] % CUDA_MEM_ALIGN); +size_t calculateTotalWorkspaceSize(size_t* workspaces, int count) +{ + size_t total = 0; + for (int i = 0; i < count; i++) + { + total += workspaces[i]; + if (workspaces[i] % CUDA_MEM_ALIGN) + { + total += CUDA_MEM_ALIGN - (workspaces[i] % CUDA_MEM_ALIGN); + } } - } - return total; + return total; } using nvinfer1::DataType; -template -__launch_bounds__(nthds_per_cta) __global__ - void setUniformOffsets_kernel(const int num_segments, const int offset, int *d_offsets) { - const int idx = blockIdx.x * nthds_per_cta + threadIdx.x; - if (idx <= num_segments) d_offsets[idx] = idx * offset; +template +__launch_bounds__(nthds_per_cta) __global__ void setUniformOffsets_kernel(const int num_segments, + const int offset, + int* d_offsets) +{ + const int idx = blockIdx.x * nthds_per_cta + threadIdx.x; + if (idx <= num_segments) d_offsets[idx] = idx * offset; } -void setUniformOffsets(cudaStream_t stream, const int num_segments, const int offset, - int *d_offsets) { - const int BS = 32; - const int GS = (num_segments + 1 + BS - 1) / BS; - setUniformOffsets_kernel<<>>(num_segments, offset, d_offsets); +void setUniformOffsets(cudaStream_t stream, const int num_segments, const int offset, int* d_offsets) +{ + const int BS = 32; + const int GS = (num_segments + 1 + BS - 1) / BS; + setUniformOffsets_kernel<<>>(num_segments, offset, d_offsets); } -size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX) { - if (DT_BBOX == DataType::kFLOAT) { - return N * C1 * sizeof(float); - } +size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX) +{ + if (DT_BBOX == DataType::kFLOAT) + { + return N * C1 * sizeof(float); + } - printf("Only FP32 type bounding boxes are supported.\n"); - return (size_t)-1; + printf("Only FP32 type bounding boxes are supported.\n"); + return (size_t)-1; } -size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1, DataType DT_BBOX) { - if (DT_BBOX == DataType::kFLOAT) { - return shareLocation ? 0 : N * C1 * sizeof(float); - } - printf("Only FP32 type bounding boxes are supported.\n"); - return (size_t)-1; +size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1, DataType DT_BBOX) +{ + if (DT_BBOX == DataType::kFLOAT) + { + return shareLocation ? 0 : N * C1 * sizeof(float); + } + printf("Only FP32 type bounding boxes are supported.\n"); + return (size_t)-1; } -size_t detectionForwardPreNMSSize(int N, int C2) { - ASSERT(sizeof(float) == sizeof(int)); - return N * C2 * sizeof(float); +size_t detectionForwardPreNMSSize(int N, int C2) +{ + ASSERT(sizeof(float) == sizeof(int)); + return N * C2 * sizeof(float); } -size_t detectionForwardPostNMSSize(int N, int numClasses, int topK) { - ASSERT(sizeof(float) == sizeof(int)); - return N * numClasses * topK * sizeof(float); +size_t detectionForwardPostNMSSize(int N, int numClasses, int topK) +{ + ASSERT(sizeof(float) == sizeof(int)); + return N * numClasses * topK * sizeof(float); } -size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses, - int numPredsPerClass, int topK, DataType DT_BBOX, - DataType DT_SCORE) { - size_t wss[7]; - wss[0] = detectionForwardBBoxDataSize(N, C1, DT_BBOX); - wss[1] = detectionForwardBBoxPermuteSize(shareLocation, N, C1, DT_BBOX); - wss[2] = detectionForwardPreNMSSize(N, C2); - wss[3] = detectionForwardPreNMSSize(N, C2); - wss[4] = detectionForwardPostNMSSize(N, numClasses, topK); - wss[5] = detectionForwardPostNMSSize(N, numClasses, topK); - wss[6] = std::max(sortScoresPerClassWorkspaceSize(N, numClasses, numPredsPerClass, DT_SCORE), - sortScoresPerImageWorkspaceSize(N, numClasses * topK, DT_SCORE)); - return calculateTotalWorkspaceSize(wss, 7); +size_t detectionInferenceWorkspaceSize(bool shareLocation, + int N, + int C1, + int C2, + int numClasses, + int numPredsPerClass, + int topK, + DataType DT_BBOX, + DataType DT_SCORE) +{ + size_t wss[7]; + wss[0] = detectionForwardBBoxDataSize(N, C1, DT_BBOX); + wss[1] = detectionForwardBBoxPermuteSize(shareLocation, N, C1, DT_BBOX); + wss[2] = detectionForwardPreNMSSize(N, C2); + wss[3] = detectionForwardPreNMSSize(N, C2); + wss[4] = detectionForwardPostNMSSize(N, numClasses, topK); + wss[5] = detectionForwardPostNMSSize(N, numClasses, topK); + wss[6] = std::max(sortScoresPerClassWorkspaceSize(N, numClasses, numPredsPerClass, DT_SCORE), + sortScoresPerImageWorkspaceSize(N, numClasses * topK, DT_SCORE)); + return calculateTotalWorkspaceSize(wss, 7); } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/permuteData.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/permuteData.cu index 659c964970..327536d8b1 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/permuteData.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/permuteData.cu @@ -5,72 +5,120 @@ #include "nms/kernel.h" -template -__launch_bounds__(nthds_per_cta) __global__ - void permuteData_kernel(const int nthreads, const int num_classes, const int num_data, - const int num_dim, bool confSigmoid, const Dtype *data, - Dtype *new_data) { - // data format: [batch_size, num_data, num_classes, num_dim] - for (int index = blockIdx.x * nthds_per_cta + threadIdx.x; index < nthreads; - index += nthds_per_cta * gridDim.x) { - const int i = index % num_dim; - const int c = (index / num_dim) % num_classes; - const int d = (index / num_dim / num_classes) % num_data; - const int n = index / num_dim / num_classes / num_data; - const int new_index = ((n * num_classes + c) * num_data + d) * num_dim + i; - float result = data[index]; - if (confSigmoid) result = exp(result) / (1 + exp(result)); +template +__launch_bounds__(nthds_per_cta) __global__ void permuteData_kernel(const int nthreads, + const int num_classes, + const int num_data, + const int num_dim, + bool confSigmoid, + const Dtype* data, + Dtype* new_data) +{ + // data format: [batch_size, num_data, num_classes, num_dim] + for (int index = blockIdx.x * nthds_per_cta + threadIdx.x; index < nthreads; + index += nthds_per_cta * gridDim.x) + { + const int i = index % num_dim; + const int c = (index / num_dim) % num_classes; + const int d = (index / num_dim / num_classes) % num_data; + const int n = index / num_dim / num_classes / num_data; + const int new_index = ((n * num_classes + c) * num_data + d) * num_dim + i; + float result = data[index]; + if (confSigmoid) result = exp(result) / (1 + exp(result)); - new_data[new_index] = result; - } - // new data format: [batch_size, num_classes, num_data, num_dim] + new_data[new_index] = result; + } + // new data format: [batch_size, num_classes, num_data, num_dim] } -template -pluginStatus_t permuteData_gpu(cudaStream_t stream, const int nthreads, const int num_classes, - const int num_data, const int num_dim, bool confSigmoid, - const void *data, void *new_data) { - const int BS = 512; - const int GS = (nthreads + BS - 1) / BS; - permuteData_kernel<<>>(nthreads, num_classes, num_data, num_dim, - confSigmoid, (const Dtype *)data, - (Dtype *)new_data); - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; +template +pluginStatus_t permuteData_gpu(cudaStream_t stream, + const int nthreads, + const int num_classes, + const int num_data, + const int num_dim, + bool confSigmoid, + const void* data, + void* new_data) +{ + const int BS = 512; + const int GS = (nthreads + BS - 1) / BS; + permuteData_kernel<<>>(nthreads, + num_classes, + num_data, + num_dim, + confSigmoid, + (const Dtype*)data, + (Dtype*)new_data); + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // permuteData LAUNCH CONFIG -typedef pluginStatus_t (*pdFunc)(cudaStream_t, const int, const int, const int, const int, bool, - const void *, void *); +typedef pluginStatus_t (*pdFunc)(cudaStream_t, + const int, + const int, + const int, + const int, + bool, + const void*, + void*); -struct pdLaunchConfig { - DataType t_data; - pdFunc function; +struct pdLaunchConfig +{ + DataType t_data; + pdFunc function; - pdLaunchConfig(DataType t_data) : t_data(t_data) {} - pdLaunchConfig(DataType t_data, pdFunc function) : t_data(t_data), function(function) {} - bool operator==(const pdLaunchConfig &other) { return t_data == other.t_data; } + pdLaunchConfig(DataType t_data) + : t_data(t_data) + { + } + pdLaunchConfig(DataType t_data, pdFunc function) + : t_data(t_data) + , function(function) + { + } + bool operator==(const pdLaunchConfig& other) + { + return t_data == other.t_data; + } }; static std::vector pdFuncVec; -bool permuteDataInit() { - pdFuncVec.push_back(pdLaunchConfig(DataType::kFLOAT, permuteData_gpu)); - return true; +bool permuteDataInit() +{ + pdFuncVec.push_back(pdLaunchConfig(DataType::kFLOAT, permuteData_gpu)); + return true; } -static bool initialized = permuteDataInit(); +static bool initialized = permuteDataInit(); -pluginStatus_t permuteData(cudaStream_t stream, const int nthreads, const int num_classes, - const int num_data, const int num_dim, const DataType DT_DATA, - bool confSigmoid, const void *data, void *new_data) { - pdLaunchConfig lc = pdLaunchConfig(DT_DATA); - for (unsigned i = 0; i < pdFuncVec.size(); ++i) { - if (lc == pdFuncVec[i]) { - DEBUG_PRINTF("permuteData kernel %d\n", i); - return pdFuncVec[i].function(stream, nthreads, num_classes, num_data, num_dim, confSigmoid, - data, new_data); +pluginStatus_t permuteData(cudaStream_t stream, + const int nthreads, + const int num_classes, + const int num_data, + const int num_dim, + const DataType DT_DATA, + bool confSigmoid, + const void* data, + void* new_data) +{ + pdLaunchConfig lc = pdLaunchConfig(DT_DATA); + for (unsigned i = 0; i < pdFuncVec.size(); ++i) + { + if (lc == pdFuncVec[i]) + { + DEBUG_PRINTF("permuteData kernel %d\n", i); + return pdFuncVec[i].function(stream, + nthreads, + num_classes, + num_data, + num_dim, + confSigmoid, + data, + new_data); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerClass.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerClass.cu index e72f040cc9..df506d3896 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerClass.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerClass.cu @@ -8,134 +8,209 @@ #include "nms/kernel.h" #include "trt_plugin_helper.hpp" -template -__launch_bounds__(nthds_per_cta) __global__ - void prepareSortData(const int num, const int num_classes, const int num_preds_per_class, - const int background_label_id, const float confidence_threshold, - T_SCORE *conf_scores_gpu, T_SCORE *temp_scores, int *temp_idx, - int *d_offsets) { - // Prepare scores data for sort - const int cur_idx = blockIdx.x * nthds_per_cta + threadIdx.x; - const int numPredsPerBatch = num_classes * num_preds_per_class; - if (cur_idx < numPredsPerBatch) { - const int class_idx = cur_idx / num_preds_per_class; - for (int i = 0; i < num; i++) { - const int targetIdx = i * numPredsPerBatch + cur_idx; - const T_SCORE score = conf_scores_gpu[targetIdx]; +template +__launch_bounds__(nthds_per_cta) __global__ void prepareSortData(const int num, + const int num_classes, + const int num_preds_per_class, + const int background_label_id, + const float confidence_threshold, + T_SCORE* conf_scores_gpu, + T_SCORE* temp_scores, + int* temp_idx, + int* d_offsets) +{ + // Prepare scores data for sort + const int cur_idx = blockIdx.x * nthds_per_cta + threadIdx.x; + const int numPredsPerBatch = num_classes * num_preds_per_class; + if (cur_idx < numPredsPerBatch) + { + const int class_idx = cur_idx / num_preds_per_class; + for (int i = 0; i < num; i++) + { + const int targetIdx = i * numPredsPerBatch + cur_idx; + const T_SCORE score = conf_scores_gpu[targetIdx]; - // "Clear" background labeled score and index - // Because we do not care about background - if (class_idx == background_label_id) { - // Set scores to 0 - // Set label = -1 - temp_scores[targetIdx] = 0.0f; - temp_idx[targetIdx] = -1; - conf_scores_gpu[targetIdx] = 0.0f; - } - // "Clear" scores lower than threshold - else { - if (score > confidence_threshold) { - temp_scores[targetIdx] = score; - temp_idx[targetIdx] = cur_idx + i * numPredsPerBatch; - } else { - // Set scores to 0 - // Set label = -1 - temp_scores[targetIdx] = 0.0f; - temp_idx[targetIdx] = -1; - conf_scores_gpu[targetIdx] = 0.0f; - // TODO: HERE writing memory too many times - } - } + // "Clear" background labeled score and index + // Because we do not care about background + if (class_idx == background_label_id) + { + // Set scores to 0 + // Set label = -1 + temp_scores[targetIdx] = 0.0f; + temp_idx[targetIdx] = -1; + conf_scores_gpu[targetIdx] = 0.0f; + } + // "Clear" scores lower than threshold + else + { + if (score > confidence_threshold) + { + temp_scores[targetIdx] = score; + temp_idx[targetIdx] = cur_idx + i * numPredsPerBatch; + } + else + { + // Set scores to 0 + // Set label = -1 + temp_scores[targetIdx] = 0.0f; + temp_idx[targetIdx] = -1; + conf_scores_gpu[targetIdx] = 0.0f; + // TODO: HERE writing memory too many times + } + } - if ((cur_idx % num_preds_per_class) == 0) { - const int offset_ct = i * num_classes + cur_idx / num_preds_per_class; - d_offsets[offset_ct] = offset_ct * num_preds_per_class; - // set the last element in d_offset - if (blockIdx.x == 0 && threadIdx.x == 0) - d_offsets[num * num_classes] = num * numPredsPerBatch; - } + if ((cur_idx % num_preds_per_class) == 0) + { + const int offset_ct = i * num_classes + cur_idx / num_preds_per_class; + d_offsets[offset_ct] = offset_ct * num_preds_per_class; + // set the last element in d_offset + if (blockIdx.x == 0 && threadIdx.x == 0) + d_offsets[num * num_classes] = num * numPredsPerBatch; + } + } } - } } -template -pluginStatus_t sortScoresPerClass_gpu(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int background_label_id, - const float confidence_threshold, void *conf_scores_gpu, - void *index_array_gpu, void *workspace) { - const int num_segments = num * num_classes; - void *temp_scores = workspace; - const int arrayLen = num * num_classes * num_preds_per_class; - void *temp_idx = nextWorkspacePtr((int8_t *)temp_scores, arrayLen * sizeof(T_SCORE)); - void *d_offsets = nextWorkspacePtr((int8_t *)temp_idx, arrayLen * sizeof(int)); - size_t cubOffsetSize = (num_segments + 1) * sizeof(int); - void *cubWorkspace = nextWorkspacePtr((int8_t *)d_offsets, cubOffsetSize); +template +pluginStatus_t sortScoresPerClass_gpu(cudaStream_t stream, + const int num, + const int num_classes, + const int num_preds_per_class, + const int background_label_id, + const float confidence_threshold, + void* conf_scores_gpu, + void* index_array_gpu, + void* workspace) +{ + const int num_segments = num * num_classes; + void* temp_scores = workspace; + const int arrayLen = num * num_classes * num_preds_per_class; + void* temp_idx = nextWorkspacePtr((int8_t*)temp_scores, arrayLen * sizeof(T_SCORE)); + void* d_offsets = nextWorkspacePtr((int8_t*)temp_idx, arrayLen * sizeof(int)); + size_t cubOffsetSize = (num_segments + 1) * sizeof(int); + void* cubWorkspace = nextWorkspacePtr((int8_t*)d_offsets, cubOffsetSize); - const int BS = 512; - const int GS = (num_classes * num_preds_per_class + BS - 1) / BS; - prepareSortData<<>>( - num, num_classes, num_preds_per_class, background_label_id, confidence_threshold, - (T_SCORE *)conf_scores_gpu, (T_SCORE *)temp_scores, (int *)temp_idx, (int *)d_offsets); + const int BS = 512; + const int GS = (num_classes * num_preds_per_class + BS - 1) / BS; + prepareSortData<<>>( + num, + num_classes, + num_preds_per_class, + background_label_id, + confidence_threshold, + (T_SCORE*)conf_scores_gpu, + (T_SCORE*)temp_scores, + (int*)temp_idx, + (int*)d_offsets); - size_t temp_storage_bytes = cubSortPairsWorkspaceSize(arrayLen, num_segments); - cub::DeviceSegmentedRadixSort::SortPairsDescending( - cubWorkspace, temp_storage_bytes, (const T_SCORE *)(temp_scores), - (T_SCORE *)(conf_scores_gpu), (const int *)(temp_idx), (int *)(index_array_gpu), arrayLen, - num_segments, (const int *)d_offsets, (const int *)d_offsets + 1, 0, sizeof(T_SCORE) * 8, - stream); - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; + size_t temp_storage_bytes = cubSortPairsWorkspaceSize(arrayLen, num_segments); + cub::DeviceSegmentedRadixSort::SortPairsDescending( + cubWorkspace, + temp_storage_bytes, + (const T_SCORE*)(temp_scores), + (T_SCORE*)(conf_scores_gpu), + (const int*)(temp_idx), + (int*)(index_array_gpu), + arrayLen, + num_segments, + (const int*)d_offsets, + (const int*)d_offsets + 1, + 0, + sizeof(T_SCORE) * 8, + stream); + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // sortScoresPerClass LAUNCH CONFIG -typedef pluginStatus_t (*sspcFunc)(cudaStream_t, const int, const int, const int, const int, - const float, void *, void *, void *); -struct sspcLaunchConfig { - DataType t_score; - sspcFunc function; +typedef pluginStatus_t (*sspcFunc)(cudaStream_t, + const int, + const int, + const int, + const int, + const float, + void*, + void*, + void*); +struct sspcLaunchConfig +{ + DataType t_score; + sspcFunc function; - sspcLaunchConfig(DataType t_score) : t_score(t_score) {} - sspcLaunchConfig(DataType t_score, sspcFunc function) : t_score(t_score), function(function) {} - bool operator==(const sspcLaunchConfig &other) { return t_score == other.t_score; } + sspcLaunchConfig(DataType t_score) + : t_score(t_score) + { + } + sspcLaunchConfig(DataType t_score, sspcFunc function) + : t_score(t_score) + , function(function) + { + } + bool operator==(const sspcLaunchConfig& other) + { + return t_score == other.t_score; + } }; static std::vector sspcFuncVec; -bool sspcInit() { - sspcFuncVec.push_back(sspcLaunchConfig(DataType::kFLOAT, sortScoresPerClass_gpu)); - return true; +bool sspcInit() +{ + sspcFuncVec.push_back(sspcLaunchConfig(DataType::kFLOAT, sortScoresPerClass_gpu)); + return true; } -static bool initialized = sspcInit(); +static bool initialized = sspcInit(); -pluginStatus_t sortScoresPerClass(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int background_label_id, - const float confidence_threshold, const DataType DT_SCORE, - void *conf_scores_gpu, void *index_array_gpu, void *workspace) { - sspcLaunchConfig lc = sspcLaunchConfig(DT_SCORE); - for (unsigned i = 0; i < sspcFuncVec.size(); ++i) { - if (lc == sspcFuncVec[i]) { - DEBUG_PRINTF("sortScoresPerClass kernel %d\n", i); - return sspcFuncVec[i].function(stream, num, num_classes, num_preds_per_class, - background_label_id, confidence_threshold, conf_scores_gpu, - index_array_gpu, workspace); +pluginStatus_t sortScoresPerClass(cudaStream_t stream, + const int num, + const int num_classes, + const int num_preds_per_class, + const int background_label_id, + const float confidence_threshold, + const DataType DT_SCORE, + void* conf_scores_gpu, + void* index_array_gpu, + void* workspace) +{ + sspcLaunchConfig lc = sspcLaunchConfig(DT_SCORE); + for (unsigned i = 0; i < sspcFuncVec.size(); ++i) + { + if (lc == sspcFuncVec[i]) + { + DEBUG_PRINTF("sortScoresPerClass kernel %d\n", i); + return sspcFuncVec[i].function(stream, + num, + num_classes, + num_preds_per_class, + background_label_id, + confidence_threshold, + conf_scores_gpu, + index_array_gpu, + workspace); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } -size_t sortScoresPerClassWorkspaceSize(const int num, const int num_classes, - const int num_preds_per_class, const DataType DT_CONF) { - size_t wss[4]; - const int arrayLen = num * num_classes * num_preds_per_class; - wss[0] = arrayLen * mmdeploy::getElementSize(DT_CONF); // temp scores - wss[1] = arrayLen * sizeof(int); // temp indices - wss[2] = (num * num_classes + 1) * sizeof(int); // offsets - if (DT_CONF == DataType::kFLOAT) { - wss[3] = cubSortPairsWorkspaceSize(arrayLen, num * num_classes); // cub workspace - } else { - printf("SCORE type not supported\n"); - return (size_t)-1; - } +size_t sortScoresPerClassWorkspaceSize(const int num, + const int num_classes, + const int num_preds_per_class, + const DataType DT_CONF) +{ + size_t wss[4]; + const int arrayLen = num * num_classes * num_preds_per_class; + wss[0] = arrayLen * mmdeploy::getElementSize(DT_CONF); // temp scores + wss[1] = arrayLen * sizeof(int); // temp indices + wss[2] = (num * num_classes + 1) * sizeof(int); // offsets + if (DT_CONF == DataType::kFLOAT) + { + wss[3] = cubSortPairsWorkspaceSize(arrayLen, num * num_classes); // cub workspace + } + else + { + printf("SCORE type not supported\n"); + return (size_t)-1; + } - return calculateTotalWorkspaceSize(wss, 4); + return calculateTotalWorkspaceSize(wss, 4); } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerImage.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerImage.cu index a6ad70262d..ab60b5f88a 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerImage.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerImage.cu @@ -7,75 +7,125 @@ #include "nms/cub_helper.h" #include "nms/kernel.h" -template -pluginStatus_t sortScoresPerImage_gpu(cudaStream_t stream, const int num_images, - const int num_items_per_image, void *unsorted_scores, - void *unsorted_bbox_indices, void *sorted_scores, - void *sorted_bbox_indices, void *workspace) { - void *d_offsets = workspace; - void *cubWorkspace = nextWorkspacePtr((int8_t *)d_offsets, (num_images + 1) * sizeof(int)); +template +pluginStatus_t sortScoresPerImage_gpu(cudaStream_t stream, + const int num_images, + const int num_items_per_image, + void* unsorted_scores, + void* unsorted_bbox_indices, + void* sorted_scores, + void* sorted_bbox_indices, + void* workspace) +{ + void* d_offsets = workspace; + void* cubWorkspace = nextWorkspacePtr((int8_t*)d_offsets, (num_images + 1) * sizeof(int)); - setUniformOffsets(stream, num_images, num_items_per_image, (int *)d_offsets); + setUniformOffsets(stream, num_images, num_items_per_image, (int*)d_offsets); - const int arrayLen = num_images * num_items_per_image; - size_t temp_storage_bytes = cubSortPairsWorkspaceSize(arrayLen, num_images); - cub::DeviceSegmentedRadixSort::SortPairsDescending( - cubWorkspace, temp_storage_bytes, (const T_SCORE *)(unsorted_scores), - (T_SCORE *)(sorted_scores), (const int *)(unsorted_bbox_indices), - (int *)(sorted_bbox_indices), arrayLen, num_images, (const int *)d_offsets, - (const int *)d_offsets + 1, 0, sizeof(T_SCORE) * 8, stream); - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; + const int arrayLen = num_images * num_items_per_image; + size_t temp_storage_bytes = cubSortPairsWorkspaceSize(arrayLen, num_images); + cub::DeviceSegmentedRadixSort::SortPairsDescending( + cubWorkspace, + temp_storage_bytes, + (const T_SCORE*)(unsorted_scores), + (T_SCORE*)(sorted_scores), + (const int*)(unsorted_bbox_indices), + (int*)(sorted_bbox_indices), + arrayLen, + num_images, + (const int*)d_offsets, + (const int*)d_offsets + 1, + 0, + sizeof(T_SCORE) * 8, + stream); + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // sortScoresPerImage LAUNCH CONFIG -typedef pluginStatus_t (*sspiFunc)(cudaStream_t, const int, const int, void *, void *, void *, - void *, void *); -struct sspiLaunchConfig { - DataType t_score; - sspiFunc function; +typedef pluginStatus_t (*sspiFunc)(cudaStream_t, + const int, + const int, + void*, + void*, + void*, + void*, + void*); +struct sspiLaunchConfig +{ + DataType t_score; + sspiFunc function; - sspiLaunchConfig(DataType t_score) : t_score(t_score) {} - sspiLaunchConfig(DataType t_score, sspiFunc function) : t_score(t_score), function(function) {} - bool operator==(const sspiLaunchConfig &other) { return t_score == other.t_score; } + sspiLaunchConfig(DataType t_score) + : t_score(t_score) + { + } + sspiLaunchConfig(DataType t_score, sspiFunc function) + : t_score(t_score) + , function(function) + { + } + bool operator==(const sspiLaunchConfig& other) + { + return t_score == other.t_score; + } }; static std::vector sspiFuncVec; -bool sspiInit() { - sspiFuncVec.push_back(sspiLaunchConfig(DataType::kFLOAT, sortScoresPerImage_gpu)); - return true; +bool sspiInit() +{ + sspiFuncVec.push_back(sspiLaunchConfig(DataType::kFLOAT, sortScoresPerImage_gpu)); + return true; } -static bool initialized = sspiInit(); +static bool initialized = sspiInit(); -pluginStatus_t sortScoresPerImage(cudaStream_t stream, const int num_images, - const int num_items_per_image, const DataType DT_SCORE, - void *unsorted_scores, void *unsorted_bbox_indices, - void *sorted_scores, void *sorted_bbox_indices, void *workspace) { - sspiLaunchConfig lc = sspiLaunchConfig(DT_SCORE); - for (unsigned i = 0; i < sspiFuncVec.size(); ++i) { - if (lc == sspiFuncVec[i]) { - DEBUG_PRINTF("sortScoresPerImage kernel %d\n", i); - return sspiFuncVec[i].function(stream, num_images, num_items_per_image, unsorted_scores, - unsorted_bbox_indices, sorted_scores, sorted_bbox_indices, - workspace); +pluginStatus_t sortScoresPerImage(cudaStream_t stream, + const int num_images, + const int num_items_per_image, + const DataType DT_SCORE, + void* unsorted_scores, + void* unsorted_bbox_indices, + void* sorted_scores, + void* sorted_bbox_indices, + void* workspace) +{ + sspiLaunchConfig lc = sspiLaunchConfig(DT_SCORE); + for (unsigned i = 0; i < sspiFuncVec.size(); ++i) + { + if (lc == sspiFuncVec[i]) + { + DEBUG_PRINTF("sortScoresPerImage kernel %d\n", i); + return sspiFuncVec[i].function(stream, + num_images, + num_items_per_image, + unsorted_scores, + unsorted_bbox_indices, + sorted_scores, + sorted_bbox_indices, + workspace); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } -size_t sortScoresPerImageWorkspaceSize(const int num_images, const int num_items_per_image, - const DataType DT_SCORE) { - const int arrayLen = num_images * num_items_per_image; - size_t wss[2]; - wss[0] = (num_images + 1) * sizeof(int); // offsets - if (DT_SCORE == DataType::kFLOAT) { - wss[1] = cubSortPairsWorkspaceSize(arrayLen, - num_images); // cub workspace - } else { - printf("SCORE type not supported.\n"); - return (size_t)-1; - } +size_t sortScoresPerImageWorkspaceSize(const int num_images, + const int num_items_per_image, + const DataType DT_SCORE) +{ + const int arrayLen = num_images * num_items_per_image; + size_t wss[2]; + wss[0] = (num_images + 1) * sizeof(int); // offsets + if (DT_SCORE == DataType::kFLOAT) + { + wss[1] = cubSortPairsWorkspaceSize(arrayLen, + num_images); // cub workspace + } + else + { + printf("SCORE type not supported.\n"); + return (size_t)-1; + } - return calculateTotalWorkspaceSize(wss, 2); + return calculateTotalWorkspaceSize(wss, 2); } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu index 47e8ae8615..ad0a1bf6de 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu @@ -4,92 +4,145 @@ using mmdeploy::TensorDesc; -template -__global__ void copy_permute_kernel(scalar_t *__restrict__ dst, const scalar_t *__restrict__ src, - int n, TensorDesc ts_src_stride, TensorDesc ts_dst_stride, - TensorDesc ts_permute) { - const int src_dim = ts_src_stride.dim; - const auto src_stride = ts_src_stride.stride; - const auto dst_stride = ts_dst_stride.stride; - const auto permute = ts_permute.shape; - CUDA_1D_KERNEL_LOOP(index, n) { - size_t dst_index = index; - size_t src_index = 0; - for (int i = 0; i < src_dim; ++i) { - int dim_index = dst_index / dst_stride[i]; - dst_index = dst_index % dst_stride[i]; - src_index += dim_index * src_stride[permute[i]]; +template +__global__ void copy_permute_kernel(scalar_t* __restrict__ dst, + const scalar_t* __restrict__ src, + int n, + TensorDesc ts_src_stride, + TensorDesc ts_dst_stride, + TensorDesc ts_permute) +{ + const int src_dim = ts_src_stride.dim; + const auto src_stride = ts_src_stride.stride; + const auto dst_stride = ts_dst_stride.stride; + const auto permute = ts_permute.shape; + CUDA_1D_KERNEL_LOOP(index, n) + { + size_t dst_index = index; + size_t src_index = 0; + for (int i = 0; i < src_dim; ++i) + { + int dim_index = dst_index / dst_stride[i]; + dst_index = dst_index % dst_stride[i]; + src_index += dim_index * src_stride[permute[i]]; + } + dst[index] = src[src_index]; } - dst[index] = src[src_index]; - } } -template -void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size, int *permute, int src_dim, - cudaStream_t stream) { - size_t copy_size = 1; - TensorDesc ts_permute; - memcpy(&(ts_permute.shape[0]), permute, src_dim * sizeof(int)); +template +void memcpyPermute(scalar_t* dst, + const scalar_t* src, + int* src_size, + int* permute, + int src_dim, + cudaStream_t stream) +{ + size_t copy_size = 1; + TensorDesc ts_permute; + memcpy(&(ts_permute.shape[0]), permute, src_dim * sizeof(int)); - TensorDesc ts_src_stride; - TensorDesc ts_dst_stride; - ts_src_stride.dim = src_dim; - ts_dst_stride.dim = src_dim; - int *src_stride = &(ts_src_stride.stride[0]); - int *dst_stride = &(ts_dst_stride.stride[0]); - int *dst_size = &(ts_dst_stride.shape[0]); - src_stride[src_dim - 1] = 1; - dst_stride[src_dim - 1] = 1; + TensorDesc ts_src_stride; + TensorDesc ts_dst_stride; + ts_src_stride.dim = src_dim; + ts_dst_stride.dim = src_dim; + int* src_stride = &(ts_src_stride.stride[0]); + int* dst_stride = &(ts_dst_stride.stride[0]); + int* dst_size = &(ts_dst_stride.shape[0]); + src_stride[src_dim - 1] = 1; + dst_stride[src_dim - 1] = 1; - for (int i = src_dim - 1; i >= 0; --i) { - dst_size[i] = src_size[permute[i]]; - if (i < src_dim - 1) { - src_stride[i] = src_stride[i + 1] * src_size[i + 1]; + for (int i = src_dim - 1; i >= 0; --i) + { + dst_size[i] = src_size[permute[i]]; + if (i < src_dim - 1) + { + src_stride[i] = src_stride[i + 1] * src_size[i + 1]; + } } - } - for (int i = src_dim - 1; i >= 0; --i) { - copy_size *= dst_size[i]; - if (i < src_dim - 1) { - dst_stride[i] = dst_stride[i + 1] * dst_size[i + 1]; + for (int i = src_dim - 1; i >= 0; --i) + { + copy_size *= dst_size[i]; + if (i < src_dim - 1) + { + dst_stride[i] = dst_stride[i + 1] * dst_size[i + 1]; + } } - } - copy_permute_kernel<<>>( - dst, src, copy_size, ts_src_stride, ts_dst_stride, ts_permute); + copy_permute_kernel<<>>( + dst, + src, + copy_size, + ts_src_stride, + ts_dst_stride, + ts_permute); } -template void memcpyPermute(float *dst, const float *src, int *src_size, int *permute, - int src_dim, cudaStream_t stream); -template void memcpyPermute(half *dst, const half *src, int *src_size, int *permute, - int src_dim, cudaStream_t stream); +template void memcpyPermute(float* dst, + const float* src, + int* src_size, + int* permute, + int src_dim, + cudaStream_t stream); -cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t *cudnn_dtype) { - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - *cudnn_dtype = CUDNN_DATA_FLOAT; - break; - case nvinfer1::DataType::kHALF: - *cudnn_dtype = CUDNN_DATA_HALF; - break; - default: - return CUDNN_STATUS_BAD_PARAM; - } - return CUDNN_STATUS_SUCCESS; +template void memcpyPermute(half* dst, + const half* src, + int* src_size, + int* permute, + int src_dim, + cudaStream_t stream); + +cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t* cudnn_dtype) +{ + switch (trt_dtype) + { + case nvinfer1::DataType::kFLOAT: + *cudnn_dtype = CUDNN_DATA_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *cudnn_dtype = CUDNN_DATA_HALF; + break; + default: + return CUDNN_STATUS_BAD_PARAM; + } + return CUDNN_STATUS_SUCCESS; } -template <> -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *A, int lda, const float *B, - int ldb, const float *beta, float *C, int ldc) { - return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +template<> +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* A, + int lda, + const float* B, + int ldb, + const float* beta, + float* C, + int ldc) +{ + return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -template <> -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const half *alpha, const half *A, int lda, const half *B, - int ldb, const half *beta, half *C, int ldc) { - return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +template<> +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const half* alpha, + const half* A, + int lda, + const half* B, + int ldb, + const half* beta, + half* C, + int ldc) +{ + return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp index 0d518323d2..247093db2f 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp @@ -10,254 +10,346 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVDeformConv2d"}; -} // namespace - -DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string &name, - const nvinfer1::Dims stride, - const nvinfer1::Dims padding, - const nvinfer1::Dims dilation, - const int deformableGroup, const int group) - : TRTPluginBase(name), - mStride(stride), - mPadding(padding), - mDilation(dilation), - mDeformableGroup(deformableGroup), - mGroup(group) {} - -DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mStride); - deserialize_value(&data, &length, &mPadding); - deserialize_value(&data, &length, &mDilation); - deserialize_value(&data, &length, &mDeformableGroup); - deserialize_value(&data, &length, &mGroup); -} -DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {} - -nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const TRT_NOEXCEPT { - DeformableConvPluginDynamic *plugin = new DeformableConvPluginDynamic( - mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - // input[0] == input - // input[1] == offset - // input[2] == weight - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[2].d[0]; - - ret.d[2] = inputs[1].d[2]; - ret.d[3] = inputs[1].d[3]; - - return ret; -} - -bool DeformableConvPluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || - ioDesc[pos].type == nvinfer1::DataType::kHALF) && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -void DeformableConvPluginDynamic::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t DeformableConvPluginDynamic::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type); - - int batch_size = inputs[0].dims.d[0]; - int nInputPlane = inputs[0].dims.d[1]; - int inputHeight = inputs[0].dims.d[2]; - int inputWidth = inputs[0].dims.d[3]; - - int nOutputPlane = outputs[0].dims.d[1]; - int outputHeight = outputs[0].dims.d[2]; - int outputWidth = outputs[0].dims.d[3]; - - int kW = inputs[2].dims.d[2]; - int kH = inputs[2].dims.d[3]; - int im2col_step = std::min(32, batch_size); - - size_t col_size = mmdeploy::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight * - outputWidth * sizeof_dtype); - - size_t out_size = 0; - if (im2col_step != 1) - out_size = mmdeploy::getAlignedSize(batch_size * nOutputPlane * outputHeight * outputWidth * - sizeof_dtype); - - return col_size + out_size; -} - -int DeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - int batch = inputDesc[0].dims.d[0]; - int channels = inputDesc[0].dims.d[1]; - int height = inputDesc[0].dims.d[2]; - int width = inputDesc[0].dims.d[3]; - int channels_out = outputDesc[0].dims.d[1]; - int kernel_h = inputDesc[2].dims.d[2]; - int kernel_w = inputDesc[2].dims.d[3]; - - const void *x = inputs[0]; - const void *offset = inputs[1]; - const void *weight = inputs[2]; - void *output = outputs[0]; - int im2col_step = std::min(batch, 32); - - auto data_type = inputDesc[0].type; - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - deform_conv((float *)x, (float *)weight, (float *)offset, (float *)output, workSpace, - batch, channels, height, width, channels_out, kernel_w, kernel_h, - mStride.d[0], mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], - mDilation.d[1], mGroup, mDeformableGroup, im2col_step, m_cublas_handle, - stream); - break; - case nvinfer1::DataType::kHALF: - deform_conv((half *)x, (half *)weight, (half *)offset, (half *)output, workSpace, batch, - channels, height, width, channels_out, kernel_w, kernel_h, mStride.d[0], - mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], - mGroup, mDeformableGroup, im2col_step, m_cublas_handle, stream); - break; - default: - return 1; - } - - return 0; -} - -nvinfer1::DataType DeformableConvPluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *DeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *DeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int DeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t DeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) + - serialized_size(mDeformableGroup) + serialized_size(mGroup); -} - -void DeformableConvPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mStride); - serialize_value(&buffer, mPadding); - serialize_value(&buffer, mDilation); - serialize_value(&buffer, mDeformableGroup); - serialize_value(&buffer, mGroup); -} - -void DeformableConvPluginDynamic::attachToContext( - cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT { - m_cublas_handle = cublasContext; -} - -void DeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {} - -////////////////////// creator ///////////////////////////// - -DeformableConvPluginDynamicCreator::DeformableConvPluginDynamicCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *DeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *DeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - nvinfer1::Dims stride{2, {1, 1}}; - nvinfer1::Dims padding{2, {0, 0}}; - nvinfer1::Dims dilation{2, {1, 1}}; - int deformableGroup = 1; - int group = 1; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVDeformConv2d"}; + } // namespace + + DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string& name, + const nvinfer1::Dims stride, + const nvinfer1::Dims padding, + const nvinfer1::Dims dilation, + const int deformableGroup, + const int group) + : TRTPluginBase(name) + , mStride(stride) + , mPadding(padding) + , mDilation(dilation) + , mDeformableGroup(deformableGroup) + , mGroup(group) + { } - std::string field_name(fc->fields[i].name); - if (field_name.compare("deform_groups") == 0) { - deformableGroup = static_cast(fc->fields[i].data)[0]; + DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mStride); + deserialize_value(&data, &length, &mPadding); + deserialize_value(&data, &length, &mDilation); + deserialize_value(&data, &length, &mDeformableGroup); + deserialize_value(&data, &length, &mGroup); + } + DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {} + + nvinfer1::IPluginV2DynamicExt* DeformableConvPluginDynamic::clone() const TRT_NOEXCEPT + { + DeformableConvPluginDynamic* plugin = new DeformableConvPluginDynamic( + mLayerName, + mStride, + mPadding, + mDilation, + mDeformableGroup, + mGroup); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + // input[0] == input + // input[1] == offset + // input[2] == weight + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[2].d[0]; + + ret.d[2] = inputs[1].d[2]; + ret.d[3] = inputs[1].d[3]; + + return ret; + } + + bool DeformableConvPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || + ioDesc[pos].type == nvinfer1::DataType::kHALF) && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + void DeformableConvPluginDynamic::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + + size_t DeformableConvPluginDynamic::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type); + + int batch_size = inputs[0].dims.d[0]; + int nInputPlane = inputs[0].dims.d[1]; + int inputHeight = inputs[0].dims.d[2]; + int inputWidth = inputs[0].dims.d[3]; + + int nOutputPlane = outputs[0].dims.d[1]; + int outputHeight = outputs[0].dims.d[2]; + int outputWidth = outputs[0].dims.d[3]; + + int kW = inputs[2].dims.d[2]; + int kH = inputs[2].dims.d[3]; + int im2col_step = std::min(32, batch_size); + + size_t col_size = mmdeploy::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight * + outputWidth * sizeof_dtype); + + size_t out_size = 0; + if (im2col_step != 1) + out_size = mmdeploy::getAlignedSize(batch_size * nOutputPlane * outputHeight * outputWidth * + sizeof_dtype); + + return col_size + out_size; + } + + int DeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + int channels_out = outputDesc[0].dims.d[1]; + int kernel_h = inputDesc[2].dims.d[2]; + int kernel_w = inputDesc[2].dims.d[3]; + + const void* x = inputs[0]; + const void* offset = inputs[1]; + const void* weight = inputs[2]; + void* output = outputs[0]; + int im2col_step = std::min(batch, 32); + + auto data_type = inputDesc[0].type; + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + deform_conv((float*)x, + (float*)weight, + (float*)offset, + (float*)output, + workSpace, + batch, + channels, + height, + width, + channels_out, + kernel_w, + kernel_h, + mStride.d[0], + mStride.d[1], + mPadding.d[0], + mPadding.d[1], + mDilation.d[0], + mDilation.d[1], + mGroup, + mDeformableGroup, + im2col_step, + m_cublas_handle, + stream); + break; + case nvinfer1::DataType::kHALF: + deform_conv((half*)x, + (half*)weight, + (half*)offset, + (half*)output, + workSpace, + batch, + channels, + height, + width, + channels_out, + kernel_w, + kernel_h, + mStride.d[0], + mStride.d[1], + mPadding.d[0], + mPadding.d[1], + mDilation.d[0], + mDilation.d[1], + mGroup, + mDeformableGroup, + im2col_step, + m_cublas_handle, + stream); + break; + default: + return 1; + } + + return 0; + } + + nvinfer1::DataType DeformableConvPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* DeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* DeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int DeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t DeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) + + serialized_size(mDeformableGroup) + serialized_size(mGroup); + } + + void DeformableConvPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mStride); + serialize_value(&buffer, mPadding); + serialize_value(&buffer, mDilation); + serialize_value(&buffer, mDeformableGroup); + serialize_value(&buffer, mGroup); + } + + void DeformableConvPluginDynamic::attachToContext( + cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT + { + m_cublas_handle = cublasContext; + } + + void DeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {} + + ////////////////////// creator ///////////////////////////// + + DeformableConvPluginDynamicCreator::DeformableConvPluginDynamicCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); } - if (field_name.compare("groups") == 0) { - group = static_cast(fc->fields[i].data)[0]; + const char* DeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; } - if (field_name.compare("stride") == 0) { - stride.nbDims = 2; - stride.d[0] = static_cast(fc->fields[i].data)[0]; - stride.d[1] = static_cast(fc->fields[i].data)[1]; + const char* DeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; } - if (field_name.compare("padding") == 0) { - padding.nbDims = 2; - padding.d[0] = static_cast(fc->fields[i].data)[0]; - padding.d[1] = static_cast(fc->fields[i].data)[1]; + nvinfer1::IPluginV2* DeformableConvPluginDynamicCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + nvinfer1::Dims stride{2, {1, 1}}; + nvinfer1::Dims padding{2, {0, 0}}; + nvinfer1::Dims dilation{2, {1, 1}}; + int deformableGroup = 1; + int group = 1; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("deform_groups") == 0) + { + deformableGroup = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("groups") == 0) + { + group = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("stride") == 0) + { + stride.nbDims = 2; + stride.d[0] = static_cast(fc->fields[i].data)[0]; + stride.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("padding") == 0) + { + padding.nbDims = 2; + padding.d[0] = static_cast(fc->fields[i].data)[0]; + padding.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("dilation") == 0) + { + dilation.nbDims = 2; + dilation.d[0] = static_cast(fc->fields[i].data)[0]; + dilation.d[1] = static_cast(fc->fields[i].data)[1]; + } + } + + DeformableConvPluginDynamic* plugin = + new DeformableConvPluginDynamic(name, stride, padding, dilation, deformableGroup, group); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; } - if (field_name.compare("dilation") == 0) { - dilation.nbDims = 2; - dilation.d[0] = static_cast(fc->fields[i].data)[0]; - dilation.d[1] = static_cast(fc->fields[i].data)[1]; + nvinfer1::IPluginV2* DeformableConvPluginDynamicCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new DeformableConvPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; } - } - - DeformableConvPluginDynamic *plugin = - new DeformableConvPluginDynamic(name, stride, padding, dilation, deformableGroup, group); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new DeformableConvPluginDynamic(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); + REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp index 3ea0ccbefe..09845327ca 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp @@ -9,73 +9,99 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class DeformableConvPluginDynamic : public TRTPluginBase { - public: - DeformableConvPluginDynamic(const std::string &name, const nvinfer1::Dims stride, - const nvinfer1::Dims padding, const nvinfer1::Dims dilation, - const int deformableGroup, const int group); - - DeformableConvPluginDynamic(const std::string name, const void *data, size_t length); - - DeformableConvPluginDynamic() = delete; - - ~DeformableConvPluginDynamic() TRT_NOEXCEPT override; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; - - private: - nvinfer1::Dims mStride; - nvinfer1::Dims mPadding; - nvinfer1::Dims mDilation; - int mDeformableGroup; - int mGroup; - - cublasHandle_t m_cublas_handle; -}; - -class DeformableConvPluginDynamicCreator : public TRTPluginCreatorBase { - public: - DeformableConvPluginDynamicCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class DeformableConvPluginDynamic : public TRTPluginBase + { + public: + DeformableConvPluginDynamic(const std::string& name, + const nvinfer1::Dims stride, + const nvinfer1::Dims padding, + const nvinfer1::Dims dilation, + const int deformableGroup, + const int group); + + DeformableConvPluginDynamic(const std::string name, + const void* data, + size_t length); + + DeformableConvPluginDynamic() = delete; + + ~DeformableConvPluginDynamic() TRT_NOEXCEPT override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + void attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + + private: + nvinfer1::Dims mStride; + nvinfer1::Dims mPadding; + nvinfer1::Dims mDilation; + int mDeformableGroup; + int mGroup; + + cublasHandle_t m_cublas_handle; + }; + + class DeformableConvPluginDynamicCreator : public TRTPluginCreatorBase + { + public: + DeformableConvPluginDynamicCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_DEFORM_CONV_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu index 3f401fc9e2..e62bdb0a48 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu @@ -68,105 +68,228 @@ #include "trt_deform_conv_kernel.hpp" #include "trt_plugin_helper.hpp" -template -void deform_conv_im2col(const scalar_t* input, const scalar_t* offset, scalar_t* column, - const int channels, const int height, const int width, const int ksize_h, - const int ksize_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int parallel_imgs, const int deformable_group, cudaStream_t stream) { - int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; - int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; - int num_kernels = channels * height_col * width_col * parallel_imgs; - int channel_per_deformable_group = channels / deformable_group; - - deformable_im2col_gpu_kernel<<>>( - num_kernels, input, offset, height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs, channels, - deformable_group, height_col, width_col, column); - - cudaCheckError(); +template +void deform_conv_im2col(const scalar_t* input, + const scalar_t* offset, + scalar_t* column, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + cudaStream_t stream) +{ + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + deformable_im2col_gpu_kernel<<>>(num_kernels, + input, + offset, + height, + width, + ksize_h, + ksize_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + parallel_imgs, + channels, + deformable_group, + height_col, + width_col, + column); + + cudaCheckError(); } -template -void deform_conv(const scalar_t* input, const scalar_t* weight, const scalar_t* offset, - scalar_t* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, - int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, int group, int deformable_group, - int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) { - size_t word_size = sizeof(scalar_t); - - im2col_step = std::min(int(batchSize), im2col_step); - long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - - long outputHW = outputHeight * outputWidth; - long kHW = kH * kW; - long columns_size = - mmdeploy::getAlignedSize(nInputPlane * kHW * im2col_step * outputHW * word_size); - - // column buffer for img2col - char* workspace_ptr = reinterpret_cast(workspace); - scalar_t* columns = reinterpret_cast(workspace_ptr); - workspace_ptr = workspace_ptr + columns_size; - - scalar_t* output_buffer; - if (im2col_step == 1) { - output_buffer = output; - } else { - // output need permute when im2col_step!=1 - output_buffer = reinterpret_cast(workspace_ptr); - } - - long input_elt_step = im2col_step * nInputPlane * inputHeight * inputWidth; - long offset_elt_step = im2col_step * deformable_group * 2 * kHW * outputHW; - long out_buffer_step = nOutputPlane * im2col_step * outputHW; - long col_g_step = nInputPlane * kHW * im2col_step * outputHW / group; - long weight_g_step = nOutputPlane * nInputPlane * kHW / (group * group); - long out_buffer_g_step = out_buffer_step / group; - int m = nOutputPlane / group; - int n = im2col_step * outputHW; - int k = nInputPlane * kHW / group; - scalar_t alpha = 1.f; - scalar_t beta = 0.f; - - for (int elt = 0; elt < batchSize / im2col_step; elt++) { - const scalar_t* input_start = input + elt * input_elt_step; - const scalar_t* offset_start = offset + elt * offset_elt_step; - - deform_conv_im2col(input_start, offset_start, columns, nInputPlane, inputHeight, - inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, - im2col_step, deformable_group, stream); - - for (int g = 0; g < group; ++g) { - const scalar_t* weight_start = weight + g * weight_g_step; - scalar_t* col_start = columns + g * col_g_step; - scalar_t* out_buffer_start = output_buffer + elt * out_buffer_step + g * out_buffer_g_step; - - cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, col_start, - n, weight_start, k, &beta, out_buffer_start, n); - cudaCheckError(); +template +void deform_conv(const scalar_t* input, + const scalar_t* weight, + const scalar_t* offset, + scalar_t* output, + void* workspace, + int batchSize, + int nInputPlane, + int inputHeight, + int inputWidth, + int nOutputPlane, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream) +{ + size_t word_size = sizeof(scalar_t); + + im2col_step = std::min(int(batchSize), im2col_step); + long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + long outputHW = outputHeight * outputWidth; + long kHW = kH * kW; + long columns_size = mmdeploy::getAlignedSize(nInputPlane * kHW * im2col_step * outputHW * word_size); + + // column buffer for img2col + char* workspace_ptr = reinterpret_cast(workspace); + scalar_t* columns = reinterpret_cast(workspace_ptr); + workspace_ptr = workspace_ptr + columns_size; + + scalar_t* output_buffer; + if (im2col_step == 1) + { + output_buffer = output; + } + else + { + // output need permute when im2col_step!=1 + output_buffer = reinterpret_cast(workspace_ptr); + } + + long input_elt_step = im2col_step * nInputPlane * inputHeight * inputWidth; + long offset_elt_step = im2col_step * deformable_group * 2 * kHW * outputHW; + long out_buffer_step = nOutputPlane * im2col_step * outputHW; + long col_g_step = nInputPlane * kHW * im2col_step * outputHW / group; + long weight_g_step = nOutputPlane * nInputPlane * kHW / (group * group); + long out_buffer_g_step = out_buffer_step / group; + int m = nOutputPlane / group; + int n = im2col_step * outputHW; + int k = nInputPlane * kHW / group; + scalar_t alpha = 1.f; + scalar_t beta = 0.f; + + for (int elt = 0; elt < batchSize / im2col_step; elt++) + { + const scalar_t* input_start = input + elt * input_elt_step; + const scalar_t* offset_start = offset + elt * offset_elt_step; + + deform_conv_im2col(input_start, + offset_start, + columns, + nInputPlane, + inputHeight, + inputWidth, + kH, + kW, + padH, + padW, + dH, + dW, + dilationH, + dilationW, + im2col_step, + deformable_group, + stream); + + for (int g = 0; g < group; ++g) + { + const scalar_t* weight_start = weight + g * weight_g_step; + scalar_t* col_start = columns + g * col_g_step; + scalar_t* out_buffer_start = output_buffer + elt * out_buffer_step + g * out_buffer_g_step; + + cublasGemmWrap(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + n, + m, + k, + &alpha, + col_start, + n, + weight_start, + k, + &beta, + out_buffer_start, + n); + cudaCheckError(); + } + } + + if (im2col_step != 1) + { + int output_buffer_shape[5] = {batchSize / im2col_step, + nOutputPlane, + im2col_step, + static_cast(outputHeight), + static_cast(outputWidth)}; + int output_buffer_permute[5] = {0, 2, 1, 3, 4}; + memcpyPermute(output, + output_buffer, + &output_buffer_shape[0], + &output_buffer_permute[0], + 5, + stream); } - } - - if (im2col_step != 1) { - int output_buffer_shape[5] = {batchSize / im2col_step, nOutputPlane, im2col_step, - static_cast(outputHeight), static_cast(outputWidth)}; - int output_buffer_permute[5] = {0, 2, 1, 3, 4}; - memcpyPermute(output, output_buffer, &output_buffer_shape[0], - &output_buffer_permute[0], 5, stream); - } } -template void deform_conv(const float* input, const float* weight, const float* offset, - float* output, void* workspace, int batchSize, int nInputPlane, - int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, - int dW, int dH, int padW, int padH, int dilationW, int dilationH, - int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); - -template void deform_conv<__half>(const __half* input, const __half* weight, const __half* offset, - __half* output, void* workspace, int batchSize, int nInputPlane, - int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, - int dW, int dH, int padW, int padH, int dilationW, int dilationH, - int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); +template void deform_conv(const float* input, + const float* weight, + const float* offset, + float* output, + void* workspace, + int batchSize, + int nInputPlane, + int inputHeight, + int inputWidth, + int nOutputPlane, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream); + +template void deform_conv<__half>(const __half* input, + const __half* weight, + const __half* offset, + __half* output, + void* workspace, + int batchSize, + int nInputPlane, + int inputHeight, + int inputWidth, + int nOutputPlane, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh index c91f17ca4a..85e675bf9c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh @@ -67,108 +67,133 @@ #include "common_cuda_helper.hpp" -template +template __device__ __forceinline__ scalar_t deformable_im2col_bilinear(const scalar_t* __restrict__ input, - const int height, const int width, - float h, float w) { - if (h <= -1 || height <= h || w <= -1 || width <= w) { - return 0; - } + const int height, + const int width, + float h, + float w) +{ + if (h <= -1 || height <= h || w <= -1 || width <= w) + { + return 0; + } - const int h_low = floorf(h); - const int w_low = floorf(w); + const int h_low = floorf(h); + const int w_low = floorf(w); - input += h_low * width; - const scalar_t v1 = (h_low >= 0 && w_low >= 0) ? input[w_low] : static_cast(0.0f); - const int w_high = w_low + 1; - const scalar_t v2 = - (h_low >= 0 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); - const scalar_t lw = w - w_low; - const scalar_t v_low = fmaf(v2 - v1, lw, v1); - input += width; - const scalar_t v3 = - (h_low <= height - 2 && w_low >= 0) ? input[w_low] : static_cast(0.0f); - const scalar_t v4 = - (h_low <= height - 2 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); - const scalar_t v_high = fmaf(v4 - v3, lw, v3); - const scalar_t lh = h - h_low; - const scalar_t val = fmaf(v_high - v_low, lh, v_low); - return val; + input += h_low * width; + const scalar_t v1 = (h_low >= 0 && w_low >= 0) ? input[w_low] : static_cast(0.0f); + const int w_high = w_low + 1; + const scalar_t v2 = + (h_low >= 0 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); + const scalar_t lw = w - w_low; + const scalar_t v_low = fmaf(v2 - v1, lw, v1); + input += width; + const scalar_t v3 = + (h_low <= height - 2 && w_low >= 0) ? input[w_low] : static_cast(0.0f); + const scalar_t v4 = + (h_low <= height - 2 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); + const scalar_t v_high = fmaf(v4 - v3, lw, v3); + const scalar_t lh = h - h_low; + const scalar_t val = fmaf(v_high - v_low, lh, v_low); + return val; } -template <> +template<> __device__ __forceinline__ __half deformable_im2col_bilinear(const __half* __restrict__ input, - const int height, const int width, - float h, float w) { - if (h <= -1 || height <= h || w <= -1 || width <= w) { - return 0; - } + const int height, + const int width, + float h, + float w) +{ + if (h <= -1 || height <= h || w <= -1 || width <= w) + { + return 0; + } - const int h_low = floorf(h); - const int w_low = floorf(w); + const int h_low = floorf(h); + const int w_low = floorf(w); - input += h_low * width; - const float v1 = (h_low >= 0 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f; - const int w_high = w_low + 1; - const float v2 = (h_low >= 0 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f; - const float lw = w - w_low; - const float v_low = fmaf(v2 - v1, lw, v1); - input += width; - const float v3 = (h_low <= height - 2 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f; - const float v4 = - (h_low <= height - 2 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f; - const float v_high = fmaf(v4 - v3, lw, v3); - const float lh = h - h_low; - const float val = fmaf(v_high - v_low, lh, v_low); - return __float2half(val); + input += h_low * width; + const float v1 = (h_low >= 0 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f; + const int w_high = w_low + 1; + const float v2 = (h_low >= 0 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f; + const float lw = w - w_low; + const float v_low = fmaf(v2 - v1, lw, v1); + input += width; + const float v3 = (h_low <= height - 2 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f; + const float v4 = + (h_low <= height - 2 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f; + const float v_high = fmaf(v4 - v3, lw, v3); + const float lh = h - h_low; + const float val = fmaf(v_high - v_low, lh, v_low); + return __float2half(val); } -template -__global__ void deformable_im2col_gpu_kernel( - const int n, const scalar_t* __restrict__ data_im, const scalar_t* __restrict__ data_offset, - const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, - const int pad_w, const int stride_h, const int stride_w, const int dilation_h, - const int dilation_w, const int channel_per_deformable_group, const int batch_size, - const int num_channels, const int deformable_group, const int height_col, const int width_col, - scalar_t* __restrict__ data_col) { - const int hw_col = height_col * width_col; - const int data_col_step = batch_size * hw_col; +template +__global__ void deformable_im2col_gpu_kernel(const int n, + const scalar_t* __restrict__ data_im, + const scalar_t* __restrict__ data_offset, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* __restrict__ data_col) +{ + const int hw_col = height_col * width_col; + const int data_col_step = batch_size * hw_col; - CUDA_1D_KERNEL_LOOP(index, n) { - // index index of output matrix - int tmp_index = index; - const int w_col = tmp_index % width_col; - tmp_index /= width_col; - const int h_col = tmp_index % height_col; - tmp_index /= height_col; - const int b_col = tmp_index % batch_size; - const int c_im = tmp_index / batch_size; - const int c_col = c_im * kernel_h * kernel_w; + CUDA_1D_KERNEL_LOOP(index, n) + { + // index index of output matrix + int tmp_index = index; + const int w_col = tmp_index % width_col; + tmp_index /= width_col; + const int h_col = tmp_index % height_col; + tmp_index /= height_col; + const int b_col = tmp_index % batch_size; + const int c_im = tmp_index / batch_size; + const int c_col = c_im * kernel_h * kernel_w; - // compute deformable group index - const int deformable_group_index = c_im / channel_per_deformable_group; + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - scalar_t* __restrict__ data_col_ptr = data_col + c_col * data_col_step + index % data_col_step; - const scalar_t* __restrict__ data_im_ptr = - data_im + (b_col * num_channels + c_im) * height * width; - const scalar_t* __restrict__ data_offset_ptr = - data_offset + - ((b_col * deformable_group + deformable_group_index) << 1) * kernel_h * kernel_w * hw_col + - h_col * width_col + w_col; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h = (i * kernel_w + j) * hw_col << 1; - const scalar_t offset_h = data_offset_ptr[data_offset_h]; - const int data_offset_w = data_offset_h + hw_col; - const scalar_t offset_w = data_offset_ptr[data_offset_w]; - const scalar_t h_im = h_in + i * dilation_h + (float)offset_h; - const scalar_t w_im = w_in + j * dilation_w + (float)offset_w; - const scalar_t val = deformable_im2col_bilinear(data_im_ptr, height, width, h_im, w_im); - *data_col_ptr = val; - data_col_ptr += data_col_step; - } + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t* __restrict__ data_col_ptr = data_col + c_col * data_col_step + index % data_col_step; + const scalar_t* __restrict__ data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t* __restrict__ data_offset_ptr = + data_offset + + ((b_col * deformable_group + deformable_group_index) << 1) * kernel_h * kernel_w * hw_col + + h_col * width_col + w_col; + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h = (i * kernel_w + j) * hw_col << 1; + const scalar_t offset_h = data_offset_ptr[data_offset_h]; + const int data_offset_w = data_offset_h + hw_col; + const scalar_t offset_w = data_offset_ptr[data_offset_w]; + const scalar_t h_im = h_in + i * dilation_h + (float)offset_h; + const scalar_t w_im = w_in + j * dilation_w + (float)offset_w; + const scalar_t val = deformable_im2col_bilinear(data_im_ptr, height, width, h_im, w_im); + *data_col_ptr = val; + data_col_ptr += data_col_step; + } + } } - } } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp index 3d8f6dfc45..012dc894f8 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp @@ -4,17 +4,47 @@ #include #include -template -void deform_conv_im2col(const scalar_t* input, const scalar_t* offset, scalar_t* column, - const int channels, const int height, const int width, const int ksize_h, - const int ksize_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int parallel_imgs, const int deformable_group, cudaStream_t stream); +template +void deform_conv_im2col(const scalar_t* input, + const scalar_t* offset, + scalar_t* column, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + cudaStream_t stream); -template -void deform_conv(const scalar_t* input, const scalar_t* weight, const scalar_t* offset, - scalar_t* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, - int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, int group, int deformable_group, - int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream); +template +void deform_conv(const scalar_t* input, + const scalar_t* weight, + const scalar_t* offset, + scalar_t* output, + void* workspace, + int batchSize, + int nInputPlane, + int inputHeight, + int inputWidth, + int nOutputPlane, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream); #endif // TRT_DEFORM_CONV_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp index b5e6c0b677..2de48da10b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp @@ -10,141 +10,203 @@ #include "gather_topk_kernel.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"GatherTopk"}; -} // namespace - -GatherTopk::GatherTopk(const std::string &name) : TRTPluginBase(name) {} - -GatherTopk::GatherTopk(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) {} - -nvinfer1::IPluginV2DynamicExt *GatherTopk::clone() const TRT_NOEXCEPT { - GatherTopk *plugin = new GatherTopk(mLayerName); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs GatherTopk::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - assert(inputs[0].nbDims >= inputs[1].nbDims); - nvinfer1::DimsExprs ret; - ret.nbDims = inputs[0].nbDims; - for (int i = 0; i < inputs[1].nbDims; ++i) { - ret.d[i] = inputs[1].d[i]; - } - for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) { - ret.d[i] = inputs[0].d[i]; - } - return ret; -} - -bool GatherTopk::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - switch (pos) { - case 0: - // data - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) || - (ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - case 1: - // indices - return ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - case 2: - // output - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - default: - return true; - } - return true; -} - -void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - const int *dims = &(inputDesc[0].dims.d[0]); - const int *indices_dims = &(inputDesc[1].dims.d[0]); - int nbDims = inputDesc[0].dims.nbDims; - int indice_nbDims = inputDesc[1].dims.nbDims; - - const void *data = inputs[0]; - const void *indices = inputs[1]; - void *output = outputs[0]; - - auto data_type = inputDesc[0].type; - - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - gather_topk_impl((float *)data, (int *)indices, dims, nbDims, indices_dims, - indice_nbDims, (float *)output, stream); - break; - - case nvinfer1::DataType::kINT32: - gather_topk_impl((int *)data, (int *)indices, dims, nbDims, indices_dims, indice_nbDims, - (int *)output, stream); - break; - default: - break; - } - - return 0; -} - -nvinfer1::DataType GatherTopk::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *GatherTopk::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *GatherTopk::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int GatherTopk::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT { return 0; } - -void GatherTopk::serialize(void *buffer) const TRT_NOEXCEPT {} - -GatherTopkCreator::GatherTopkCreator() { - mPluginAttributes.clear(); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *GatherTopkCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *GatherTopkCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - auto *plugin = new GatherTopk(name); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *GatherTopkCreator::deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - auto plugin = new GatherTopk(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(GatherTopkCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"GatherTopk"}; + } // namespace + + GatherTopk::GatherTopk(const std::string& name) + : TRTPluginBase(name) + { + } + + GatherTopk::GatherTopk(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + } + + nvinfer1::IPluginV2DynamicExt* GatherTopk::clone() const TRT_NOEXCEPT + { + GatherTopk* plugin = new GatherTopk(mLayerName); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs GatherTopk::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + assert(inputs[0].nbDims >= inputs[1].nbDims); + nvinfer1::DimsExprs ret; + ret.nbDims = inputs[0].nbDims; + for (int i = 0; i < inputs[1].nbDims; ++i) + { + ret.d[i] = inputs[1].d[i]; + } + for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) + { + ret.d[i] = inputs[0].d[i]; + } + return ret; + } + + bool GatherTopk::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + switch (pos) + { + case 0: + // data + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) || + (ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + case 1: + // indices + return ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + case 2: + // output + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + default: + return true; + } + return true; + } + + void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + + size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + const int* dims = &(inputDesc[0].dims.d[0]); + const int* indices_dims = &(inputDesc[1].dims.d[0]); + int nbDims = inputDesc[0].dims.nbDims; + int indice_nbDims = inputDesc[1].dims.nbDims; + + const void* data = inputs[0]; + const void* indices = inputs[1]; + void* output = outputs[0]; + + auto data_type = inputDesc[0].type; + + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + gather_topk_impl((float*)data, + (int*)indices, + dims, + nbDims, + indices_dims, + indice_nbDims, + (float*)output, + stream); + break; + + case nvinfer1::DataType::kINT32: + gather_topk_impl((int*)data, + (int*)indices, + dims, + nbDims, + indices_dims, + indice_nbDims, + (int*)output, + stream); + break; + default: + break; + } + + return 0; + } + + nvinfer1::DataType GatherTopk::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* GatherTopk::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* GatherTopk::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int GatherTopk::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT + { + return 0; + } + + void GatherTopk::serialize(void* buffer) const TRT_NOEXCEPT {} + + GatherTopkCreator::GatherTopkCreator() + { + mPluginAttributes.clear(); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* GatherTopkCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* GatherTopkCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + auto* plugin = new GatherTopk(name); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* GatherTopkCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new GatherTopk(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(GatherTopkCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp index 72f76a2678..d1a0df29e3 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp @@ -9,56 +9,75 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class GatherTopk : public TRTPluginBase { - public: - GatherTopk(const std::string &name); - - GatherTopk(const std::string name, const void *data, size_t length); - - GatherTopk() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; -}; - -class GatherTopkCreator : public TRTPluginCreatorBase { - public: - GatherTopkCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class GatherTopk : public TRTPluginBase + { + public: + GatherTopk(const std::string& name); + + GatherTopk(const std::string name, const void* data, size_t length); + + GatherTopk() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + }; + + class GatherTopkCreator : public TRTPluginCreatorBase + { + public: + GatherTopkCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_SCATTERND_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.cu index 9a5c8ec963..3c1663d499 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.cu @@ -8,39 +8,67 @@ #include "gather_topk_kernel.hpp" #include "trt_plugin_helper.hpp" -template -__global__ void gather_topk_kernel(const scalar_t* input, const int* indices, scalar_t* output, - int batch, int num_input, int num_indices, int channel) { - CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) { - const int b_id = index / (num_indices * channel); - const int n_id = (index / channel) % num_indices; - const int c_id = index % channel; +template +__global__ void gather_topk_kernel(const scalar_t* input, + const int* indices, + scalar_t* output, + int batch, + int num_input, + int num_indices, + int channel) +{ + CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) + { + const int b_id = index / (num_indices * channel); + const int n_id = (index / channel) % num_indices; + const int c_id = index % channel; - const int input_n_id = indices[b_id * num_indices + n_id]; - const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id]; - output[b_id * num_indices * channel + n_id * channel + c_id] = value; - } + const int input_n_id = indices[b_id * num_indices + n_id]; + const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id]; + output[b_id * num_indices * channel + n_id * channel + c_id] = value; + } } -template -void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims, - const int* indices_dims, int indice_nbDims, scalar_t* output, - cudaStream_t stream) { - int batch = 1; - for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i]; - int num_input = dims[indice_nbDims - 1]; - int num_indices = indices_dims[indice_nbDims - 1]; - int channel = 1; - for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i]; - const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK); - gather_topk_kernel<<>>(input, indices, output, batch, - num_input, num_indices, channel); +template +void gather_topk_impl(const scalar_t* input, + const int* indices, + const int* dims, + int nbDims, + const int* indices_dims, + int indice_nbDims, + scalar_t* output, + cudaStream_t stream) +{ + int batch = 1; + for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i]; + int num_input = dims[indice_nbDims - 1]; + int num_indices = indices_dims[indice_nbDims - 1]; + int channel = 1; + for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i]; + const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK); + gather_topk_kernel<<>>(input, + indices, + output, + batch, + num_input, + num_indices, + channel); } -template void gather_topk_impl(const float* input, const int* indices, const int* dims, - int nbDims, const int* indices_dims, int indice_nbDims, - float* output, cudaStream_t stream); +template void gather_topk_impl(const float* input, + const int* indices, + const int* dims, + int nbDims, + const int* indices_dims, + int indice_nbDims, + float* output, + cudaStream_t stream); -template void gather_topk_impl(const int32_t* input, const int* indices, const int* dims, - int nbDims, const int* indices_dims, int indice_nbDims, - int32_t* output, cudaStream_t stream); +template void gather_topk_impl(const int32_t* input, + const int* indices, + const int* dims, + int nbDims, + const int* indices_dims, + int indice_nbDims, + int32_t* output, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.hpp index 1f9b428394..0c5c7e6011 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.hpp @@ -3,8 +3,13 @@ #define TRT_GRID_SAMPLER_KERNEL_HPP #include -template -void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims, - const int* indices_dims, int indice_nbDims, scalar_t* output, - cudaStream_t stream); +template +void gather_topk_impl(const scalar_t* input, + const int* indices, + const int* dims, + int nbDims, + const int* indices_dims, + int indice_nbDims, + scalar_t* output, + cudaStream_t stream); #endif // TRT_GRID_SAMPLER_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp index 1850fbfc1a..761b61538b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp @@ -10,145 +10,202 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"GridPriorsTRT"}; -} // namespace - -GridPriorsTRT::GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride) - : TRTPluginBase(name), mStride(stride) {} - -GridPriorsTRT::GridPriorsTRT(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mStride); -} -GridPriorsTRT::~GridPriorsTRT() {} - -nvinfer1::IPluginV2DynamicExt *GridPriorsTRT::clone() const TRT_NOEXCEPT { - GridPriorsTRT *plugin = new GridPriorsTRT(mLayerName, mStride); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs GridPriorsTRT::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - // input[0] == base_anchor - // input[1] == empty_h - // input[2] == empty_w - - nvinfer1::DimsExprs ret; - ret.nbDims = 2; - auto area = - exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[2].d[0], *inputs[1].d[0]); - ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *area, *(inputs[0].d[0])); - ret.d[1] = exprBuilder.constant(4); - - return ret; -} - -bool GridPriorsTRT::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else if (pos - nbInputs == 0) { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } else { - return true; - } -} - -int GridPriorsTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int num_base_anchors = inputDesc[0].dims.d[0]; - int feat_h = inputDesc[1].dims.d[0]; - int feat_w = inputDesc[2].dims.d[0]; - - const void *base_anchor = inputs[0]; - void *output = outputs[0]; - - auto data_type = inputDesc[0].type; - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - trt_grid_priors_impl((float *)base_anchor, (float *)output, num_base_anchors, feat_w, - feat_h, mStride.d[0], mStride.d[1], stream); - break; - default: - return 1; - } - - return 0; -} - -nvinfer1::DataType GridPriorsTRT::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *GridPriorsTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *GridPriorsTRT::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int GridPriorsTRT::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t GridPriorsTRT::getSerializationSize() const TRT_NOEXCEPT { return serialized_size(mStride); } - -void GridPriorsTRT::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mStride); - ; -} - -////////////////////// creator ///////////////////////////// - -GridPriorsTRTCreator::GridPriorsTRTCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *GridPriorsTRTCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *GridPriorsTRTCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *GridPriorsTRTCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int stride_w = 1; - int stride_h = 1; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("stride_w") == 0) { - stride_w = static_cast(fc->fields[i].data)[0]; - } - if (field_name.compare("stride_h") == 0) { - stride_h = static_cast(fc->fields[i].data)[0]; - } - } - nvinfer1::Dims stride{2, {stride_w, stride_h}}; - - GridPriorsTRT *plugin = new GridPriorsTRT(name, stride); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *GridPriorsTRTCreator::deserializePlugin(const char *name, - const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - auto plugin = new GridPriorsTRT(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(GridPriorsTRTCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"GridPriorsTRT"}; + } // namespace + + GridPriorsTRT::GridPriorsTRT(const std::string& name, const nvinfer1::Dims stride) + : TRTPluginBase(name) + , mStride(stride) + { + } + + GridPriorsTRT::GridPriorsTRT(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mStride); + } + GridPriorsTRT::~GridPriorsTRT() {} + + nvinfer1::IPluginV2DynamicExt* GridPriorsTRT::clone() const TRT_NOEXCEPT + { + GridPriorsTRT* plugin = new GridPriorsTRT(mLayerName, mStride); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs GridPriorsTRT::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + // input[0] == base_anchor + // input[1] == empty_h + // input[2] == empty_w + + nvinfer1::DimsExprs ret; + ret.nbDims = 2; + auto area = + exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[2].d[0], *inputs[1].d[0]); + ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *area, *(inputs[0].d[0])); + ret.d[1] = exprBuilder.constant(4); + + return ret; + } + + bool GridPriorsTRT::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else if (pos - nbInputs == 0) + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + else + { + return true; + } + } + + int GridPriorsTRT::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int num_base_anchors = inputDesc[0].dims.d[0]; + int feat_h = inputDesc[1].dims.d[0]; + int feat_w = inputDesc[2].dims.d[0]; + + const void* base_anchor = inputs[0]; + void* output = outputs[0]; + + auto data_type = inputDesc[0].type; + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + trt_grid_priors_impl((float*)base_anchor, + (float*)output, + num_base_anchors, + feat_w, + feat_h, + mStride.d[0], + mStride.d[1], + stream); + break; + default: + return 1; + } + + return 0; + } + + nvinfer1::DataType GridPriorsTRT::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* GridPriorsTRT::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* GridPriorsTRT::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int GridPriorsTRT::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t GridPriorsTRT::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mStride); + } + + void GridPriorsTRT::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mStride); + ; + } + + ////////////////////// creator ///////////////////////////// + + GridPriorsTRTCreator::GridPriorsTRTCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* GridPriorsTRTCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* GridPriorsTRTCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* GridPriorsTRTCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int stride_w = 1; + int stride_h = 1; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("stride_w") == 0) + { + stride_w = static_cast(fc->fields[i].data)[0]; + } + if (field_name.compare("stride_h") == 0) + { + stride_h = static_cast(fc->fields[i].data)[0]; + } + } + nvinfer1::Dims stride{2, {stride_w, stride_h}}; + + GridPriorsTRT* plugin = new GridPriorsTRT(name, stride); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* GridPriorsTRTCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new GridPriorsTRT(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(GridPriorsTRTCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp index 0036f62586..8285ba47ab 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp @@ -9,58 +9,72 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class GridPriorsTRT : public TRTPluginBase { - public: - GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride); +namespace mmdeploy +{ + class GridPriorsTRT : public TRTPluginBase + { + public: + GridPriorsTRT(const std::string& name, const nvinfer1::Dims stride); - GridPriorsTRT(const std::string name, const void *data, size_t length); + GridPriorsTRT(const std::string name, const void* data, size_t length); - GridPriorsTRT() = delete; + GridPriorsTRT() = delete; - ~GridPriorsTRT() TRT_NOEXCEPT override; + ~GridPriorsTRT() TRT_NOEXCEPT override; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; - private: - nvinfer1::Dims mStride; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - cublasHandle_t m_cublas_handle; -}; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; -class GridPriorsTRTCreator : public TRTPluginCreatorBase { - public: - GridPriorsTRTCreator(); + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - const char *getPluginName() const TRT_NOEXCEPT override; + private: + nvinfer1::Dims mStride; - const char *getPluginVersion() const TRT_NOEXCEPT override; + cublasHandle_t m_cublas_handle; + }; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + class GridPriorsTRTCreator : public TRTPluginCreatorBase + { + public: + GridPriorsTRTCreator(); - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_GRID_PRIORS_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu index 72c33d179a..f6207eecc1 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu @@ -5,39 +5,64 @@ #include "trt_grid_priors_kernel.hpp" #include "trt_plugin_helper.hpp" -template -__global__ void trt_grid_priors_kernel(const scalar_t* base_anchor, scalar_t* output, - int num_base_anchors, int feat_w, int feat_h, int stride_w, - int stride_h) { - // load base anchor into shared memory. - extern __shared__ scalar_t shared_base_anchor[]; - for (int i = threadIdx.x; i < num_base_anchors * 4; i += blockDim.x) { - shared_base_anchor[i] = base_anchor[i]; - } - __syncthreads(); +template +__global__ void trt_grid_priors_kernel(const scalar_t* base_anchor, + scalar_t* output, + int num_base_anchors, + int feat_w, + int feat_h, + int stride_w, + int stride_h) +{ + // load base anchor into shared memory. + extern __shared__ scalar_t shared_base_anchor[]; + for (int i = threadIdx.x; i < num_base_anchors * 4; i += blockDim.x) + { + shared_base_anchor[i] = base_anchor[i]; + } + __syncthreads(); - CUDA_1D_KERNEL_LOOP(index, num_base_anchors * feat_w * feat_h) { - const int a_offset = (index % num_base_anchors) << 2; - const scalar_t w = scalar_t(((index / num_base_anchors) % feat_w) * stride_w); - const scalar_t h = scalar_t((index / (feat_w * num_base_anchors)) * stride_h); + CUDA_1D_KERNEL_LOOP(index, num_base_anchors * feat_w * feat_h) + { + const int a_offset = (index % num_base_anchors) << 2; + const scalar_t w = scalar_t(((index / num_base_anchors) % feat_w) * stride_w); + const scalar_t h = scalar_t((index / (feat_w * num_base_anchors)) * stride_h); - auto out_start = output + index * 4; - out_start[0] = shared_base_anchor[a_offset] + w; - out_start[1] = shared_base_anchor[a_offset + 1] + h; - out_start[2] = shared_base_anchor[a_offset + 2] + w; - out_start[3] = shared_base_anchor[a_offset + 3] + h; - } + auto out_start = output + index * 4; + out_start[0] = shared_base_anchor[a_offset] + w; + out_start[1] = shared_base_anchor[a_offset + 1] + h; + out_start[2] = shared_base_anchor[a_offset + 2] + w; + out_start[3] = shared_base_anchor[a_offset + 3] + h; + } } -template -void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, - int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream) { - trt_grid_priors_kernel<<>>( - base_anchor, output, (int)num_base_anchors, (int)feat_w, (int)feat_h, (int)stride_w, - (int)stride_h); +template +void trt_grid_priors_impl(const scalar_t* base_anchor, + scalar_t* output, + int num_base_anchors, + int feat_w, + int feat_h, + int stride_w, + int stride_h, + cudaStream_t stream) +{ + trt_grid_priors_kernel<<>>(base_anchor, + output, + (int)num_base_anchors, + (int)feat_w, + (int)feat_h, + (int)stride_w, + (int)stride_h); } -template void trt_grid_priors_impl(const float* base_anchor, float* output, - int num_base_anchors, int feat_w, int feat_h, - int stride_w, int stride_h, cudaStream_t stream); +template void trt_grid_priors_impl(const float* base_anchor, + float* output, + int num_base_anchors, + int feat_w, + int feat_h, + int stride_w, + int stride_h, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp index 77cef58c54..5de3690b30 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp @@ -3,8 +3,14 @@ #define TRT_GRID_PRIORS_KERNEL_HPP #include -template -void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, - int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream); +template +void trt_grid_priors_impl(const scalar_t* base_anchor, + scalar_t* output, + int num_base_anchors, + int feat_w, + int feat_h, + int stride_w, + int stride_h, + cudaStream_t stream); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp index 7e55686902..9894f7f0b4 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp @@ -9,194 +9,257 @@ #include "trt_plugin_helper.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"grid_sampler"}; -} // namespace - -TRTGridSampler::TRTGridSampler(const std::string &name, int mode, int paddingMode, - bool alignCorners) - : TRTPluginBase(name), mMode(mode), mPaddingMode(paddingMode), mAlignCorners(alignCorners) {} - -TRTGridSampler::TRTGridSampler(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mMode); - deserialize_value(&data, &length, &mPaddingMode); - deserialize_value(&data, &length, &mAlignCorners); -} - -nvinfer1::IPluginV2DynamicExt *TRTGridSampler::clone() const TRT_NOEXCEPT { - TRTGridSampler *plugin = new TRTGridSampler(mLayerName, mMode, mPaddingMode, mAlignCorners); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTGridSampler::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs ret; - ret.nbDims = inputs[0].nbDims; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[0].d[1]; - for (int i = 2; i < ret.nbDims; ++i) { - ret.d[i] = inputs[1].d[i - 1]; - } - return ret; -} - -bool TRTGridSampler::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -void TRTGridSampler::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments -} - -size_t TRTGridSampler::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTGridSampler::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - nvinfer1::Dims input_dims = inputDesc[0].dims; - nvinfer1::Dims grid_dims = inputDesc[1].dims; - nvinfer1::Dims output_dims = outputDesc[0].dims; - - GridSamplerInterpolation interp_mode = GridSamplerInterpolation::Bilinear; - switch (mMode) { - case 0: - interp_mode = GridSamplerInterpolation::Bilinear; - break; - case 1: - interp_mode = GridSamplerInterpolation::Nearest; - break; - default: - break; - } - - GridSamplerPadding padding_mode = GridSamplerPadding::Zeros; - switch (mPaddingMode) { - case 0: - padding_mode = GridSamplerPadding::Zeros; - break; - - case 1: - padding_mode = GridSamplerPadding::Border; - break; - - case 2: - padding_mode = GridSamplerPadding::Reflection; - break; - default: - break; - } - - auto data_type = inputDesc[0].type; - - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - grid_sample((float *)outputs[0], (float *)inputs[0], (float *)inputs[1], - &(output_dims.d[0]), &(input_dims.d[0]), &(grid_dims.d[0]), - input_dims.nbDims, interp_mode, padding_mode, mAlignCorners, stream); - break; - default: - return 1; - break; - } - - return 0; -} - -nvinfer1::DataType TRTGridSampler::getOutputDataType(int index, - const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *TRTGridSampler::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTGridSampler::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTGridSampler::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTGridSampler::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mMode) + serialized_size(mPaddingMode) + serialized_size(mAlignCorners); -} - -void TRTGridSampler::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mMode); - serialize_value(&buffer, mPaddingMode); - serialize_value(&buffer, mAlignCorners); -} - -////////////////////// creator ///////////////////////////// - -TRTGridSamplerCreator::TRTGridSamplerCreator() { - mPluginAttributes = std::vector( - {nvinfer1::PluginField("interpolation_mode"), nvinfer1::PluginField("padding_mode"), - nvinfer1::PluginField("align_corners")}); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTGridSamplerCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTGridSamplerCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *TRTGridSamplerCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int mode = 0; - int paddingMode = 0; - bool alignCorners = false; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("interpolation_mode") == 0) { - mode = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("padding_mode") == 0) { - paddingMode = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("align_corners") == 0) { - alignCorners = (bool)(static_cast(fc->fields[i].data)[0]); - } - } - - TRTGridSampler *plugin = new TRTGridSampler(name, mode, paddingMode, alignCorners); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"grid_sampler"}; + } // namespace + + TRTGridSampler::TRTGridSampler(const std::string& name, int mode, int paddingMode, bool alignCorners) + : TRTPluginBase(name) + , mMode(mode) + , mPaddingMode(paddingMode) + , mAlignCorners(alignCorners) + { + } + + TRTGridSampler::TRTGridSampler(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mMode); + deserialize_value(&data, &length, &mPaddingMode); + deserialize_value(&data, &length, &mAlignCorners); + } + + nvinfer1::IPluginV2DynamicExt* TRTGridSampler::clone() const TRT_NOEXCEPT + { + TRTGridSampler* plugin = new TRTGridSampler(mLayerName, mMode, mPaddingMode, mAlignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTGridSampler::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs ret; + ret.nbDims = inputs[0].nbDims; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + for (int i = 2; i < ret.nbDims; ++i) + { + ret.d[i] = inputs[1].d[i - 1]; + } + return ret; + } + + bool TRTGridSampler::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + void TRTGridSampler::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + } + + size_t TRTGridSampler::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTGridSampler::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + nvinfer1::Dims input_dims = inputDesc[0].dims; + nvinfer1::Dims grid_dims = inputDesc[1].dims; + nvinfer1::Dims output_dims = outputDesc[0].dims; + + GridSamplerInterpolation interp_mode = GridSamplerInterpolation::Bilinear; + switch (mMode) + { + case 0: + interp_mode = GridSamplerInterpolation::Bilinear; + break; + case 1: + interp_mode = GridSamplerInterpolation::Nearest; + break; + default: + break; + } + + GridSamplerPadding padding_mode = GridSamplerPadding::Zeros; + switch (mPaddingMode) + { + case 0: + padding_mode = GridSamplerPadding::Zeros; + break; + + case 1: + padding_mode = GridSamplerPadding::Border; + break; + + case 2: + padding_mode = GridSamplerPadding::Reflection; + break; + default: + break; + } + + auto data_type = inputDesc[0].type; + + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + grid_sample((float*)outputs[0], + (float*)inputs[0], + (float*)inputs[1], + &(output_dims.d[0]), + &(input_dims.d[0]), + &(grid_dims.d[0]), + input_dims.nbDims, + interp_mode, + padding_mode, + mAlignCorners, + stream); + break; + default: + return 1; + break; + } + + return 0; + } + + nvinfer1::DataType TRTGridSampler::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* TRTGridSampler::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTGridSampler::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTGridSampler::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTGridSampler::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mMode) + serialized_size(mPaddingMode) + serialized_size(mAlignCorners); + } + + void TRTGridSampler::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mMode); + serialize_value(&buffer, mPaddingMode); + serialize_value(&buffer, mAlignCorners); + } + + ////////////////////// creator ///////////////////////////// -nvinfer1::IPluginV2 *TRTGridSamplerCreator::deserializePlugin(const char *name, - const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - // This object will be deleted when the network is destroyed, which will - // call FCPluginDynamic::destroy() - auto plugin = new TRTGridSampler(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} + TRTGridSamplerCreator::TRTGridSamplerCreator() + { + mPluginAttributes = std::vector({nvinfer1::PluginField("interpolation_mode"), + nvinfer1::PluginField("padding_mode"), + nvinfer1::PluginField("align_corners")}); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTGridSamplerCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTGridSamplerCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTGridSamplerCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int mode = 0; + int paddingMode = 0; + bool alignCorners = false; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("interpolation_mode") == 0) + { + mode = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("padding_mode") == 0) + { + paddingMode = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("align_corners") == 0) + { + alignCorners = (bool)(static_cast(fc->fields[i].data)[0]); + } + } + + TRTGridSampler* plugin = new TRTGridSampler(name, mode, paddingMode, alignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTGridSamplerCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + // This object will be deleted when the network is destroyed, which will + // call FCPluginDynamic::destroy() + auto plugin = new TRTGridSampler(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } -REGISTER_TENSORRT_PLUGIN(TRTGridSamplerCreator); + REGISTER_TENSORRT_PLUGIN(TRTGridSamplerCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp index 0f62bce7c8..286b955d6c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp @@ -9,76 +9,94 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { +namespace mmdeploy +{ -class TRTGridSampler : public TRTPluginBase { - public: - TRTGridSampler(const std::string &name, int mode, int paddingMode, bool alignCorners); + class TRTGridSampler : public TRTPluginBase + { + public: + TRTGridSampler(const std::string& name, + int mode, + int paddingMode, + bool alignCorners); - TRTGridSampler(const std::string name, const void *data, size_t length); + TRTGridSampler(const std::string name, + const void* data, + size_t length); - TRTGridSampler() = delete; + TRTGridSampler() = delete; - ~TRTGridSampler() TRT_NOEXCEPT override = default; + ~TRTGridSampler() TRT_NOEXCEPT override = default; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - private: - int mMode; - int mPaddingMode; - bool mAlignCorners; -}; + private: + int mMode; + int mPaddingMode; + bool mAlignCorners; + }; -class TRTGridSamplerCreator : public TRTPluginCreatorBase { - public: - TRTGridSamplerCreator(); + class TRTGridSamplerCreator : public TRTPluginCreatorBase + { + public: + TRTGridSamplerCreator(); - ~TRTGridSamplerCreator() TRT_NOEXCEPT override = default; + ~TRTGridSamplerCreator() TRT_NOEXCEPT override = default; - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_GRID_SAMPLER_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu index 5d83f98d2c..28d581dd66 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu @@ -27,370 +27,470 @@ using mmdeploy::TensorDesc; // -1 --> -0.5 // +1 --> (size - 1) + 0.5 == size - 0.5 // scale_factor = size / 2 -template -static __forceinline__ __device__ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, - bool align_corners) { - if (align_corners) { - // unnormalize coord from [-1, 1] to [0, size - 1] - return ((coord + 1.f) / 2) * (size - 1); - } else { - // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] - return ((coord + 1.f) * size - 1) / 2; - } +template +static __forceinline__ __device__ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) +{ + if (align_corners) + { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1.f) / 2) * (size - 1); + } + else + { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1.f) * size - 1) / 2; + } } // Clips coordinates to between 0 and clip_limit - 1 -template -static __forceinline__ __device__ scalar_t clip_coordinates(scalar_t in, int clip_limit) { - return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); +template +static __forceinline__ __device__ scalar_t clip_coordinates(scalar_t in, int clip_limit) +{ + return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); } // Reflects coordinates until they fall between low and high (inclusive). // The bounds are passed as twice their value so that half-integer values // can be represented as ints. -template -static __forceinline__ __device__ scalar_t reflect_coordinates(scalar_t in, int twice_low, - int twice_high) { - if (twice_low == twice_high) { - return static_cast(0); - } - scalar_t min = static_cast(twice_low) / 2; - scalar_t span = static_cast(twice_high - twice_low) / 2; - in = ::fabs(in - min); - // `fmod` returns same sign as `in`, which is positive after the `fabs` above. - scalar_t extra = ::fmod(in, span); - int flips = static_cast(::floor(in / span)); - if (flips % 2 == 0) { - return extra + min; - } else { - return span - extra + min; - } +template +static __forceinline__ __device__ scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) +{ + if (twice_low == twice_high) + { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = ::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 == 0) + { + return extra + min; + } + else + { + return span - extra + min; + } } -template -static __forceinline__ __device__ scalar_t safe_downgrade_to_int_range(scalar_t x) { - // -100.0 does not have special meaning. This is just to make sure - // it's not within_bounds_2d or within_bounds_3d, and does not cause - // undefined behavior. See #35506. - if (x > INT_MAX - 1 || x < INT_MIN || !::isfinite(static_cast(x))) - return static_cast(-100.0); - return x; +template +static __forceinline__ __device__ scalar_t safe_downgrade_to_int_range(scalar_t x) +{ + // -100.0 does not have special meaning. This is just to make sure + // it's not within_bounds_2d or within_bounds_3d, and does not cause + // undefined behavior. See #35506. + if (x > INT_MAX - 1 || x < INT_MIN || !::isfinite(static_cast(x))) + return static_cast(-100.0); + return x; } // Computes the pixel source index value for a grid coordinate -template +template static __forceinline__ __device__ scalar_t grid_sampler_compute_source_index( - scalar_t coord, int size, GridSamplerPadding padding_mode, bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); - if (padding_mode == GridSamplerPadding::Border) { - // clip coordinates to image borders - coord = clip_coordinates(coord, size); - } else if (padding_mode == GridSamplerPadding::Reflection) { - // reflect coordinates by image borders - if (align_corners) { - coord = reflect_coordinates(coord, 0, 2 * (size - 1)); - } else { - coord = reflect_coordinates(coord, -1, 2 * size - 1); + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners) +{ + coord = grid_sampler_unnormalize(coord, size, align_corners); + if (padding_mode == GridSamplerPadding::Border) + { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } + else if (padding_mode == GridSamplerPadding::Reflection) + { + // reflect coordinates by image borders + if (align_corners) + { + coord = reflect_coordinates(coord, 0, 2 * (size - 1)); + } + else + { + coord = reflect_coordinates(coord, -1, 2 * size - 1); + } + // clip coordinates to image borders + coord = clip_coordinates(coord, size); } - // clip coordinates to image borders - coord = clip_coordinates(coord, size); - } - coord = safe_downgrade_to_int_range(coord); - return coord; + coord = safe_downgrade_to_int_range(coord); + return coord; } -static __forceinline__ __device__ bool within_bounds_2d(int h, int w, int H, int W) { - return h >= 0 && h < H && w >= 0 && w < W; +static __forceinline__ __device__ bool within_bounds_2d(int h, int w, int H, int W) +{ + return h >= 0 && h < H && w >= 0 && w < W; } -static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { - return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) +{ + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } -template -__global__ void grid_sampler_2d_kernel(const int nthreads, const scalar_t *input, - const scalar_t *grid, scalar_t *output, - TensorDesc input_desc, TensorDesc grid_desc, - TensorDesc output_desc, +template +__global__ void grid_sampler_2d_kernel(const int nthreads, + const scalar_t* input, + const scalar_t* grid, + scalar_t* output, + TensorDesc input_desc, + TensorDesc grid_desc, + TensorDesc output_desc, const GridSamplerInterpolation interpolation_mode, - const GridSamplerPadding padding_mode, bool align_corners) { - int C = input_desc.shape[1]; - int inp_H = input_desc.shape[2]; - int inp_W = input_desc.shape[3]; - int out_H = grid_desc.shape[1]; - int out_W = grid_desc.shape[2]; - int inp_sN = input_desc.stride[0]; - int inp_sC = input_desc.stride[1]; - int inp_sH = input_desc.stride[2]; - int inp_sW = input_desc.stride[3]; - int grid_sN = grid_desc.stride[0]; - int grid_sH = grid_desc.stride[1]; - int grid_sW = grid_desc.stride[2]; - int grid_sCoor = grid_desc.stride[3]; - int out_sN = output_desc.stride[0]; - int out_sC = output_desc.stride[1]; - int out_sH = output_desc.stride[2]; - int out_sW = output_desc.stride[3]; - - CUDA_1D_KERNEL_LOOP(index, nthreads) { - const int w = index % out_W; - const int h = (index / out_W) % out_H; - const int n = index / (out_H * out_W); - const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; - - // get the corresponding input x, y coordinates from grid - scalar_t ix = grid[grid_offset]; - scalar_t iy = grid[grid_offset + grid_sCoor]; - - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); - - if (interpolation_mode == GridSamplerInterpolation::Bilinear) { - // get NE, NW, SE, SW pixel values from (x, y) - int ix_nw = static_cast(::floor(ix)); - int iy_nw = static_cast(::floor(iy)); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; - - // get surfaces to each neighbor: - scalar_t nw = (ix_se - ix) * (iy_se - iy); - scalar_t ne = (ix - ix_sw) * (iy_sw - iy); - scalar_t sw = (ix_ne - ix) * (iy - iy_ne); - scalar_t se = (ix - ix_nw) * (iy - iy_nw); - - // calculate bilinear weighted pixel value and set output pixel - auto inp_ptr_NC = input + n * inp_sN; - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; - for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { - *out_ptr_NCHW = static_cast(0); - if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { - *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; - } - if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { - *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; - } - if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { - *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + const GridSamplerPadding padding_mode, + bool align_corners) +{ + int C = input_desc.shape[1]; + int inp_H = input_desc.shape[2]; + int inp_W = input_desc.shape[3]; + int out_H = grid_desc.shape[1]; + int out_W = grid_desc.shape[2]; + int inp_sN = input_desc.stride[0]; + int inp_sC = input_desc.stride[1]; + int inp_sH = input_desc.stride[2]; + int inp_sW = input_desc.stride[3]; + int grid_sN = grid_desc.stride[0]; + int grid_sH = grid_desc.stride[1]; + int grid_sW = grid_desc.stride[2]; + int grid_sCoor = grid_desc.stride[3]; + int out_sN = output_desc.stride[0]; + int out_sC = output_desc.stride[1]; + int out_sH = output_desc.stride[2]; + int out_sW = output_desc.stride[3]; + + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int n = index / (out_H * out_W); + const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y coordinates from grid + scalar_t ix = grid[grid_offset]; + scalar_t iy = grid[grid_offset + grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get NE, NW, SE, SW pixel values from (x, y) + int ix_nw = static_cast(::floor(ix)); + int iy_nw = static_cast(::floor(iy)); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) + { + *out_ptr_NCHW = static_cast(0); + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) + { + *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) + { + *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) + { + *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) + { + *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + } } - if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { - *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + else if (interpolation_mode == GridSamplerInterpolation::Nearest) + { + int ix_nearest = static_cast(::round(ix)); + int iy_nearest = static_cast(::round(iy)); + + // assign nearest neighbor pixel value to output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) + { + if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) + { + *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; + } + else + { + *out_ptr_NCHW = static_cast(0); + } + } } - } - } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { - int ix_nearest = static_cast(::round(ix)); - int iy_nearest = static_cast(::round(iy)); - - // assign nearest neighbor pixel value to output pixel - auto inp_ptr_NC = input + n * inp_sN; - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; - for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { - if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { - *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; - } else { - *out_ptr_NCHW = static_cast(0); - } - } } - } } -template -__global__ void grid_sampler_3d_kernel(const int nthreads, const scalar_t *input, - const scalar_t *grid, scalar_t *output, - TensorDesc input_desc, TensorDesc grid_desc, - TensorDesc output_desc, +template +__global__ void grid_sampler_3d_kernel(const int nthreads, + const scalar_t* input, + const scalar_t* grid, + scalar_t* output, + TensorDesc input_desc, + TensorDesc grid_desc, + TensorDesc output_desc, const GridSamplerInterpolation interpolation_mode, - const GridSamplerPadding padding_mode, bool align_corners) { - int C = input_desc.shape[1]; - int inp_D = input_desc.shape[2]; - int inp_H = input_desc.shape[3]; - int inp_W = input_desc.shape[4]; - int out_D = grid_desc.shape[1]; - int out_H = grid_desc.shape[2]; - int out_W = grid_desc.shape[3]; - int inp_sN = input_desc.stride[0]; - int inp_sC = input_desc.stride[1]; - int inp_sD = input_desc.stride[2]; - int inp_sH = input_desc.stride[3]; - int inp_sW = input_desc.stride[4]; - int grid_sN = grid_desc.stride[0]; - int grid_sD = grid_desc.stride[1]; - int grid_sH = grid_desc.stride[2]; - int grid_sW = grid_desc.stride[3]; - int grid_sCoor = grid_desc.stride[4]; - int out_sN = output_desc.stride[0]; - int out_sC = output_desc.stride[1]; - int out_sD = output_desc.stride[2]; - int out_sH = output_desc.stride[3]; - int out_sW = output_desc.stride[4]; - - CUDA_1D_KERNEL_LOOP(index, nthreads) { - const int w = index % out_W; - const int h = (index / out_W) % out_H; - const int d = (index / (out_H * out_W)) % out_D; - const int n = index / (out_D * out_H * out_W); - const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; - - // get the corresponding input x, y, z coordinates from grid - scalar_t ix = grid[grid_offset]; - scalar_t iy = grid[grid_offset + grid_sCoor]; - scalar_t iz = grid[grid_offset + 2 * grid_sCoor]; - - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); - iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); - - if (interpolation_mode == GridSamplerInterpolation::Bilinear) { - // get corner pixel values from (x, y, z) - // for 4d, we used north-east-south-west - // for 5d, we add top-bottom - int ix_tnw = static_cast(::floor(ix)); - int iy_tnw = static_cast(::floor(iy)); - int iz_tnw = static_cast(::floor(iz)); - - int ix_tne = ix_tnw + 1; - int iy_tne = iy_tnw; - int iz_tne = iz_tnw; - - int ix_tsw = ix_tnw; - int iy_tsw = iy_tnw + 1; - int iz_tsw = iz_tnw; - - int ix_tse = ix_tnw + 1; - int iy_tse = iy_tnw + 1; - int iz_tse = iz_tnw; - - int ix_bnw = ix_tnw; - int iy_bnw = iy_tnw; - int iz_bnw = iz_tnw + 1; - - int ix_bne = ix_tnw + 1; - int iy_bne = iy_tnw; - int iz_bne = iz_tnw + 1; - - int ix_bsw = ix_tnw; - int iy_bsw = iy_tnw + 1; - int iz_bsw = iz_tnw + 1; - - int ix_bse = ix_tnw + 1; - int iy_bse = iy_tnw + 1; - int iz_bse = iz_tnw + 1; - - // get surfaces to each neighbor: - scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); - scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); - scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); - scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); - scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); - scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); - scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); - scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); - - auto inp_ptr_NC = input + n * inp_sN; - auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; - for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { - // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * - // tne - // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * - // tse - // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * - // bne - // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * - // bse - *out_ptr_NCDHW = static_cast(0); - if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; - } - if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; - } - if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + const GridSamplerPadding padding_mode, + bool align_corners) +{ + int C = input_desc.shape[1]; + int inp_D = input_desc.shape[2]; + int inp_H = input_desc.shape[3]; + int inp_W = input_desc.shape[4]; + int out_D = grid_desc.shape[1]; + int out_H = grid_desc.shape[2]; + int out_W = grid_desc.shape[3]; + int inp_sN = input_desc.stride[0]; + int inp_sC = input_desc.stride[1]; + int inp_sD = input_desc.stride[2]; + int inp_sH = input_desc.stride[3]; + int inp_sW = input_desc.stride[4]; + int grid_sN = grid_desc.stride[0]; + int grid_sD = grid_desc.stride[1]; + int grid_sH = grid_desc.stride[2]; + int grid_sW = grid_desc.stride[3]; + int grid_sCoor = grid_desc.stride[4]; + int out_sN = output_desc.stride[0]; + int out_sC = output_desc.stride[1]; + int out_sD = output_desc.stride[2]; + int out_sH = output_desc.stride[3]; + int out_sW = output_desc.stride[4]; + + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int d = (index / (out_H * out_W)) % out_D; + const int n = index / (out_D * out_H * out_W); + const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z coordinates from grid + scalar_t ix = grid[grid_offset]; + scalar_t iy = grid[grid_offset + grid_sCoor]; + scalar_t iz = grid[grid_offset + 2 * grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int ix_tnw = static_cast(::floor(ix)); + int iy_tnw = static_cast(::floor(iy)); + int iz_tnw = static_cast(::floor(iz)); + + int ix_tne = ix_tnw + 1; + int iy_tne = iy_tnw; + int iz_tne = iz_tnw; + + int ix_tsw = ix_tnw; + int iy_tsw = iy_tnw + 1; + int iz_tsw = iz_tnw; + + int ix_tse = ix_tnw + 1; + int iy_tse = iy_tnw + 1; + int iz_tse = iz_tnw; + + int ix_bnw = ix_tnw; + int iy_bnw = iy_tnw; + int iz_bnw = iz_tnw + 1; + + int ix_bne = ix_tnw + 1; + int iy_bne = iy_tnw; + int iz_bne = iz_tnw + 1; + + int ix_bsw = ix_tnw; + int iy_bsw = iy_tnw + 1; + int iz_bsw = iz_tnw + 1; + + int ix_bse = ix_tnw + 1; + int iy_bse = iy_tnw + 1; + int iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) + { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * + // tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * + // tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * + // bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * + // bse + *out_ptr_NCDHW = static_cast(0); + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + } } - if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + else if (interpolation_mode == GridSamplerInterpolation::Nearest) + { + int ix_nearest = static_cast(::round(ix)); + int iy_nearest = static_cast(::round(iy)); + int iz_nearest = static_cast(::round(iz)); + + // assign nearest neighbor pixel value to output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) + { + if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW = + inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } + else + { + *out_ptr_NCDHW = static_cast(0); + } + } } - if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; - } - if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; - } - if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; - } - if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; - } - } - } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { - int ix_nearest = static_cast(::round(ix)); - int iy_nearest = static_cast(::round(iy)); - int iz_nearest = static_cast(::round(iz)); - - // assign nearest neighbor pixel value to output pixel - auto inp_ptr_NC = input + n * inp_sN; - auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; - for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { - if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW = - inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; - } else { - *out_ptr_NCDHW = static_cast(0); - } - } } - } } -void create_desc(const int *dims, int nb_dims, TensorDesc &desc) { - memcpy(&desc.shape[0], dims, sizeof(int) * nb_dims); - desc.stride[nb_dims - 1] = 1; - for (int i = nb_dims - 2; i >= 0; --i) { - desc.stride[i] = desc.stride[i + 1] * desc.shape[i + 1]; - } +void create_desc(const int* dims, int nb_dims, TensorDesc& desc) +{ + memcpy(&desc.shape[0], dims, sizeof(int) * nb_dims); + desc.stride[nb_dims - 1] = 1; + for (int i = nb_dims - 2; i >= 0; --i) + { + desc.stride[i] = desc.stride[i + 1] * desc.shape[i + 1]; + } } -template -void grid_sample(T *output, const T *input, const T *grid, int *output_dims, int *input_dims, - int *grid_dims, int nb_dims, GridSamplerInterpolation interp, - GridSamplerPadding padding, bool align_corners, cudaStream_t stream) { - TensorDesc input_desc; - create_desc(input_dims, nb_dims, input_desc); - - TensorDesc output_desc; - create_desc(output_dims, nb_dims, output_desc); - - TensorDesc grid_desc; - create_desc(grid_dims, nb_dims, grid_desc); +template +void grid_sample(T* output, + const T* input, + const T* grid, + int* output_dims, + int* input_dims, + int* grid_dims, + int nb_dims, + GridSamplerInterpolation interp, + GridSamplerPadding padding, + bool align_corners, + cudaStream_t stream) +{ + TensorDesc input_desc; + create_desc(input_dims, nb_dims, input_desc); + + TensorDesc output_desc; + create_desc(output_dims, nb_dims, output_desc); + + TensorDesc grid_desc; + create_desc(grid_dims, nb_dims, grid_desc); + + int count = 1; + for (int i = 0; i < nb_dims; ++i) + { + if (i == 1) + { + continue; + } + count *= output_desc.shape[i]; + } - int count = 1; - for (int i = 0; i < nb_dims; ++i) { - if (i == 1) { - continue; + if (nb_dims == 4) + { + grid_sampler_2d_kernel<<>>(count, + input, + grid, + output, + input_desc, + grid_desc, + output_desc, + interp, + padding, + align_corners); + } + else if (nb_dims == 5) + { + grid_sampler_3d_kernel<<>>(count, + input, + grid, + output, + input_desc, + grid_desc, + output_desc, + interp, + padding, + align_corners); + } + else + { + printf("input and grid dims should be 4 or 5\n"); } - count *= output_desc.shape[i]; - } - - if (nb_dims == 4) { - grid_sampler_2d_kernel<<>>( - count, input, grid, output, input_desc, grid_desc, output_desc, interp, padding, - align_corners); - } else if (nb_dims == 5) { - grid_sampler_3d_kernel<<>>( - count, input, grid, output, input_desc, grid_desc, output_desc, interp, padding, - align_corners); - } else { - printf("input and grid dims should be 4 or 5\n"); - } } -template void grid_sample(float *output, const float *input, const float *grid, - int *output_dims, int *input_dims, int *grid_dims, int nb_dims, - GridSamplerInterpolation interp, GridSamplerPadding padding, - bool align_corners, cudaStream_t stream); +template void grid_sample(float* output, + const float* input, + const float* grid, + int* output_dims, + int* input_dims, + int* grid_dims, + int nb_dims, + GridSamplerInterpolation interp, + GridSamplerPadding padding, + bool align_corners, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp index e4e50332f4..2da0e3abc5 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp @@ -3,11 +3,28 @@ #define TRT_GRID_SAMPLER_KERNEL_HPP #include -enum class GridSamplerInterpolation { Bilinear, Nearest }; -enum class GridSamplerPadding { Zeros, Border, Reflection }; +enum class GridSamplerInterpolation +{ + Bilinear, + Nearest +}; +enum class GridSamplerPadding +{ + Zeros, + Border, + Reflection +}; -template -void grid_sample(T *output, const T *input, const T *grid, int *output_dims, int *input_dims, - int *grid_dims, int nb_dims, GridSamplerInterpolation interp, - GridSamplerPadding padding, bool align_corners, cudaStream_t stream); +template +void grid_sample(T* output, + const T* input, + const T* grid, + int* output_dims, + int* input_dims, + int* grid_dims, + int nb_dims, + GridSamplerInterpolation interp, + GridSamplerPadding padding, + bool align_corners, + cudaStream_t stream); #endif // TRT_GRID_SAMPLER_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp b/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp index e6aab92f4c..7b5ed533e5 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp @@ -12,203 +12,259 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -constexpr const char* PLUGIN_VERSION{"1"}; -constexpr const char* PLUGIN_NAME{"TRTInstanceNormalization"}; -} // namespace - -TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, float epsilon) - : TRTPluginBase(name), mEpsilon(epsilon) {} - -TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, void const* serialData, - size_t serialLength) - : TRTPluginBase(name) { - deserialize_value(&serialData, &serialLength, &mEpsilon); -} - -TRTInstanceNormalization::~TRTInstanceNormalization() {} - -// TRTInstanceNormalization returns one output. -int TRTInstanceNormalization::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -DimsExprs TRTInstanceNormalization::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs output(inputs[0]); - return output; -} - -size_t TRTInstanceNormalization::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT { - int n = inputs[0].dims.d[0]; - int c = inputs[0].dims.d[1]; - int elem_size = sizeof(float); - return getAlignedSize(n * c * elem_size) * 2; -} - -int TRTInstanceNormalization::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, - void* workspace, cudaStream_t stream) TRT_NOEXCEPT { - nvinfer1::Dims input_dims = inputDesc[0].dims; - int n = input_dims.d[0]; - int c = input_dims.d[1]; - int h = input_dims.d[2]; - int w = input_dims.nbDims > 3 ? input_dims.d[3] : 1; - int elem_size = sizeof(float); - - void* n_scales = (void*)workspace; - void* n_bias = (void*)((char*)workspace + getAlignedSize(n * c * elem_size)); - - const void* scales = (const void*)inputs[1]; - const void* bias = (const void*)inputs[2]; - - for (int i = 0; i < n; ++i) { - cudaMemcpyAsync((char*)n_scales + i * c * elem_size, scales, c * elem_size, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync((char*)n_bias + i * c * elem_size, bias, c * elem_size, - cudaMemcpyDeviceToDevice, stream); - } - - cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1); - cudnnDataType_t cudnn_dtype{}; - convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype); - cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); - cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); - float alpha = 1; - float beta = 0; - void const* x_ptr = inputs[0]; - void* y_ptr = outputs[0]; - cudnnSetStream(_cudnn_handle, stream); - // Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical - // overflows (NaNs) for fp32 data in some circumstances. The lower- - // performance CUDNN_BATCHNORM_SPATIAL should be used if this is not - // acceptable. - cudnnBatchNormalizationForwardTraining(_cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, - &beta, _x_desc, x_ptr, _y_desc, y_ptr, _b_desc, n_scales, - n_bias, 1., nullptr, nullptr, mEpsilon, nullptr, nullptr); - return 0; -} - -size_t TRTInstanceNormalization::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mEpsilon); -} - -void TRTInstanceNormalization::serialize(void* buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mEpsilon); -} - -bool TRTInstanceNormalization::supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - switch (pos) { - case 0: - case 3: - return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || - ioDesc[pos].type == nvinfer1::DataType::kHALF) && - ioDesc[pos].format == nvinfer1::PluginFormat::kLINEAR && - ioDesc[pos].type == ioDesc[0].type); - case 1: - case 2: - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::PluginFormat::kLINEAR; - default: - return false; - } - return false; -} - -const char* TRTInstanceNormalization::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char* TRTInstanceNormalization::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -IPluginV2DynamicExt* TRTInstanceNormalization::clone() const TRT_NOEXCEPT { - auto* plugin = new TRTInstanceNormalization{mLayerName, mEpsilon}; - plugin->setPluginNamespace(mPluginNamespace.c_str()); - return plugin; -} - -nvinfer1::DataType TRTInstanceNormalization::getOutputDataType(int index, - const nvinfer1::DataType* inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// Attach the plugin object to an execution context and grant the plugin the -// access to some context resource. -void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext, - cublasContext* cublasContext, - IGpuAllocator* gpuAllocator) TRT_NOEXCEPT { - _cudnn_handle = cudnnContext; - cudnnCreateTensorDescriptor(&_b_desc); - cudnnCreateTensorDescriptor(&_x_desc); - cudnnCreateTensorDescriptor(&_y_desc); -} - -// Detach the plugin object from its execution context. -void TRTInstanceNormalization::detachFromContext() TRT_NOEXCEPT { - if (_y_desc) { - cudnnDestroyTensorDescriptor(_y_desc); - _y_desc = nullptr; - } - if (_x_desc) { - cudnnDestroyTensorDescriptor(_x_desc); - _x_desc = nullptr; - } - if (_b_desc) { - cudnnDestroyTensorDescriptor(_b_desc); - _b_desc = nullptr; - } -} - -void TRTInstanceNormalization::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) TRT_NOEXCEPT {} - -// TRTInstanceNormalizationCreator methods -TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1)); - - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char* TRTInstanceNormalizationCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char* TRTInstanceNormalizationCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin( - const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT { - float epsilon = 1e-5; - const PluginField* fields = fc->fields; - for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; - if (!strcmp(attrName, "epsilon")) { - epsilon = *(static_cast(fields[i].data)); - } - } - - TRTInstanceNormalization* obj = new TRTInstanceNormalization(name, epsilon); - obj->setPluginNamespace(mNamespace.c_str()); - return obj; -} - -IPluginV2DynamicExt* TRTInstanceNormalizationCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT { - TRTInstanceNormalization* obj = new TRTInstanceNormalization{name, serialData, serialLength}; - obj->setPluginNamespace(mNamespace.c_str()); - return obj; -} -REGISTER_TENSORRT_PLUGIN(TRTInstanceNormalizationCreator); +namespace mmdeploy +{ + namespace + { + constexpr const char* PLUGIN_VERSION{"1"}; + constexpr const char* PLUGIN_NAME{"TRTInstanceNormalization"}; + } // namespace + + TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, + float epsilon) + : TRTPluginBase(name) + , mEpsilon(epsilon) + { + } + + TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, + void const* serialData, + size_t serialLength) + : TRTPluginBase(name) + { + deserialize_value(&serialData, &serialLength, &mEpsilon); + } + + TRTInstanceNormalization::~TRTInstanceNormalization() {} + + // TRTInstanceNormalization returns one output. + int TRTInstanceNormalization::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + DimsExprs TRTInstanceNormalization::getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs output(inputs[0]); + return output; + } + + size_t TRTInstanceNormalization::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + int n = inputs[0].dims.d[0]; + int c = inputs[0].dims.d[1]; + int elem_size = sizeof(float); + return getAlignedSize(n * c * elem_size) * 2; + } + + int TRTInstanceNormalization::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT + { + nvinfer1::Dims input_dims = inputDesc[0].dims; + int n = input_dims.d[0]; + int c = input_dims.d[1]; + int h = input_dims.d[2]; + int w = input_dims.nbDims > 3 ? input_dims.d[3] : 1; + int elem_size = sizeof(float); + + void* n_scales = (void*)workspace; + void* n_bias = (void*)((char*)workspace + getAlignedSize(n * c * elem_size)); + + const void* scales = (const void*)inputs[1]; + const void* bias = (const void*)inputs[2]; + + for (int i = 0; i < n; ++i) + { + cudaMemcpyAsync((char*)n_scales + i * c * elem_size, scales, c * elem_size, cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync((char*)n_bias + i * c * elem_size, bias, c * elem_size, cudaMemcpyDeviceToDevice, stream); + } + + cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1); + cudnnDataType_t cudnn_dtype{}; + convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype); + cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); + cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); + float alpha = 1; + float beta = 0; + void const* x_ptr = inputs[0]; + void* y_ptr = outputs[0]; + cudnnSetStream(_cudnn_handle, stream); + // Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical + // overflows (NaNs) for fp32 data in some circumstances. The lower- + // performance CUDNN_BATCHNORM_SPATIAL should be used if this is not + // acceptable. + cudnnBatchNormalizationForwardTraining(_cudnn_handle, + CUDNN_BATCHNORM_SPATIAL_PERSISTENT, + &alpha, + &beta, + _x_desc, + x_ptr, + _y_desc, + y_ptr, + _b_desc, + n_scales, + n_bias, + 1., + nullptr, + nullptr, + mEpsilon, + nullptr, + nullptr); + return 0; + } + + size_t TRTInstanceNormalization::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mEpsilon); + } + + void TRTInstanceNormalization::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mEpsilon); + } + + bool TRTInstanceNormalization::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + switch (pos) + { + case 0: + case 3: + return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || + ioDesc[pos].type == nvinfer1::DataType::kHALF) && + ioDesc[pos].format == nvinfer1::PluginFormat::kLINEAR && + ioDesc[pos].type == ioDesc[0].type); + case 1: + case 2: + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::PluginFormat::kLINEAR; + default: + return false; + } + return false; + } + + const char* TRTInstanceNormalization::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTInstanceNormalization::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + IPluginV2DynamicExt* TRTInstanceNormalization::clone() const TRT_NOEXCEPT + { + auto* plugin = new TRTInstanceNormalization{mLayerName, mEpsilon}; + plugin->setPluginNamespace(mPluginNamespace.c_str()); + return plugin; + } + + nvinfer1::DataType TRTInstanceNormalization::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // Attach the plugin object to an execution context and grant the plugin the + // access to some context resource. + void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + IGpuAllocator* gpuAllocator) TRT_NOEXCEPT + { + _cudnn_handle = cudnnContext; + cudnnCreateTensorDescriptor(&_b_desc); + cudnnCreateTensorDescriptor(&_x_desc); + cudnnCreateTensorDescriptor(&_y_desc); + } + + // Detach the plugin object from its execution context. + void TRTInstanceNormalization::detachFromContext() TRT_NOEXCEPT + { + if (_y_desc) + { + cudnnDestroyTensorDescriptor(_y_desc); + _y_desc = nullptr; + } + if (_x_desc) + { + cudnnDestroyTensorDescriptor(_x_desc); + _x_desc = nullptr; + } + if (_b_desc) + { + cudnnDestroyTensorDescriptor(_b_desc); + _b_desc = nullptr; + } + } + + void TRTInstanceNormalization::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT {} + + // TRTInstanceNormalizationCreator methods + TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTInstanceNormalizationCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTInstanceNormalizationCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + float epsilon = 1e-5; + const PluginField* fields = fc->fields; + for (int i = 0; i < fc->nbFields; ++i) + { + const char* attrName = fields[i].name; + if (!strcmp(attrName, "epsilon")) + { + epsilon = *(static_cast(fields[i].data)); + } + } + + TRTInstanceNormalization* obj = new TRTInstanceNormalization(name, epsilon); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + + IPluginV2DynamicExt* TRTInstanceNormalizationCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + TRTInstanceNormalization* obj = new TRTInstanceNormalization{name, serialData, serialLength}; + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + REGISTER_TENSORRT_PLUGIN(TRTInstanceNormalizationCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.hpp b/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.hpp index 2df04a5f6d..d8119d355b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.hpp @@ -14,80 +14,97 @@ typedef unsigned short half_type; -namespace mmdeploy { -class TRTInstanceNormalization final : public TRTPluginBase { - public: - TRTInstanceNormalization(const std::string& name, float epsilon); +namespace mmdeploy +{ + class TRTInstanceNormalization final : public TRTPluginBase + { + public: + TRTInstanceNormalization(const std::string& name, + float epsilon); - TRTInstanceNormalization(const std::string& name, void const* serialData, size_t serialLength); + TRTInstanceNormalization(const std::string& name, + void const* serialData, + size_t serialLength); - TRTInstanceNormalization() = delete; + TRTInstanceNormalization() = delete; - ~TRTInstanceNormalization() TRT_NOEXCEPT override; + ~TRTInstanceNormalization() TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; - // DynamicExt plugins returns DimsExprs class instead of Dims - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, - int nbInputs, nvinfer1::IExprBuilder& exprBuilder) - TRT_NOEXCEPT override; + // DynamicExt plugins returns DimsExprs class instead of Dims + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, - void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void* buffer) const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - // DynamicExt plugin supportsFormat update. - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; + // DynamicExt plugin supportsFormat update. + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; - const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginType() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; - void attachToContext(cudnnContext* cudnn, cublasContext* cublas, - nvinfer1::IGpuAllocator* allocator) TRT_NOEXCEPT override; + void attachToContext(cudnnContext* cudnn, + cublasContext* cublas, + nvinfer1::IGpuAllocator* allocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; + void detachFromContext() TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; - private: - float mEpsilon{}; - cudnnHandle_t _cudnn_handle{}; - cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _b_desc{}; - std::string mPluginNamespace{}; -}; + private: + float mEpsilon{}; + cudnnHandle_t _cudnn_handle{}; + cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _b_desc{}; + std::string mPluginNamespace{}; + }; -class TRTInstanceNormalizationCreator : public TRTPluginCreatorBase { - public: - TRTInstanceNormalizationCreator(); + class TRTInstanceNormalizationCreator : public TRTPluginCreatorBase + { + public: + TRTInstanceNormalizationCreator(); - ~TRTInstanceNormalizationCreator() override = default; + ~TRTInstanceNormalizationCreator() override = default; - const char* getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* createPlugin( - const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_INSTANCE_NORMALIZATION_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp index 692000b740..c3540002fa 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp @@ -10,297 +10,397 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVModulatedDeformConv2d"}; -} // namespace - -ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic( - const std::string &name, const nvinfer1::Dims stride, const nvinfer1::Dims padding, - const nvinfer1::Dims dilation, const int deformableGroup, const int group) - : TRTPluginBase(name), - mStride(stride), - mPadding(padding), - mDilation(dilation), - mDeformableGroup(deformableGroup), - mGroup(group) { - mWithBias = false; -} - -ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(const std::string name, - const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mStride); - deserialize_value(&data, &length, &mPadding); - deserialize_value(&data, &length, &mDilation); - deserialize_value(&data, &length, &mDeformableGroup); - deserialize_value(&data, &length, &mGroup); - mWithBias = false; -} -ModulatedDeformableConvPluginDynamic::~ModulatedDeformableConvPluginDynamic() {} - -nvinfer1::IPluginV2DynamicExt *ModulatedDeformableConvPluginDynamic::clone() const TRT_NOEXCEPT { - ModulatedDeformableConvPluginDynamic *plugin = new ModulatedDeformableConvPluginDynamic( - mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -static const nvinfer1::IDimensionExpr *get_hw(const nvinfer1::IDimensionExpr *input, - const nvinfer1::IDimensionExpr *weight, - const nvinfer1::IDimensionExpr *stride, - const nvinfer1::IDimensionExpr *pad, - const nvinfer1::IDimensionExpr *dilation, - nvinfer1::IExprBuilder &exprBuilder) { - using DimOp = nvinfer1::DimensionOperation; - auto expr_1 = exprBuilder.constant(1); - - // d*(w-1)+1 - auto kernel_0 = exprBuilder.operation(DimOp::kSUB, *weight, *expr_1); - auto kernel_1 = exprBuilder.operation(DimOp::kPROD, *dilation, *kernel_0); - auto kernel = exprBuilder.operation(DimOp::kSUM, *kernel_1, *expr_1); - - // (1+2*p-k)//stride -1 - auto out_0 = exprBuilder.operation(DimOp::kSUM, *pad, *pad); - auto out_1 = exprBuilder.operation(DimOp::kSUM, *input, *out_0); - auto out_2 = exprBuilder.operation(DimOp::kSUB, *out_1, *kernel); - auto out_3 = exprBuilder.operation(DimOp::kFLOOR_DIV, *out_2, *stride); - auto out = exprBuilder.operation(DimOp::kSUM, *out_3, *expr_1); - - return out; -} - -nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - using DimOp = nvinfer1::DimensionOperation; - auto weight_dim = inputs[3].d; - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[3].d[0]; - - auto input_h = inputs[0].d[2]; - auto input_w = inputs[0].d[3]; - auto weight_h = weight_dim[2]; - auto weight_w = weight_dim[3]; - auto dilation_w = exprBuilder.constant(mDilation.d[0]); - auto dilation_h = exprBuilder.constant(mDilation.d[1]); - auto pad_w = exprBuilder.constant(mPadding.d[0]); - auto pad_h = exprBuilder.constant(mPadding.d[1]); - auto stride_w = exprBuilder.constant(mStride.d[0]); - auto stride_h = exprBuilder.constant(mStride.d[1]); - auto expr_1 = exprBuilder.constant(1); - auto expr_2 = exprBuilder.constant(2); - - ret.d[2] = get_hw(input_h, weight_h, stride_h, pad_h, dilation_h, exprBuilder); - ret.d[3] = get_hw(input_w, weight_w, stride_w, pad_w, dilation_w, exprBuilder); - - return ret; -} - -bool ModulatedDeformableConvPluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || - ioDesc[pos].type == nvinfer1::DataType::kHALF) && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -void ModulatedDeformableConvPluginDynamic::configurePlugin( - const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) TRT_NOEXCEPT { - if (nbInputs == 5) { - mWithBias = true; - } -} - -size_t ModulatedDeformableConvPluginDynamic::getWorkspaceSize( - const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT { - int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type); - - int batch_size = inputs[0].dims.d[0]; - int nInputPlane = inputs[0].dims.d[1]; - int inputHeight = inputs[0].dims.d[2]; - int inputWidth = inputs[0].dims.d[3]; - - int nOutputPlane = outputs[0].dims.d[1]; - int outputHeight = outputs[0].dims.d[2]; - int outputWidth = outputs[0].dims.d[3]; - - int kW = inputs[3].dims.d[2]; - int kH = inputs[3].dims.d[3]; - int im2col_step = std::min(32, batch_size); - - size_t col_size = - mmdeploy::getAlignedSize(nInputPlane * kW * kH * outputHeight * outputWidth * sizeof_dtype); - - return col_size; -} - -int ModulatedDeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int batch = inputDesc[0].dims.d[0]; - int channels = inputDesc[0].dims.d[1]; - int height = inputDesc[0].dims.d[2]; - int width = inputDesc[0].dims.d[3]; - int channels_out = outputDesc[0].dims.d[1]; - int kernel_h = inputDesc[3].dims.d[2]; - int kernel_w = inputDesc[3].dims.d[3]; - - const void *x = inputs[0]; - const void *offset = inputs[1]; - const void *mask = inputs[2]; - const void *weight = inputs[3]; - const void *bias = mWithBias ? inputs[4] : nullptr; - void *output = outputs[0]; - int im2col_step = std::min(batch, 32); - - // TODO: add fp16 support - auto data_type = inputDesc[0].type; - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - ModulatedDeformConvForwardCUDAKernelLauncher( - (float *)x, (float *)weight, (float *)bias, (float *)offset, (float *)mask, - (float *)output, workSpace, batch, channels, height, width, channels_out, kernel_w, - kernel_h, mStride.d[0], mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], - mDilation.d[1], mGroup, mDeformableGroup, im2col_step, m_cublas_handle, stream); - break; - case nvinfer1::DataType::kHALF: - ModulatedDeformConvForwardCUDAKernelLauncher( - (half *)x, (half *)weight, (half *)bias, (half *)offset, (half *)mask, (half *)output, - workSpace, batch, channels, height, width, channels_out, kernel_w, kernel_h, mStride.d[0], - mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup, - mDeformableGroup, im2col_step, m_cublas_handle, stream); - break; - default: - return 1; - break; - } - - return 0; -} - -nvinfer1::DataType ModulatedDeformableConvPluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *ModulatedDeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *ModulatedDeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int ModulatedDeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t ModulatedDeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) + - serialized_size(mDeformableGroup) + serialized_size(mGroup); -} - -void ModulatedDeformableConvPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mStride); - serialize_value(&buffer, mPadding); - serialize_value(&buffer, mDilation); - serialize_value(&buffer, mDeformableGroup); - serialize_value(&buffer, mGroup); -} - -void ModulatedDeformableConvPluginDynamic::attachToContext( - cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT { - m_cublas_handle = cublasContext; -} - -void ModulatedDeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {} - -////////////////////// creator ///////////////////////////// - -ModulatedDeformableConvPluginDynamicCreator::ModulatedDeformableConvPluginDynamicCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *ModulatedDeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *ModulatedDeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - nvinfer1::Dims stride{2, {1, 1}}; - nvinfer1::Dims padding{2, {0, 0}}; - nvinfer1::Dims dilation{2, {1, 1}}; - int deformableGroup = 1; - int group = 1; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVModulatedDeformConv2d"}; + } // namespace + + ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(const std::string& name, + const nvinfer1::Dims stride, + const nvinfer1::Dims padding, + const nvinfer1::Dims dilation, + const int deformableGroup, + const int group) + : TRTPluginBase(name) + , mStride(stride) + , mPadding(padding) + , mDilation(dilation) + , mDeformableGroup(deformableGroup) + , mGroup(group) + { + mWithBias = false; } - std::string field_name(fc->fields[i].name); - if (field_name.compare("deform_groups") == 0) { - deformableGroup = static_cast(fc->fields[i].data)[0]; + ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(const std::string name, + const void* data, + size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mStride); + deserialize_value(&data, &length, &mPadding); + deserialize_value(&data, &length, &mDilation); + deserialize_value(&data, &length, &mDeformableGroup); + deserialize_value(&data, &length, &mGroup); + mWithBias = false; + } + ModulatedDeformableConvPluginDynamic::~ModulatedDeformableConvPluginDynamic() {} + + nvinfer1::IPluginV2DynamicExt* ModulatedDeformableConvPluginDynamic::clone() const TRT_NOEXCEPT + { + ModulatedDeformableConvPluginDynamic* plugin = new ModulatedDeformableConvPluginDynamic(mLayerName, + mStride, + mPadding, + mDilation, + mDeformableGroup, + mGroup); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + static const nvinfer1::IDimensionExpr* get_hw(const nvinfer1::IDimensionExpr* input, + const nvinfer1::IDimensionExpr* weight, + const nvinfer1::IDimensionExpr* stride, + const nvinfer1::IDimensionExpr* pad, + const nvinfer1::IDimensionExpr* dilation, + nvinfer1::IExprBuilder& exprBuilder) + { + using DimOp = nvinfer1::DimensionOperation; + auto expr_1 = exprBuilder.constant(1); + + // d*(w-1)+1 + auto kernel_0 = exprBuilder.operation(DimOp::kSUB, *weight, *expr_1); + auto kernel_1 = exprBuilder.operation(DimOp::kPROD, *dilation, *kernel_0); + auto kernel = exprBuilder.operation(DimOp::kSUM, *kernel_1, *expr_1); + + // (1+2*p-k)//stride -1 + auto out_0 = exprBuilder.operation(DimOp::kSUM, *pad, *pad); + auto out_1 = exprBuilder.operation(DimOp::kSUM, *input, *out_0); + auto out_2 = exprBuilder.operation(DimOp::kSUB, *out_1, *kernel); + auto out_3 = exprBuilder.operation(DimOp::kFLOOR_DIV, *out_2, *stride); + auto out = exprBuilder.operation(DimOp::kSUM, *out_3, *expr_1); + + return out; + } + + nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + using DimOp = nvinfer1::DimensionOperation; + auto weight_dim = inputs[3].d; + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[3].d[0]; + + auto input_h = inputs[0].d[2]; + auto input_w = inputs[0].d[3]; + auto weight_h = weight_dim[2]; + auto weight_w = weight_dim[3]; + auto dilation_w = exprBuilder.constant(mDilation.d[0]); + auto dilation_h = exprBuilder.constant(mDilation.d[1]); + auto pad_w = exprBuilder.constant(mPadding.d[0]); + auto pad_h = exprBuilder.constant(mPadding.d[1]); + auto stride_w = exprBuilder.constant(mStride.d[0]); + auto stride_h = exprBuilder.constant(mStride.d[1]); + auto expr_1 = exprBuilder.constant(1); + auto expr_2 = exprBuilder.constant(2); + + ret.d[2] = get_hw(input_h, weight_h, stride_h, pad_h, dilation_h, exprBuilder); + ret.d[3] = get_hw(input_w, weight_w, stride_w, pad_w, dilation_w, exprBuilder); + + return ret; + } + + bool ModulatedDeformableConvPluginDynamic::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || + ioDesc[pos].type == nvinfer1::DataType::kHALF) && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + void ModulatedDeformableConvPluginDynamic::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + if (nbInputs == 5) + { + mWithBias = true; + } + } + + size_t ModulatedDeformableConvPluginDynamic::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type); + + int batch_size = inputs[0].dims.d[0]; + int nInputPlane = inputs[0].dims.d[1]; + int inputHeight = inputs[0].dims.d[2]; + int inputWidth = inputs[0].dims.d[3]; + + int nOutputPlane = outputs[0].dims.d[1]; + int outputHeight = outputs[0].dims.d[2]; + int outputWidth = outputs[0].dims.d[3]; + + int kW = inputs[3].dims.d[2]; + int kH = inputs[3].dims.d[3]; + int im2col_step = std::min(32, batch_size); + + size_t col_size = + mmdeploy::getAlignedSize(nInputPlane * kW * kH * outputHeight * outputWidth * sizeof_dtype); + + return col_size; + } + + int ModulatedDeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + int channels_out = outputDesc[0].dims.d[1]; + int kernel_h = inputDesc[3].dims.d[2]; + int kernel_w = inputDesc[3].dims.d[3]; + + const void* x = inputs[0]; + const void* offset = inputs[1]; + const void* mask = inputs[2]; + const void* weight = inputs[3]; + const void* bias = mWithBias ? inputs[4] : nullptr; + void* output = outputs[0]; + int im2col_step = std::min(batch, 32); + + // TODO: add fp16 support + auto data_type = inputDesc[0].type; + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + ModulatedDeformConvForwardCUDAKernelLauncher((float*)x, + (float*)weight, + (float*)bias, + (float*)offset, + (float*)mask, + (float*)output, + workSpace, + batch, + channels, + height, + width, + channels_out, + kernel_w, + kernel_h, + mStride.d[0], + mStride.d[1], + mPadding.d[0], + mPadding.d[1], + mDilation.d[0], + mDilation.d[1], + mGroup, + mDeformableGroup, + im2col_step, + m_cublas_handle, + stream); + break; + case nvinfer1::DataType::kHALF: + ModulatedDeformConvForwardCUDAKernelLauncher((half*)x, + (half*)weight, + (half*)bias, + (half*)offset, + (half*)mask, + (half*)output, + workSpace, + batch, + channels, + height, + width, + channels_out, + kernel_w, + kernel_h, + mStride.d[0], + mStride.d[1], + mPadding.d[0], + mPadding.d[1], + mDilation.d[0], + mDilation.d[1], + mGroup, + mDeformableGroup, + im2col_step, + m_cublas_handle, + stream); + break; + default: + return 1; + break; + } + + return 0; + } + + nvinfer1::DataType ModulatedDeformableConvPluginDynamic::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* ModulatedDeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* ModulatedDeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int ModulatedDeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t ModulatedDeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) + + serialized_size(mDeformableGroup) + serialized_size(mGroup); + } + + void ModulatedDeformableConvPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mStride); + serialize_value(&buffer, mPadding); + serialize_value(&buffer, mDilation); + serialize_value(&buffer, mDeformableGroup); + serialize_value(&buffer, mGroup); + } + + void ModulatedDeformableConvPluginDynamic::attachToContext( + cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT + { + m_cublas_handle = cublasContext; + } + + void ModulatedDeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {} + + ////////////////////// creator ///////////////////////////// + + ModulatedDeformableConvPluginDynamicCreator::ModulatedDeformableConvPluginDynamicCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); } - if (field_name.compare("groups") == 0) { - group = static_cast(fc->fields[i].data)[0]; + const char* ModulatedDeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; } - if (field_name.compare("stride") == 0) { - stride.nbDims = 2; - stride.d[0] = static_cast(fc->fields[i].data)[0]; - stride.d[1] = static_cast(fc->fields[i].data)[1]; + const char* ModulatedDeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; } - if (field_name.compare("padding") == 0) { - padding.nbDims = 2; - padding.d[0] = static_cast(fc->fields[i].data)[0]; - padding.d[1] = static_cast(fc->fields[i].data)[1]; + nvinfer1::IPluginV2* ModulatedDeformableConvPluginDynamicCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + nvinfer1::Dims stride{2, {1, 1}}; + nvinfer1::Dims padding{2, {0, 0}}; + nvinfer1::Dims dilation{2, {1, 1}}; + int deformableGroup = 1; + int group = 1; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("deform_groups") == 0) + { + deformableGroup = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("groups") == 0) + { + group = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("stride") == 0) + { + stride.nbDims = 2; + stride.d[0] = static_cast(fc->fields[i].data)[0]; + stride.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("padding") == 0) + { + padding.nbDims = 2; + padding.d[0] = static_cast(fc->fields[i].data)[0]; + padding.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("dilation") == 0) + { + dilation.nbDims = 2; + dilation.d[0] = static_cast(fc->fields[i].data)[0]; + dilation.d[1] = static_cast(fc->fields[i].data)[1]; + } + } + + ModulatedDeformableConvPluginDynamic* plugin = new ModulatedDeformableConvPluginDynamic( + name, + stride, + padding, + dilation, + deformableGroup, + group); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; } - if (field_name.compare("dilation") == 0) { - dilation.nbDims = 2; - dilation.d[0] = static_cast(fc->fields[i].data)[0]; - dilation.d[1] = static_cast(fc->fields[i].data)[1]; + nvinfer1::IPluginV2* ModulatedDeformableConvPluginDynamicCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new ModulatedDeformableConvPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; } - } - - ModulatedDeformableConvPluginDynamic *plugin = new ModulatedDeformableConvPluginDynamic( - name, stride, padding, dilation, deformableGroup, group); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new ModulatedDeformableConvPluginDynamic(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(ModulatedDeformableConvPluginDynamicCreator); + REGISTER_TENSORRT_PLUGIN(ModulatedDeformableConvPluginDynamicCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.hpp b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.hpp index 2dc6ed2f20..1bfbc17735 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.hpp @@ -9,74 +9,101 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class ModulatedDeformableConvPluginDynamic : public TRTPluginBase { - public: - ModulatedDeformableConvPluginDynamic(const std::string &name, const nvinfer1::Dims stride, - const nvinfer1::Dims padding, const nvinfer1::Dims dilation, - const int deformableGroup, const int group); - - ModulatedDeformableConvPluginDynamic(const std::string name, const void *data, size_t length); - - ModulatedDeformableConvPluginDynamic() = delete; - - ~ModulatedDeformableConvPluginDynamic() TRT_NOEXCEPT override; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; - - private: - nvinfer1::Dims mStride; - nvinfer1::Dims mPadding; - nvinfer1::Dims mDilation; - int mDeformableGroup; - int mGroup; - bool mWithBias; - - cublasHandle_t m_cublas_handle; -}; - -class ModulatedDeformableConvPluginDynamicCreator : public TRTPluginCreatorBase { - public: - ModulatedDeformableConvPluginDynamicCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class ModulatedDeformableConvPluginDynamic : public TRTPluginBase + { + public: + ModulatedDeformableConvPluginDynamic(const std::string& name, + const nvinfer1::Dims stride, + const nvinfer1::Dims padding, + const nvinfer1::Dims dilation, + const int deformableGroup, + const int group); + + ModulatedDeformableConvPluginDynamic(const std::string name, + const void* data, + size_t length); + + ModulatedDeformableConvPluginDynamic() = delete; + + ~ModulatedDeformableConvPluginDynamic() TRT_NOEXCEPT override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + void attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + + private: + nvinfer1::Dims mStride; + nvinfer1::Dims mPadding; + nvinfer1::Dims mDilation; + int mDeformableGroup; + int mGroup; + bool mWithBias; + + cublasHandle_t m_cublas_handle; + }; + + class ModulatedDeformableConvPluginDynamicCreator : public TRTPluginCreatorBase + { + public: + ModulatedDeformableConvPluginDynamicCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_MODULATED_DEFORM_CONV_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu index 1e1f99d5ff..1b8884c7dc 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu @@ -7,132 +7,284 @@ #include "trt_modulated_deform_conv_kernel.hpp" #include "trt_plugin_helper.hpp" -template -void trt_modulated_deformable_im2col(const T* data_im_, const T* data_offset_, const T* data_mask_, - const int batch_size, const int channels, const int height_im, - const int width_im, const int height_col, const int width_col, - const int kernel_h, const int kenerl_w, const int pad_h, - const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, T* data_col_, - cudaStream_t stream) { - // num_axes should be smaller than block size - const int channel_per_deformable_group = channels / deformable_group; - const int num_kernels = channels * batch_size * height_col * width_col; - - modulated_deformable_im2col_gpu_kernel - <<>>( - num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, - batch_size, channels, deformable_group, height_col, width_col, data_col_); - - cudaCheckError(); +template +void trt_modulated_deformable_im2col(const T* data_im_, + const T* data_offset_, + const T* data_mask_, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kenerl_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + T* data_col_, + cudaStream_t stream) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + modulated_deformable_im2col_gpu_kernel + <<>>(num_kernels, + data_im_, + data_offset_, + data_mask_, + height_im, + width_im, + kernel_h, + kenerl_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + batch_size, + channels, + deformable_group, + height_col, + width_col, + data_col_); + + cudaCheckError(); } -template -__global__ void output_add_bias_kernel(scalar_t* output, const scalar_t* bias, size_t step_batch, - size_t step_channel, size_t n) { - CUDA_1D_KERNEL_LOOP(index, n) { output[index] += bias[(index % step_batch) / step_channel]; } +template +__global__ void output_add_bias_kernel(scalar_t* output, + const scalar_t* bias, + size_t step_batch, + size_t step_channel, + size_t n) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + output[index] += bias[(index % step_batch) / step_channel]; + } } #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -template <> -__global__ void output_add_bias_kernel<__half>(__half* output, const __half* bias, - size_t step_batch, size_t step_channel, size_t n) { - CUDA_1D_KERNEL_LOOP(index, n) { - const __half b = bias[(index % step_batch) / step_channel]; - const __half o = output[index]; - output[index] = __hadd(o, b); - } +template<> +__global__ void output_add_bias_kernel<__half>(__half* output, + const __half* bias, + size_t step_batch, + size_t step_channel, + size_t n) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + const __half b = bias[(index % step_batch) / step_channel]; + const __half o = output[index]; + output[index] = __hadd(o, b); + } } #else -template <> -__global__ void output_add_bias_kernel<__half>(__half* output, const __half* bias, - size_t step_batch, size_t step_channel, size_t n) { - CUDA_1D_KERNEL_LOOP(index, n) { - const __half b = bias[(index % step_batch) / step_channel]; - const __half o = output[index]; - output[index] = __float2half(__half2float(o) + __half2float(b)); - } +template<> +__global__ void output_add_bias_kernel<__half>(__half* output, + const __half* bias, + size_t step_batch, + size_t step_channel, + size_t n) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + const __half b = bias[(index % step_batch) / step_channel]; + const __half o = output[index]; + output[index] = __float2half(__half2float(o) + __half2float(b)); + } } #endif -template -static void output_add_bias(scalar_t* output, const scalar_t* bias, size_t batch, size_t channel, - size_t height, size_t width, cudaStream_t stream) { - size_t step_channel = height * width; - size_t step_batch = step_channel * channel; - size_t n = step_batch * batch; - output_add_bias_kernel<<>>(output, bias, step_batch, - step_channel, n); +template +static void output_add_bias(scalar_t* output, + const scalar_t* bias, + size_t batch, + size_t channel, + size_t height, + size_t width, + cudaStream_t stream) +{ + size_t step_channel = height * width; + size_t step_batch = step_channel * channel; + size_t n = step_batch * batch; + output_add_bias_kernel<<>>(output, + bias, + step_batch, + step_channel, + n); } -template -void ModulatedDeformConvForwardCUDAKernelLauncher( - const scalar_t* input, const scalar_t* weight, const scalar_t* bias, const scalar_t* offset, - const scalar_t* mask, scalar_t* output, void* workspace, int batch, int channels, int height, - int width, int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, - int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream) { - bool with_bias = (bias != nullptr); - - im2col_step = std::min(int(batch), im2col_step); - assert(batch % im2col_step == 0); - - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - scalar_t* columns = (scalar_t*)workspace; - - const size_t input_step = channels * height * width; - const size_t offset_step = deformable_group * kernel_h * kernel_w * 2 * height_out * width_out; - const size_t mask_step = deformable_group * kernel_h * kernel_w * height_out * width_out; - const size_t out_step = channels_out * height_out * width_out; - const size_t out_group_step = out_step / group; - const size_t col_g_step = channels * kernel_w * kernel_h / group * height_out * width_out; - const size_t weight_g_step = channels_out / group * channels / group * kernel_h * kernel_w; - - const int m = channels_out / group; - const int n = height_out * width_out; - const int k = channels / group * kernel_h * kernel_w; - scalar_t alpha = 1.; - scalar_t beta = 0.; - - for (int b = 0; b < batch; b++) { - const scalar_t* input_start = input + b * input_step; - const scalar_t* offset_start = offset + b * offset_step; - const scalar_t* mask_start = mask + b * mask_step; - trt_modulated_deformable_im2col( - input_start, offset_start, mask_start, 1, channels, height, width, height_out, width_out, - kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - deformable_group, columns, stream); - - for (int g = 0; g < group; g++) { - const scalar_t* weight_start = weight + g * weight_g_step; - scalar_t* col_start = columns + g * col_g_step; - scalar_t* out_buffer_start = output + b * out_step + g * out_group_step; - - cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, col_start, - n, weight_start, k, &beta, out_buffer_start, n); - cudaCheckError(); +template +void ModulatedDeformConvForwardCUDAKernelLauncher(const scalar_t* input, + const scalar_t* weight, + const scalar_t* bias, + const scalar_t* offset, + const scalar_t* mask, + scalar_t* output, + void* workspace, + int batch, + int channels, + int height, + int width, + int channels_out, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int pad_w, + int pad_h, + int dilation_w, + int dilation_h, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream) +{ + bool with_bias = (bias != nullptr); + + im2col_step = std::min(int(batch), im2col_step); + assert(batch % im2col_step == 0); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + scalar_t* columns = (scalar_t*)workspace; + + const size_t input_step = channels * height * width; + const size_t offset_step = deformable_group * kernel_h * kernel_w * 2 * height_out * width_out; + const size_t mask_step = deformable_group * kernel_h * kernel_w * height_out * width_out; + const size_t out_step = channels_out * height_out * width_out; + const size_t out_group_step = out_step / group; + const size_t col_g_step = channels * kernel_w * kernel_h / group * height_out * width_out; + const size_t weight_g_step = channels_out / group * channels / group * kernel_h * kernel_w; + + const int m = channels_out / group; + const int n = height_out * width_out; + const int k = channels / group * kernel_h * kernel_w; + scalar_t alpha = 1.; + scalar_t beta = 0.; + + for (int b = 0; b < batch; b++) + { + const scalar_t* input_start = input + b * input_step; + const scalar_t* offset_start = offset + b * offset_step; + const scalar_t* mask_start = mask + b * mask_step; + trt_modulated_deformable_im2col( + input_start, + offset_start, + mask_start, + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + columns, + stream); + + for (int g = 0; g < group; g++) + { + const scalar_t* weight_start = weight + g * weight_g_step; + scalar_t* col_start = columns + g * col_g_step; + scalar_t* out_buffer_start = output + b * out_step + g * out_group_step; + + cublasGemmWrap(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + n, + m, + k, + &alpha, + col_start, + n, + weight_start, + k, + &beta, + out_buffer_start, + n); + cudaCheckError(); + } } - } - if (with_bias) { - output_add_bias(output, bias, batch, channels_out, height_out, width_out, stream); - } + if (with_bias) + { + output_add_bias(output, + bias, + batch, + channels_out, + height_out, + width_out, + stream); + } } -template void ModulatedDeformConvForwardCUDAKernelLauncher( - const float* input, const float* weight, const float* bias, const float* offset, - const float* mask, float* output, void* workspace, int batch, int channels, int height, - int width, int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, - int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); - -template void ModulatedDeformConvForwardCUDAKernelLauncher<__half>( - const __half* input, const __half* weight, const __half* bias, const __half* offset, - const __half* mask, __half* output, void* workspace, int batch, int channels, int height, - int width, int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, - int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); +template void ModulatedDeformConvForwardCUDAKernelLauncher(const float* input, + const float* weight, + const float* bias, + const float* offset, + const float* mask, + float* output, + void* workspace, + int batch, + int channels, + int height, + int width, + int channels_out, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int pad_w, + int pad_h, + int dilation_w, + int dilation_h, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream); + +template void ModulatedDeformConvForwardCUDAKernelLauncher<__half>(const __half* input, + const __half* weight, + const __half* bias, + const __half* offset, + const __half* mask, + __half* output, + void* workspace, + int batch, + int channels, + int height, + int width, + int channels_out, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int pad_w, + int pad_h, + int dilation_w, + int dilation_h, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.hpp index 4cdec4fb38..4d928b16c5 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.hpp @@ -4,12 +4,31 @@ #include #include -template -void ModulatedDeformConvForwardCUDAKernelLauncher( - const scalar_t* input, const scalar_t* weight, const scalar_t* bias, const scalar_t* offset, - const scalar_t* mask, scalar_t* output, void* workspace, int batch, int channels, int height, - int width, int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, - int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); +template +void ModulatedDeformConvForwardCUDAKernelLauncher(const scalar_t* input, + const scalar_t* weight, + const scalar_t* bias, + const scalar_t* offset, + const scalar_t* mask, + scalar_t* output, + void* workspace, + int batch, + int channels, + int height, + int width, + int channels_out, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int pad_w, + int pad_h, + int dilation_w, + int dilation_h, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp index ad9a518da7..456acca9b4 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp @@ -9,219 +9,263 @@ #include "trt_multi_level_roi_align_kernel.hpp" #include "trt_plugin_helper.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVMultiLevelRoiAlign"}; -} // namespace - -TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight, - int alignedWidth, int poolMode, int sampleNum, - const std::vector &featmapStrides, - float roiScaleFactor, int finestScale, bool aligned) - : TRTPluginBase(name), - mAlignedHeight(alignedHeight), - mAlignedWidth(alignedWidth), - mPoolMode(poolMode), - mSampleNum(sampleNum), - mFeatmapStrides(featmapStrides), - mRoiScaleFactor(roiScaleFactor), - mFinestScale(finestScale), - mAligned(aligned) {} - -TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mAlignedHeight); - deserialize_value(&data, &length, &mAlignedWidth); - deserialize_value(&data, &length, &mPoolMode); - deserialize_value(&data, &length, &mSampleNum); - deserialize_value(&data, &length, &mRoiScaleFactor); - deserialize_value(&data, &length, &mFinestScale); - deserialize_value(&data, &length, &mAligned); - deserialize_value(&data, &length, &mFeatmapStrides); -} - -nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const TRT_NOEXCEPT { - TRTMultiLevelRoiAlign *plugin = - new TRTMultiLevelRoiAlign(mLayerName, mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, - mFeatmapStrides, mRoiScaleFactor, mFinestScale, mAligned); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - // warning, nbInputs should equal to mFeatmapStrides.size() + 1 - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[1].d[1]; - ret.d[2] = exprBuilder.constant(mAlignedHeight); - ret.d[3] = exprBuilder.constant(mAlignedWidth); - - return ret; -} - -bool TRTMultiLevelRoiAlign::supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -void TRTMultiLevelRoiAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments - ASSERT(nbOutputs == 1); - ASSERT(nbInputs >= 1); - mFeatmapStrides = - std::vector(mFeatmapStrides.begin(), mFeatmapStrides.begin() + (nbInputs - 1)); -} - -size_t TRTMultiLevelRoiAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int num_rois = inputDesc[0].dims.d[0]; - int batch_size = inputDesc[1].dims.d[0]; - int channels = inputDesc[1].dims.d[1]; - - const int kMaxFeatMap = 10; - int heights[kMaxFeatMap]; - int widths[kMaxFeatMap]; - float strides[kMaxFeatMap]; - - int num_feats = mFeatmapStrides.size(); - for (int i = 0; i < num_feats; ++i) { - heights[i] = inputDesc[i + 1].dims.d[2]; - widths[i] = inputDesc[i + 1].dims.d[3]; - strides[i] = mFeatmapStrides[i]; - } - - const void *rois = inputs[0]; - const void *const *feats = inputs + 1; - - multi_level_roi_align((float *)outputs[0], (const float *)rois, num_rois, feats, num_feats, - batch_size, channels, &heights[0], &widths[0], &strides[0], - mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, - mRoiScaleFactor, mFinestScale, mAligned, stream); - - return 0; -} - -nvinfer1::DataType TRTMultiLevelRoiAlign::getOutputDataType(int index, - const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return nvinfer1::DataType::kFLOAT; -} - -// IPluginV2 Methods -const char *TRTMultiLevelRoiAlign::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTMultiLevelRoiAlign::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTMultiLevelRoiAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + - serialized_size(mAlignedWidth) + serialized_size(mPoolMode) + serialized_size(mSampleNum) + - serialized_size(mRoiScaleFactor) + serialized_size(mFinestScale) + - serialized_size(mAligned); -} - -void TRTMultiLevelRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mAlignedHeight); - serialize_value(&buffer, mAlignedWidth); - serialize_value(&buffer, mPoolMode); - serialize_value(&buffer, mSampleNum); - serialize_value(&buffer, mRoiScaleFactor); - serialize_value(&buffer, mFinestScale); - serialize_value(&buffer, mAligned); - serialize_value(&buffer, mFeatmapStrides); -} - -TRTMultiLevelRoiAlignCreator::TRTMultiLevelRoiAlignCreator() { - mPluginAttributes = std::vector( - {nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"), - nvinfer1::PluginField("pool_mode"), nvinfer1::PluginField("sampling_ratio"), - nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"), - nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")}); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTMultiLevelRoiAlignCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTMultiLevelRoiAlignCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int alignedHeight = 7; - int alignedWidth = 7; - int poolMode = 0; - int sampleNum = 2; - std::vector featmapStrides; - float roiScaleFactor = -1; - int finestScale = 56; - bool aligned = false; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("output_height") == 0) { - alignedHeight = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("output_width") == 0) { - alignedWidth = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("pool_mode") == 0) { - poolMode = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("sampling_ratio") == 0) { - sampleNum = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("roi_scale_factor") == 0) { - roiScaleFactor = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("finest_scale") == 0) { - finestScale = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("featmap_strides") == 0) { - int data_size = (fc->fields[i].length); - const float *data_start = static_cast(fc->fields[i].data); - featmapStrides = std::vector(data_start, data_start + data_size); - } else if (field_name.compare("aligned") == 0) { - int aligned_int = static_cast(fc->fields[i].data)[0]; - aligned = aligned_int != 0; - } - } - - ASSERT(featmapStrides.size() != 0); - - TRTMultiLevelRoiAlign *plugin = - new TRTMultiLevelRoiAlign(name, alignedHeight, alignedWidth, poolMode, sampleNum, - featmapStrides, roiScaleFactor, finestScale, aligned); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTMultiLevelRoiAlign(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTMultiLevelRoiAlignCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVMultiLevelRoiAlign"}; + } // namespace + + TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string& name, int alignedHeight, int alignedWidth, int poolMode, int sampleNum, const std::vector& featmapStrides, float roiScaleFactor, int finestScale, bool aligned) + : TRTPluginBase(name) + , mAlignedHeight(alignedHeight) + , mAlignedWidth(alignedWidth) + , mPoolMode(poolMode) + , mSampleNum(sampleNum) + , mFeatmapStrides(featmapStrides) + , mRoiScaleFactor(roiScaleFactor) + , mFinestScale(finestScale) + , mAligned(aligned) + { + } + + TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mAlignedHeight); + deserialize_value(&data, &length, &mAlignedWidth); + deserialize_value(&data, &length, &mPoolMode); + deserialize_value(&data, &length, &mSampleNum); + deserialize_value(&data, &length, &mRoiScaleFactor); + deserialize_value(&data, &length, &mFinestScale); + deserialize_value(&data, &length, &mAligned); + deserialize_value(&data, &length, &mFeatmapStrides); + } + + nvinfer1::IPluginV2DynamicExt* TRTMultiLevelRoiAlign::clone() const TRT_NOEXCEPT + { + TRTMultiLevelRoiAlign* plugin = + new TRTMultiLevelRoiAlign(mLayerName, mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, mFeatmapStrides, mRoiScaleFactor, mFinestScale, mAligned); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + // warning, nbInputs should equal to mFeatmapStrides.size() + 1 + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[1].d[1]; + ret.d[2] = exprBuilder.constant(mAlignedHeight); + ret.d[3] = exprBuilder.constant(mAlignedWidth); + + return ret; + } + + bool TRTMultiLevelRoiAlign::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + void TRTMultiLevelRoiAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + ASSERT(nbOutputs == 1); + ASSERT(nbInputs >= 1); + mFeatmapStrides = + std::vector(mFeatmapStrides.begin(), mFeatmapStrides.begin() + (nbInputs - 1)); + } + + size_t TRTMultiLevelRoiAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int num_rois = inputDesc[0].dims.d[0]; + int batch_size = inputDesc[1].dims.d[0]; + int channels = inputDesc[1].dims.d[1]; + + const int kMaxFeatMap = 10; + int heights[kMaxFeatMap]; + int widths[kMaxFeatMap]; + float strides[kMaxFeatMap]; + + int num_feats = mFeatmapStrides.size(); + for (int i = 0; i < num_feats; ++i) + { + heights[i] = inputDesc[i + 1].dims.d[2]; + widths[i] = inputDesc[i + 1].dims.d[3]; + strides[i] = mFeatmapStrides[i]; + } + + const void* rois = inputs[0]; + const void* const* feats = inputs + 1; + + multi_level_roi_align((float*)outputs[0], (const float*)rois, num_rois, feats, num_feats, batch_size, channels, &heights[0], &widths[0], &strides[0], mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, mRoiScaleFactor, mFinestScale, mAligned, stream); + + return 0; + } + + nvinfer1::DataType TRTMultiLevelRoiAlign::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return nvinfer1::DataType::kFLOAT; + } + + // IPluginV2 Methods + const char* TRTMultiLevelRoiAlign::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTMultiLevelRoiAlign::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTMultiLevelRoiAlign::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + + serialized_size(mAlignedWidth) + serialized_size(mPoolMode) + serialized_size(mSampleNum) + + serialized_size(mRoiScaleFactor) + serialized_size(mFinestScale) + + serialized_size(mAligned); + } + + void TRTMultiLevelRoiAlign::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mAlignedHeight); + serialize_value(&buffer, mAlignedWidth); + serialize_value(&buffer, mPoolMode); + serialize_value(&buffer, mSampleNum); + serialize_value(&buffer, mRoiScaleFactor); + serialize_value(&buffer, mFinestScale); + serialize_value(&buffer, mAligned); + serialize_value(&buffer, mFeatmapStrides); + } + + TRTMultiLevelRoiAlignCreator::TRTMultiLevelRoiAlignCreator() + { + mPluginAttributes = std::vector( + {nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"), nvinfer1::PluginField("pool_mode"), nvinfer1::PluginField("sampling_ratio"), nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"), nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")}); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTMultiLevelRoiAlignCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTMultiLevelRoiAlignCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTMultiLevelRoiAlignCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int alignedHeight = 7; + int alignedWidth = 7; + int poolMode = 0; + int sampleNum = 2; + std::vector featmapStrides; + float roiScaleFactor = -1; + int finestScale = 56; + bool aligned = false; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("output_height") == 0) + { + alignedHeight = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("output_width") == 0) + { + alignedWidth = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("pool_mode") == 0) + { + poolMode = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("sampling_ratio") == 0) + { + sampleNum = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("roi_scale_factor") == 0) + { + roiScaleFactor = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("finest_scale") == 0) + { + finestScale = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("featmap_strides") == 0) + { + int data_size = (fc->fields[i].length); + const float* data_start = static_cast(fc->fields[i].data); + featmapStrides = std::vector(data_start, data_start + data_size); + } + else if (field_name.compare("aligned") == 0) + { + int aligned_int = static_cast(fc->fields[i].data)[0]; + aligned = aligned_int != 0; + } + } + + ASSERT(featmapStrides.size() != 0); + + TRTMultiLevelRoiAlign* plugin = + new TRTMultiLevelRoiAlign(name, alignedHeight, alignedWidth, poolMode, sampleNum, featmapStrides, roiScaleFactor, finestScale, aligned); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTMultiLevelRoiAlignCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTMultiLevelRoiAlign(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTMultiLevelRoiAlignCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp index a9a06236e0..814118d29b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp @@ -10,69 +10,65 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTMultiLevelRoiAlign : public TRTPluginBase { - public: - TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight, int alignedWidth, int poolMode, - int sampleNum, const std::vector &featmapStrides, - float roiScaleFactor = -1, int finestScale = 56, bool aligned = false); +namespace mmdeploy +{ + class TRTMultiLevelRoiAlign : public TRTPluginBase + { + public: + TRTMultiLevelRoiAlign(const std::string& name, int alignedHeight, int alignedWidth, int poolMode, int sampleNum, const std::vector& featmapStrides, float roiScaleFactor = -1, int finestScale = 56, bool aligned = false); - TRTMultiLevelRoiAlign(const std::string name, const void *data, size_t length); + TRTMultiLevelRoiAlign(const std::string name, const void* data, size_t length); - TRTMultiLevelRoiAlign() = delete; + TRTMultiLevelRoiAlign() = delete; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - private: - int mAlignedHeight; - int mAlignedWidth; - int mPoolMode; - int mSampleNum; - std::vector mFeatmapStrides; - float mRoiScaleFactor; - int mFinestScale; - bool mAligned; -}; + private: + int mAlignedHeight; + int mAlignedWidth; + int mPoolMode; + int mSampleNum; + std::vector mFeatmapStrides; + float mRoiScaleFactor; + int mFinestScale; + bool mAligned; + }; -class TRTMultiLevelRoiAlignCreator : public TRTPluginCreatorBase { - public: - TRTMultiLevelRoiAlignCreator(); + class TRTMultiLevelRoiAlignCreator : public TRTPluginCreatorBase + { + public: + TRTMultiLevelRoiAlignCreator(); - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_ROI_ALIGN_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu index 9eefbe3f32..260086b511 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu @@ -10,167 +10,264 @@ #include "trt_plugin_helper.hpp" const int kMAX_FEATMAP_SIZE = 10; -struct FeatData { - const void *data[kMAX_FEATMAP_SIZE]; - int batch_size; - int channels; - int h[kMAX_FEATMAP_SIZE]; - int w[kMAX_FEATMAP_SIZE]; - float spatial_scale[kMAX_FEATMAP_SIZE]; - int num_featmap; +struct FeatData +{ + const void* data[kMAX_FEATMAP_SIZE]; + int batch_size; + int channels; + int h[kMAX_FEATMAP_SIZE]; + int w[kMAX_FEATMAP_SIZE]; + float spatial_scale[kMAX_FEATMAP_SIZE]; + int num_featmap; }; -template -__device__ scalar_t roi_align_single(const scalar_t *__restrict__ bottom_data, - const int roi_batch_ind, const scalar_t roi_start_w, - const scalar_t roi_start_h, const scalar_t roi_end_w, - const scalar_t roi_end_h, const scalar_t spatial_scale, - const int pw, const int ph, const int c, const int sample_num, - const int channels, const int height, const int width, - const int pooled_height, const int pooled_width) { - // Force malformed ROIs to be 1x1 - scalar_t roi_width = max(roi_end_w - roi_start_w, (scalar_t)(aligned ? 0. : 1.)); - scalar_t roi_height = max(roi_end_h - roi_start_h, (scalar_t)(aligned ? 0. : 1.)); - - const scalar_t bin_size_h = roi_height / pooled_height; - const scalar_t bin_size_w = roi_width / pooled_width; - - const scalar_t *offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; - - const int sample_num_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); - const int sample_num_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); - - scalar_t output_val = (pool_mode == 0) ? -FLT_MAX : 0; - const scalar_t y_offset = roi_start_h + ph * bin_size_h; - const scalar_t y_scale = bin_size_h / (scalar_t)(sample_num_h); - const scalar_t x_offset = roi_start_w + pw * bin_size_w; - const scalar_t x_scale = bin_size_w / (scalar_t)(sample_num_w); - for (int iy = 0; iy < sample_num_h; iy++) { - const scalar_t y = fma(scalar_t(iy) + scalar_t(.5f), y_scale, y_offset); - for (int ix = 0; ix < sample_num_w; ix++) { - const scalar_t x = fma(scalar_t(ix) + scalar_t(.5f), x_scale, x_offset); - scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); - if (pool_mode == 0) { - output_val = max(output_val, val); - } else { - output_val += val; - } +template +__device__ scalar_t roi_align_single(const scalar_t* __restrict__ bottom_data, + const int roi_batch_ind, + const scalar_t roi_start_w, + const scalar_t roi_start_h, + const scalar_t roi_end_w, + const scalar_t roi_end_h, + const scalar_t spatial_scale, + const int pw, + const int ph, + const int c, + const int sample_num, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width) +{ + // Force malformed ROIs to be 1x1 + scalar_t roi_width = max(roi_end_w - roi_start_w, (scalar_t)(aligned ? 0. : 1.)); + scalar_t roi_height = max(roi_end_h - roi_start_h, (scalar_t)(aligned ? 0. : 1.)); + + const scalar_t bin_size_h = roi_height / pooled_height; + const scalar_t bin_size_w = roi_width / pooled_width; + + const scalar_t* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + + const int sample_num_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); + const int sample_num_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); + + scalar_t output_val = (pool_mode == 0) ? -FLT_MAX : 0; + const scalar_t y_offset = roi_start_h + ph * bin_size_h; + const scalar_t y_scale = bin_size_h / (scalar_t)(sample_num_h); + const scalar_t x_offset = roi_start_w + pw * bin_size_w; + const scalar_t x_scale = bin_size_w / (scalar_t)(sample_num_w); + for (int iy = 0; iy < sample_num_h; iy++) + { + const scalar_t y = fma(scalar_t(iy) + scalar_t(.5f), y_scale, y_offset); + for (int ix = 0; ix < sample_num_w; ix++) + { + const scalar_t x = fma(scalar_t(ix) + scalar_t(.5f), x_scale, x_offset); + scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); + if (pool_mode == 0) + { + output_val = max(output_val, val); + } + else + { + output_val += val; + } + } + } + if (pool_mode != 0) + { + output_val /= max(sample_num_h * sample_num_w, 1); } - } - if (pool_mode != 0) { - output_val /= max(sample_num_h * sample_num_w, 1); - } - return output_val; + return output_val; } -template -__global__ void roi_extractor_kernel(scalar_t *__restrict__ output, - const scalar_t *__restrict__ bottom_rois, FeatData feat_data, - const int pool_mode, const int sample_num, - const float roi_scale_factor, const int finest_scale, - const int pooled_height, const int pooled_width, - int nThreads) { - CUDA_1D_KERNEL_LOOP(index, nThreads) { - const int channels = feat_data.channels; - int tmp_index = index; - const int pw = tmp_index % pooled_width; - tmp_index /= pooled_width; - const int ph = tmp_index % pooled_height; - tmp_index /= pooled_height; - const int c = tmp_index % channels; - const int n = tmp_index / channels; - - const scalar_t *offset_bottom_rois = bottom_rois + n * 5; - - scalar_t roi_offset_x0 = offset_bottom_rois[1]; - scalar_t roi_offset_y0 = offset_bottom_rois[2]; - scalar_t roi_offset_x1 = offset_bottom_rois[3]; - scalar_t roi_offset_y1 = offset_bottom_rois[4]; - - const scalar_t scale = sqrtf((roi_offset_y1 - roi_offset_y0) * (roi_offset_x1 - roi_offset_x0)); - - const int target_lvls = - min(feat_data.num_featmap - 1, - max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); - - if (roi_scale_factor > 0.) { - const scalar_t roi_off_cx = (roi_offset_x0 + roi_offset_x1) * 0.5; - const scalar_t roi_off_cy = (roi_offset_y0 + roi_offset_y1) * 0.5; - const scalar_t half_scale_factor = roi_scale_factor * 0.5; - const scalar_t half_roi_off_w = - fma(roi_offset_x1 - roi_offset_x0 + 1, half_scale_factor, scalar_t(-0.5)); - const scalar_t half_roi_off_h = - fma(roi_offset_y1 - roi_offset_y0 + 1, half_scale_factor, scalar_t(-0.5)); - - roi_offset_x0 = roi_off_cx - half_roi_off_w; - roi_offset_x1 = roi_off_cx + half_roi_off_w; - roi_offset_y0 = roi_off_cy - half_roi_off_h; - roi_offset_y1 = roi_off_cy + half_roi_off_h; - } +template +__global__ void roi_extractor_kernel(scalar_t* __restrict__ output, + const scalar_t* __restrict__ bottom_rois, + FeatData feat_data, + const int pool_mode, + const int sample_num, + const float roi_scale_factor, + const int finest_scale, + const int pooled_height, + const int pooled_width, + int nThreads) +{ + CUDA_1D_KERNEL_LOOP(index, nThreads) + { + const int channels = feat_data.channels; + int tmp_index = index; + const int pw = tmp_index % pooled_width; + tmp_index /= pooled_width; + const int ph = tmp_index % pooled_height; + tmp_index /= pooled_height; + const int c = tmp_index % channels; + const int n = tmp_index / channels; + + const scalar_t* offset_bottom_rois = bottom_rois + n * 5; + + scalar_t roi_offset_x0 = offset_bottom_rois[1]; + scalar_t roi_offset_y0 = offset_bottom_rois[2]; + scalar_t roi_offset_x1 = offset_bottom_rois[3]; + scalar_t roi_offset_y1 = offset_bottom_rois[4]; + + const scalar_t scale = sqrtf((roi_offset_y1 - roi_offset_y0) * (roi_offset_x1 - roi_offset_x0)); - const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; - const int height = feat_data.h[target_lvls]; - const int width = feat_data.w[target_lvls]; - const scalar_t *bottom_data = (scalar_t *)feat_data.data[target_lvls]; - - const int roi_batch_ind = offset_bottom_rois[0]; - const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; - const scalar_t roi_start_w = - fma(roi_offset_x0, spatial_scale, offset); // roi_offset_x0 * spatial_scale + offset; - const scalar_t roi_start_h = - fma(roi_offset_y0, spatial_scale, offset); // roi_offset_y0 * spatial_scale + offset; - const scalar_t roi_end_w = - fma(roi_offset_x1, spatial_scale, offset); // (roi_offset_x1) * spatial_scale - offset; - const scalar_t roi_end_h = - fma(roi_offset_y1, spatial_scale, offset); // (roi_offset_y1)*spatial_scale - offset; - - if (pool_mode == 0) { - const scalar_t output_val = roi_align_single( - bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale, - pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width); - output[index] = output_val; - } else { - const scalar_t output_val = roi_align_single( - bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale, - pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width); - output[index] = output_val; + const int target_lvls = + min(feat_data.num_featmap - 1, + max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); + + if (roi_scale_factor > 0.) + { + const scalar_t roi_off_cx = (roi_offset_x0 + roi_offset_x1) * 0.5; + const scalar_t roi_off_cy = (roi_offset_y0 + roi_offset_y1) * 0.5; + const scalar_t half_scale_factor = roi_scale_factor * 0.5; + const scalar_t half_roi_off_w = + fma(roi_offset_x1 - roi_offset_x0 + 1, half_scale_factor, scalar_t(-0.5)); + const scalar_t half_roi_off_h = + fma(roi_offset_y1 - roi_offset_y0 + 1, half_scale_factor, scalar_t(-0.5)); + + roi_offset_x0 = roi_off_cx - half_roi_off_w; + roi_offset_x1 = roi_off_cx + half_roi_off_w; + roi_offset_y0 = roi_off_cy - half_roi_off_h; + roi_offset_y1 = roi_off_cy + half_roi_off_h; + } + + const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; + const int height = feat_data.h[target_lvls]; + const int width = feat_data.w[target_lvls]; + const scalar_t* bottom_data = (scalar_t*)feat_data.data[target_lvls]; + + const int roi_batch_ind = offset_bottom_rois[0]; + const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; + const scalar_t roi_start_w = + fma(roi_offset_x0, spatial_scale, offset); // roi_offset_x0 * spatial_scale + offset; + const scalar_t roi_start_h = + fma(roi_offset_y0, spatial_scale, offset); // roi_offset_y0 * spatial_scale + offset; + const scalar_t roi_end_w = + fma(roi_offset_x1, spatial_scale, offset); // (roi_offset_x1) * spatial_scale - offset; + const scalar_t roi_end_h = + fma(roi_offset_y1, spatial_scale, offset); // (roi_offset_y1)*spatial_scale - offset; + + if (pool_mode == 0) + { + const scalar_t output_val = roi_align_single(bottom_data, + roi_batch_ind, + roi_start_w, + roi_start_h, + roi_end_w, + roi_end_h, + spatial_scale, + pw, + ph, + c, + sample_num, + channels, + height, + width, + pooled_height, + pooled_width); + output[index] = output_val; + } + else + { + const scalar_t output_val = roi_align_single(bottom_data, + roi_batch_ind, + roi_start_w, + roi_start_h, + roi_end_w, + roi_end_h, + spatial_scale, + pw, + ph, + c, + sample_num, + channels, + height, + width, + pooled_height, + pooled_width); + output[index] = output_val; + } } - } } -template -void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, - int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int pool_mode, int sample_num, - float roi_scale_factor, int finest_scale, bool aligned, - cudaStream_t stream) { - FeatData feat_data; - feat_data.batch_size = n; - feat_data.channels = c; - feat_data.num_featmap = num_feats; - for (int i = 0; i < num_feats; ++i) { - feat_data.data[i] = feats[i]; - feat_data.h[i] = h[i]; - feat_data.w[i] = w[i]; - feat_data.spatial_scale[i] = 1. / float(strides[i]); - } - int nThreads = num_rois * c * aligned_height * aligned_width; - if (aligned) { - roi_extractor_kernel<<>>( - output, rois, feat_data, pool_mode, sample_num, roi_scale_factor, finest_scale, - aligned_height, aligned_width, nThreads); - } else { - roi_extractor_kernel<<>>( - output, rois, feat_data, pool_mode, sample_num, roi_scale_factor, finest_scale, - aligned_height, aligned_width, nThreads); - } +template +void multi_level_roi_align(T* output, + const T* rois, + int num_rois, + const void* const* feats, + int num_feats, + int n, + int c, + int* h, + int* w, + float* strides, + int aligned_height, + int aligned_width, + int pool_mode, + int sample_num, + float roi_scale_factor, + int finest_scale, + bool aligned, + cudaStream_t stream) +{ + FeatData feat_data; + feat_data.batch_size = n; + feat_data.channels = c; + feat_data.num_featmap = num_feats; + for (int i = 0; i < num_feats; ++i) + { + feat_data.data[i] = feats[i]; + feat_data.h[i] = h[i]; + feat_data.w[i] = w[i]; + feat_data.spatial_scale[i] = 1. / float(strides[i]); + } + int nThreads = num_rois * c * aligned_height * aligned_width; + if (aligned) + { + roi_extractor_kernel<<>>(output, + rois, + feat_data, + pool_mode, + sample_num, + roi_scale_factor, + finest_scale, + aligned_height, + aligned_width, + nThreads); + } + else + { + roi_extractor_kernel<<>>(output, + rois, + feat_data, + pool_mode, + sample_num, + roi_scale_factor, + finest_scale, + aligned_height, + aligned_width, + nThreads); + } } -template void multi_level_roi_align(float *output, const float *rois, int num_rois, - const void *const *feats, int num_feats, int n, int c, - int *h, int *w, float *strides, int aligned_height, - int aligned_width, int pool_mode, int sample_num, - float roi_scale_factor, int finest_scale, bool aligned, - cudaStream_t stream); +template void multi_level_roi_align(float* output, + const float* rois, + int num_rois, + const void* const* feats, + int num_feats, + int n, + int c, + int* h, + int* w, + float* strides, + int aligned_height, + int aligned_width, + int pool_mode, + int sample_num, + float roi_scale_factor, + int finest_scale, + bool aligned, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp index 5f7220dbf0..efd5564a27 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp @@ -3,11 +3,7 @@ #define TRT_MULTI_LEVEL_ROI_ALIGN_KERNEL_HPP #include -template -void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, - int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int pool_mode, int sample_num, - float roi_scale_factor, int finest_scale, bool aligned, - cudaStream_t stream); +template +void multi_level_roi_align(T* output, const T* rois, int num_rois, const void* const* feats, int num_feats, int n, int c, int* h, int* w, float* strides, int aligned_height, int aligned_width, int pool_mode, int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream); #endif // TRT_MULTI_LEVEL_ROI_ALIGN_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp index 6637603128..ec3c282ffe 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp @@ -9,220 +9,309 @@ #include "trt_multi_level_rotated_roi_align_kernel.hpp" #include "trt_plugin_helper.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVMultiLevelRotatedRoiAlign"}; -} // namespace - -TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign( - const std::string &name, int alignedHeight, int alignedWidth, int clockwise, int sampleNum, - const std::vector &featmapStrides, float roiScaleFactor, int finestScale, bool aligned) - : TRTPluginBase(name), - mAlignedHeight(alignedHeight), - mAlignedWidth(alignedWidth), - mClockwise(clockwise), - mSampleNum(sampleNum), - mFeatmapStrides(featmapStrides), - mRoiScaleFactor(roiScaleFactor), - mFinestScale(finestScale), - mAligned(aligned) {} - -TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mAlignedHeight); - deserialize_value(&data, &length, &mAlignedWidth); - deserialize_value(&data, &length, &mClockwise); - deserialize_value(&data, &length, &mSampleNum); - deserialize_value(&data, &length, &mRoiScaleFactor); - deserialize_value(&data, &length, &mFinestScale); - deserialize_value(&data, &length, &mAligned); - deserialize_value(&data, &length, &mFeatmapStrides); -} - -nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRotatedRoiAlign::clone() const TRT_NOEXCEPT { - TRTMultiLevelRotatedRoiAlign *plugin = new TRTMultiLevelRotatedRoiAlign( - mLayerName, mAlignedHeight, mAlignedWidth, mClockwise, mSampleNum, mFeatmapStrides, - mRoiScaleFactor, mFinestScale, mAligned); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTMultiLevelRotatedRoiAlign::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - // warning, nbInputs should equal to mFeatmapStrides.size() + 1 - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[1].d[1]; - ret.d[2] = exprBuilder.constant(mAlignedHeight); - ret.d[3] = exprBuilder.constant(mAlignedWidth); - - return ret; -} - -bool TRTMultiLevelRotatedRoiAlign::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -void TRTMultiLevelRotatedRoiAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments - ASSERT(nbOutputs == 1); - ASSERT(nbInputs >= 1); - mFeatmapStrides = - std::vector(mFeatmapStrides.begin(), mFeatmapStrides.begin() + nbInputs - 1); -} - -size_t TRTMultiLevelRotatedRoiAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTMultiLevelRotatedRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - int num_rois = inputDesc[0].dims.d[0]; - int batch_size = inputDesc[1].dims.d[0]; - int channels = inputDesc[1].dims.d[1]; - - const int kMaxFeatMap = 10; - int heights[kMaxFeatMap]; - int widths[kMaxFeatMap]; - float strides[kMaxFeatMap]; - - int num_feats = mFeatmapStrides.size(); - for (int i = 0; i < num_feats; ++i) { - heights[i] = inputDesc[i + 1].dims.d[2]; - widths[i] = inputDesc[i + 1].dims.d[3]; - strides[i] = mFeatmapStrides[i]; - } - - const void *rois = inputs[0]; - const void *const *feats = inputs + 1; - - multi_level_rotated_roi_align((float *)outputs[0], (const float *)rois, num_rois, feats, - num_feats, batch_size, channels, &heights[0], &widths[0], - &strides[0], mAlignedHeight, mAlignedWidth, mClockwise, - mSampleNum, mRoiScaleFactor, mFinestScale, mAligned, stream); - - return 0; -} - -nvinfer1::DataType TRTMultiLevelRotatedRoiAlign::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return nvinfer1::DataType::kFLOAT; -} - -// IPluginV2 Methods -const char *TRTMultiLevelRotatedRoiAlign::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTMultiLevelRotatedRoiAlign::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int TRTMultiLevelRotatedRoiAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTMultiLevelRotatedRoiAlign::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + - serialized_size(mAlignedWidth) + serialized_size(mClockwise) + - serialized_size(mSampleNum) + serialized_size(mRoiScaleFactor) + - serialized_size(mFinestScale) + serialized_size(mAligned); -} - -void TRTMultiLevelRotatedRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mAlignedHeight); - serialize_value(&buffer, mAlignedWidth); - serialize_value(&buffer, mClockwise); - serialize_value(&buffer, mSampleNum); - serialize_value(&buffer, mRoiScaleFactor); - serialize_value(&buffer, mFinestScale); - serialize_value(&buffer, mAligned); - serialize_value(&buffer, mFeatmapStrides); -} - -TRTMultiLevelRotatedRoiAlignCreator::TRTMultiLevelRotatedRoiAlignCreator() { - mPluginAttributes = std::vector( - {nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"), - nvinfer1::PluginField("clockwise"), nvinfer1::PluginField("sampling_ratio"), - nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"), - nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")}); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTMultiLevelRotatedRoiAlignCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *TRTMultiLevelRotatedRoiAlignCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *TRTMultiLevelRotatedRoiAlignCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int alignedHeight = 7; - int alignedWidth = 7; - int clockwise = 0; - int sampleNum = 2; - std::vector featmapStrides; - float roiScaleFactor = -1; - int finestScale = 56; - bool aligned = false; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("output_height") == 0) { - alignedHeight = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("output_width") == 0) { - alignedWidth = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("clockwise") == 0) { - clockwise = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("sampling_ratio") == 0) { - sampleNum = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("roi_scale_factor") == 0) { - roiScaleFactor = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("finest_scale") == 0) { - finestScale = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("featmap_strides") == 0) { - int data_size = (fc->fields[i].length); - const float *data_start = static_cast(fc->fields[i].data); - featmapStrides = std::vector(data_start, data_start + data_size); - } else if (field_name.compare("aligned") == 0) { - int aligned_int = static_cast(fc->fields[i].data)[0]; - aligned = aligned_int != 0; - } - } - - ASSERT(featmapStrides.size() != 0); - - TRTMultiLevelRotatedRoiAlign *plugin = - new TRTMultiLevelRotatedRoiAlign(name, alignedHeight, alignedWidth, clockwise, sampleNum, - featmapStrides, roiScaleFactor, finestScale, aligned); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTMultiLevelRotatedRoiAlignCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTMultiLevelRotatedRoiAlign(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTMultiLevelRotatedRoiAlignCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVMultiLevelRotatedRoiAlign"}; + } // namespace + + TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign(const std::string& name, + int alignedHeight, + int alignedWidth, + int clockwise, + int sampleNum, + const std::vector& featmapStrides, + float roiScaleFactor, + int finestScale, + bool aligned) + : TRTPluginBase(name) + , mAlignedHeight(alignedHeight) + , mAlignedWidth(alignedWidth) + , mClockwise(clockwise) + , mSampleNum(sampleNum) + , mFeatmapStrides(featmapStrides) + , mRoiScaleFactor(roiScaleFactor) + , mFinestScale(finestScale) + , mAligned(aligned) + { + } + + TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign(const std::string name, + const void* data, + size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mAlignedHeight); + deserialize_value(&data, &length, &mAlignedWidth); + deserialize_value(&data, &length, &mClockwise); + deserialize_value(&data, &length, &mSampleNum); + deserialize_value(&data, &length, &mRoiScaleFactor); + deserialize_value(&data, &length, &mFinestScale); + deserialize_value(&data, &length, &mAligned); + deserialize_value(&data, &length, &mFeatmapStrides); + } + + nvinfer1::IPluginV2DynamicExt* TRTMultiLevelRotatedRoiAlign::clone() const TRT_NOEXCEPT + { + TRTMultiLevelRotatedRoiAlign* plugin = new TRTMultiLevelRotatedRoiAlign(mLayerName, + mAlignedHeight, + mAlignedWidth, + mClockwise, + mSampleNum, + mFeatmapStrides, + mRoiScaleFactor, + mFinestScale, + mAligned); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTMultiLevelRotatedRoiAlign::getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + // warning, nbInputs should equal to mFeatmapStrides.size() + 1 + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[1].d[1]; + ret.d[2] = exprBuilder.constant(mAlignedHeight); + ret.d[3] = exprBuilder.constant(mAlignedWidth); + + return ret; + } + + bool TRTMultiLevelRotatedRoiAlign::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + void TRTMultiLevelRotatedRoiAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + ASSERT(nbOutputs == 1); + ASSERT(nbInputs >= 1); + mFeatmapStrides = + std::vector(mFeatmapStrides.begin(), mFeatmapStrides.begin() + nbInputs - 1); + } + + size_t TRTMultiLevelRotatedRoiAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTMultiLevelRotatedRoiAlign::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int num_rois = inputDesc[0].dims.d[0]; + int batch_size = inputDesc[1].dims.d[0]; + int channels = inputDesc[1].dims.d[1]; + + const int kMaxFeatMap = 10; + int heights[kMaxFeatMap]; + int widths[kMaxFeatMap]; + float strides[kMaxFeatMap]; + + int num_feats = mFeatmapStrides.size(); + for (int i = 0; i < num_feats; ++i) + { + heights[i] = inputDesc[i + 1].dims.d[2]; + widths[i] = inputDesc[i + 1].dims.d[3]; + strides[i] = mFeatmapStrides[i]; + } + + const void* rois = inputs[0]; + const void* const* feats = inputs + 1; + + multi_level_rotated_roi_align((float*)outputs[0], + (const float*)rois, + num_rois, + feats, + num_feats, + batch_size, + channels, + &heights[0], + &widths[0], + &strides[0], + mAlignedHeight, + mAlignedWidth, + mClockwise, + mSampleNum, + mRoiScaleFactor, + mFinestScale, + mAligned, + stream); + + return 0; + } + + nvinfer1::DataType TRTMultiLevelRotatedRoiAlign::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return nvinfer1::DataType::kFLOAT; + } + + // IPluginV2 Methods + const char* TRTMultiLevelRotatedRoiAlign::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTMultiLevelRotatedRoiAlign::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTMultiLevelRotatedRoiAlign::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTMultiLevelRotatedRoiAlign::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + + serialized_size(mAlignedWidth) + serialized_size(mClockwise) + + serialized_size(mSampleNum) + serialized_size(mRoiScaleFactor) + + serialized_size(mFinestScale) + serialized_size(mAligned); + } + + void TRTMultiLevelRotatedRoiAlign::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mAlignedHeight); + serialize_value(&buffer, mAlignedWidth); + serialize_value(&buffer, mClockwise); + serialize_value(&buffer, mSampleNum); + serialize_value(&buffer, mRoiScaleFactor); + serialize_value(&buffer, mFinestScale); + serialize_value(&buffer, mAligned); + serialize_value(&buffer, mFeatmapStrides); + } + + TRTMultiLevelRotatedRoiAlignCreator::TRTMultiLevelRotatedRoiAlignCreator() + { + mPluginAttributes = std::vector({nvinfer1::PluginField("output_height"), + nvinfer1::PluginField("output_width"), + nvinfer1::PluginField("clockwise"), + nvinfer1::PluginField("sampling_ratio"), + nvinfer1::PluginField("featmap_strides"), + nvinfer1::PluginField("roi_scale_factor"), + nvinfer1::PluginField("finest_scale"), + nvinfer1::PluginField("aligned")}); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTMultiLevelRotatedRoiAlignCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTMultiLevelRotatedRoiAlignCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTMultiLevelRotatedRoiAlignCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int alignedHeight = 7; + int alignedWidth = 7; + int clockwise = 0; + int sampleNum = 2; + std::vector featmapStrides; + float roiScaleFactor = -1; + int finestScale = 56; + bool aligned = false; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("output_height") == 0) + { + alignedHeight = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("output_width") == 0) + { + alignedWidth = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("clockwise") == 0) + { + clockwise = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("sampling_ratio") == 0) + { + sampleNum = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("roi_scale_factor") == 0) + { + roiScaleFactor = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("finest_scale") == 0) + { + finestScale = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("featmap_strides") == 0) + { + int data_size = (fc->fields[i].length); + const float* data_start = static_cast(fc->fields[i].data); + featmapStrides = std::vector(data_start, data_start + data_size); + } + else if (field_name.compare("aligned") == 0) + { + int aligned_int = static_cast(fc->fields[i].data)[0]; + aligned = aligned_int != 0; + } + } + + ASSERT(featmapStrides.size() != 0); + + TRTMultiLevelRotatedRoiAlign* plugin = new TRTMultiLevelRotatedRoiAlign(name, + alignedHeight, + alignedWidth, + clockwise, + sampleNum, + featmapStrides, + roiScaleFactor, + finestScale, + aligned); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTMultiLevelRotatedRoiAlignCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTMultiLevelRotatedRoiAlign(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTMultiLevelRotatedRoiAlignCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp index cf0bab7584..906a429f6e 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp @@ -10,70 +10,95 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTMultiLevelRotatedRoiAlign : public TRTPluginBase { - public: - TRTMultiLevelRotatedRoiAlign(const std::string &name, int alignedHeight, int alignedWidth, - int clockwise, int sampleNum, - const std::vector &featmapStrides, float roiScaleFactor = -1, - int finestScale = 56, bool aligned = false); - - TRTMultiLevelRotatedRoiAlign(const std::string name, const void *data, size_t length); - - TRTMultiLevelRotatedRoiAlign() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; - - private: - int mAlignedHeight; - int mAlignedWidth; - int mClockwise; - int mSampleNum; - std::vector mFeatmapStrides; - float mRoiScaleFactor; - int mFinestScale; - bool mAligned; -}; - -class TRTMultiLevelRotatedRoiAlignCreator : public TRTPluginCreatorBase { - public: - TRTMultiLevelRotatedRoiAlignCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class TRTMultiLevelRotatedRoiAlign : public TRTPluginBase + { + public: + TRTMultiLevelRotatedRoiAlign(const std::string& name, + int alignedHeight, + int alignedWidth, + int clockwise, + int sampleNum, + const std::vector& featmapStrides, + float roiScaleFactor = -1, + int finestScale = 56, + bool aligned = false); + + TRTMultiLevelRotatedRoiAlign(const std::string name, + const void* data, + size_t length); + + TRTMultiLevelRotatedRoiAlign() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + + private: + int mAlignedHeight; + int mAlignedWidth; + int mClockwise; + int mSampleNum; + std::vector mFeatmapStrides; + float mRoiScaleFactor; + int mFinestScale; + bool mAligned; + }; + + class TRTMultiLevelRotatedRoiAlignCreator : public TRTPluginCreatorBase + { + public: + TRTMultiLevelRotatedRoiAlignCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu index 1c6f292bae..3b09215547 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu @@ -10,155 +10,236 @@ #include "trt_plugin_helper.hpp" const int kMAX_FEATMAP_SIZE = 10; -struct FeatData { - const void *data[kMAX_FEATMAP_SIZE]; - int batch_size; - int channels; - int h[kMAX_FEATMAP_SIZE]; - int w[kMAX_FEATMAP_SIZE]; - float spatial_scale[kMAX_FEATMAP_SIZE]; - int num_featmap; +struct FeatData +{ + const void* data[kMAX_FEATMAP_SIZE]; + int batch_size; + int channels; + int h[kMAX_FEATMAP_SIZE]; + int w[kMAX_FEATMAP_SIZE]; + float spatial_scale[kMAX_FEATMAP_SIZE]; + int num_featmap; }; -template -__device__ scalar_t roi_align_single(const scalar_t *__restrict__ bottom_data, - const int roi_batch_ind, scalar_t roi_center_w, - scalar_t roi_center_h, scalar_t roi_width, scalar_t roi_height, - scalar_t theta, const scalar_t spatial_scale, const int pw, - const int ph, const int c, const int sample_num, - const int channels, const int height, const int width, - const int pooled_height, const int pooled_width) { - // Force malformed ROIs to be 1x1 - - roi_width = max(roi_width, (scalar_t)1.); - roi_height = max(roi_height, (scalar_t)1.); - - const scalar_t bin_size_h = roi_height / scalar_t(pooled_height); - const scalar_t bin_size_w = roi_width / scalar_t(pooled_width); - - const scalar_t *offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; - - const int roi_bin_grid_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); - const int roi_bin_grid_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); - - const scalar_t roi_start_h = -roi_height / scalar_t(2.0); - const scalar_t roi_start_w = -roi_width / scalar_t(2.0); - const scalar_t cosscalar_theta = cos(theta); - const scalar_t sinscalar_theta = sin(theta); - - // We do average (integral) pooling inside a bin - const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 - - scalar_t output_val = 0.; - - for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1 - const scalar_t yy = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const scalar_t xx = - roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - - // Rotate by theta (counterclockwise) around the center and translate - scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h; - scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w; - - scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); - output_val += val; +template +__device__ scalar_t roi_align_single(const scalar_t* __restrict__ bottom_data, + const int roi_batch_ind, + scalar_t roi_center_w, + scalar_t roi_center_h, + scalar_t roi_width, + scalar_t roi_height, + scalar_t theta, + const scalar_t spatial_scale, + const int pw, + const int ph, + const int c, + const int sample_num, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width) +{ + // Force malformed ROIs to be 1x1 + + roi_width = max(roi_width, (scalar_t)1.); + roi_height = max(roi_height, (scalar_t)1.); + + const scalar_t bin_size_h = roi_height / scalar_t(pooled_height); + const scalar_t bin_size_w = roi_width / scalar_t(pooled_width); + + const scalar_t* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + + const int roi_bin_grid_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); + const int roi_bin_grid_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); + + const scalar_t roi_start_h = -roi_height / scalar_t(2.0); + const scalar_t roi_start_w = -roi_width / scalar_t(2.0); + const scalar_t cosscalar_theta = cos(theta); + const scalar_t sinscalar_theta = sin(theta); + + // We do average (integral) pooling inside a bin + const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + scalar_t output_val = 0.; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { // e.g., iy = 0, 1 + const scalar_t yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const scalar_t xx = + roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + // Rotate by theta (counterclockwise) around the center and translate + scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h; + scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w; + + scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); + output_val += val; + } } - } - return output_val / count; + return output_val / count; } -template -__global__ void rotated_roi_extractor_kernel(scalar_t *__restrict__ output, - const scalar_t *__restrict__ bottom_rois, - FeatData feat_data, const int clockwise, - const int sample_num, const float roi_scale_factor, - const int finest_scale, const int pooled_height, - const int pooled_width, int nThreads) { - CUDA_1D_KERNEL_LOOP(index, nThreads) { - const int channels = feat_data.channels; - int tmp_index = index; - const int pw = tmp_index % pooled_width; - tmp_index /= pooled_width; - const int ph = tmp_index % pooled_height; - tmp_index /= pooled_height; - const int c = tmp_index % channels; - const int n = tmp_index / channels; - - const scalar_t *offset_bottom_rois = bottom_rois + n * 6; - - scalar_t roi_offset_x0 = offset_bottom_rois[1]; - scalar_t roi_offset_y0 = offset_bottom_rois[2]; - scalar_t roi_offset_width = offset_bottom_rois[3]; - scalar_t roi_offset_height = offset_bottom_rois[4]; - scalar_t theta = offset_bottom_rois[5]; - - const scalar_t scale = sqrtf(roi_offset_width * roi_offset_height); - - const int target_lvls = - min(feat_data.num_featmap - 1, - max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); - - if (roi_scale_factor > 0.) { - roi_offset_width = roi_offset_width * roi_scale_factor; - roi_offset_height = roi_offset_height * roi_scale_factor; +template +__global__ void rotated_roi_extractor_kernel(scalar_t* __restrict__ output, + const scalar_t* __restrict__ bottom_rois, + FeatData feat_data, + const int clockwise, + const int sample_num, + const float roi_scale_factor, + const int finest_scale, + const int pooled_height, + const int pooled_width, + int nThreads) +{ + CUDA_1D_KERNEL_LOOP(index, nThreads) + { + const int channels = feat_data.channels; + int tmp_index = index; + const int pw = tmp_index % pooled_width; + tmp_index /= pooled_width; + const int ph = tmp_index % pooled_height; + tmp_index /= pooled_height; + const int c = tmp_index % channels; + const int n = tmp_index / channels; + + const scalar_t* offset_bottom_rois = bottom_rois + n * 6; + + scalar_t roi_offset_x0 = offset_bottom_rois[1]; + scalar_t roi_offset_y0 = offset_bottom_rois[2]; + scalar_t roi_offset_width = offset_bottom_rois[3]; + scalar_t roi_offset_height = offset_bottom_rois[4]; + scalar_t theta = offset_bottom_rois[5]; + + const scalar_t scale = sqrtf(roi_offset_width * roi_offset_height); + + const int target_lvls = + min(feat_data.num_featmap - 1, + max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); + + if (roi_scale_factor > 0.) + { + roi_offset_width = roi_offset_width * roi_scale_factor; + roi_offset_height = roi_offset_height * roi_scale_factor; + } + + const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; + const int height = feat_data.h[target_lvls]; + const int width = feat_data.w[target_lvls]; + const scalar_t* bottom_data = (scalar_t*)feat_data.data[target_lvls]; + + const int roi_batch_ind = offset_bottom_rois[0]; + const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; + const scalar_t roi_center_w = fma(roi_offset_x0, spatial_scale, offset); + const scalar_t roi_center_h = fma(roi_offset_y0, spatial_scale, offset); + const scalar_t roi_width = roi_offset_width * spatial_scale; + const scalar_t roi_height = roi_offset_height * spatial_scale; + + theta = clockwise > 0 ? -theta : theta; + + const scalar_t output_val = roi_align_single(bottom_data, + roi_batch_ind, + roi_center_w, + roi_center_h, + roi_width, + roi_height, + theta, + spatial_scale, + pw, + ph, + c, + sample_num, + channels, + height, + width, + pooled_height, + pooled_width); + output[index] = output_val; } - - const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; - const int height = feat_data.h[target_lvls]; - const int width = feat_data.w[target_lvls]; - const scalar_t *bottom_data = (scalar_t *)feat_data.data[target_lvls]; - - const int roi_batch_ind = offset_bottom_rois[0]; - const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; - const scalar_t roi_center_w = fma(roi_offset_x0, spatial_scale, offset); - const scalar_t roi_center_h = fma(roi_offset_y0, spatial_scale, offset); - const scalar_t roi_width = roi_offset_width * spatial_scale; - const scalar_t roi_height = roi_offset_height * spatial_scale; - - theta = clockwise > 0 ? -theta : theta; - - const scalar_t output_val = roi_align_single( - bottom_data, roi_batch_ind, roi_center_w, roi_center_h, roi_width, roi_height, theta, - spatial_scale, pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width); - output[index] = output_val; - } } -template -void multi_level_rotated_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, - int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int clockwise, - int sample_num, float roi_scale_factor, int finest_scale, - bool aligned, cudaStream_t stream) { - FeatData feat_data; - feat_data.batch_size = n; - feat_data.channels = c; - feat_data.num_featmap = num_feats; - for (int i = 0; i < num_feats; ++i) { - feat_data.data[i] = feats[i]; - feat_data.h[i] = h[i]; - feat_data.w[i] = w[i]; - feat_data.spatial_scale[i] = 1. / float(strides[i]); - } - int nThreads = num_rois * c * aligned_height * aligned_width; - if (aligned) { - rotated_roi_extractor_kernel<<>>( - output, rois, feat_data, clockwise, sample_num, roi_scale_factor, finest_scale, - aligned_height, aligned_width, nThreads); - } else { - rotated_roi_extractor_kernel<<>>( - output, rois, feat_data, clockwise, sample_num, roi_scale_factor, finest_scale, - aligned_height, aligned_width, nThreads); - } +template +void multi_level_rotated_roi_align(T* output, + const T* rois, + int num_rois, + const void* const* feats, + int num_feats, + int n, + int c, + int* h, + int* w, + float* strides, + int aligned_height, + int aligned_width, + int clockwise, + int sample_num, + float roi_scale_factor, + int finest_scale, + bool aligned, + cudaStream_t stream) +{ + FeatData feat_data; + feat_data.batch_size = n; + feat_data.channels = c; + feat_data.num_featmap = num_feats; + for (int i = 0; i < num_feats; ++i) + { + feat_data.data[i] = feats[i]; + feat_data.h[i] = h[i]; + feat_data.w[i] = w[i]; + feat_data.spatial_scale[i] = 1. / float(strides[i]); + } + int nThreads = num_rois * c * aligned_height * aligned_width; + if (aligned) + { + rotated_roi_extractor_kernel<<>>(output, + rois, + feat_data, + clockwise, + sample_num, + roi_scale_factor, + finest_scale, + aligned_height, + aligned_width, + nThreads); + } + else + { + rotated_roi_extractor_kernel<<>>(output, + rois, + feat_data, + clockwise, + sample_num, + roi_scale_factor, + finest_scale, + aligned_height, + aligned_width, + nThreads); + } } -template void multi_level_rotated_roi_align( - float *output, const float *rois, int num_rois, const void *const *feats, int num_feats, int n, - int c, int *h, int *w, float *strides, int aligned_height, int aligned_width, int clockwise, - int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream); +template void multi_level_rotated_roi_align(float* output, + const float* rois, + int num_rois, + const void* const* feats, + int num_feats, + int n, + int c, + int* h, + int* w, + float* strides, + int aligned_height, + int aligned_width, + int clockwise, + int sample_num, + float roi_scale_factor, + int finest_scale, + bool aligned, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp index fc3700df3b..b3f7fc0f94 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp @@ -3,11 +3,24 @@ #define TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP #include -template -void multi_level_rotated_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, - int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int clockwise, - int sample_num, float roi_scale_factor, int finest_scale, - bool aligned, cudaStream_t stream); +template +void multi_level_rotated_roi_align(T* output, + const T* rois, + int num_rois, + const void* const* feats, + int num_feats, + int n, + int c, + int* h, + int* w, + float* strides, + int aligned_height, + int aligned_width, + int clockwise, + int sample_num, + float roi_scale_factor, + int finest_scale, + bool aligned, + cudaStream_t stream); #endif // TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp index d14a25e929..73a3a8e6b9 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp @@ -10,164 +10,226 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVMultiScaleDeformableAttention"}; -} // namespace - -MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic( - const std::string &name) - : TRTPluginBase(name) {} - -MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string name, - const void *data, - size_t length) - : TRTPluginBase(name) {} -MultiScaleDeformableAttnPluginDynamic::~MultiScaleDeformableAttnPluginDynamic() {} - -nvinfer1::IPluginV2DynamicExt *MultiScaleDeformableAttnPluginDynamic::clone() const TRT_NOEXCEPT { - MultiScaleDeformableAttnPluginDynamic *plugin = - new MultiScaleDeformableAttnPluginDynamic(mLayerName); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs MultiScaleDeformableAttnPluginDynamic::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs ret; - ret.nbDims = 3; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[3].d[1]; - - ret.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], *inputs[0].d[3]); - - return ret; -} - -bool MultiScaleDeformableAttnPluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) { - if ((pos == 1) || (pos == 2)) { - return (ioDesc[pos].type == nvinfer1::DataType::kINT32); - } else { - return ((ioDesc[pos].type == ioDesc[0].type) && - ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT) || - (ioDesc[pos].type == nvinfer1::DataType::kHALF))); - } - } else { - return false; - } -} - -void MultiScaleDeformableAttnPluginDynamic::configurePlugin( - const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) TRT_NOEXCEPT {} - -size_t MultiScaleDeformableAttnPluginDynamic::getWorkspaceSize( - const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int MultiScaleDeformableAttnPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int32_t const batch = inputDesc[0].dims.d[0]; - int32_t spatial_size = inputDesc[0].dims.d[1]; - int32_t num_heads = inputDesc[0].dims.d[2]; - int32_t channels = inputDesc[0].dims.d[3]; - int32_t num_levels = inputDesc[1].dims.d[0]; - int32_t num_query = inputDesc[3].dims.d[1]; - int32_t num_point = inputDesc[3].dims.d[4]; - int32_t rc = 0; - if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) { - float const *value = static_cast(inputs[0]); - int32_t const *spatialShapes = static_cast(inputs[1]); - int32_t const *levelStartIndex = static_cast(inputs[2]); - float const *samplingLoc = static_cast(inputs[3]); - float const *attnWeight = static_cast(inputs[4]); - float *output = static_cast(outputs[0]); - - rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, - output, batch, spatial_size, num_heads, channels, num_levels, - num_query, num_point, stream); - } else if (inputDesc[0].type == nvinfer1::DataType::kHALF) { - const __half *value = static_cast(inputs[0]); - int32_t const *spatialShapes = static_cast(inputs[1]); - int32_t const *levelStartIndex = static_cast(inputs[2]); - const __half *samplingLoc = static_cast(inputs[3]); - const __half *attnWeight = static_cast(inputs[4]); - __half *output = static_cast<__half *>(outputs[0]); - - rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, - output, batch, spatial_size, num_heads, channels, num_levels, - num_query, num_point, stream); - } - - return rc; -} - -nvinfer1::DataType MultiScaleDeformableAttnPluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *MultiScaleDeformableAttnPluginDynamic::getPluginType() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *MultiScaleDeformableAttnPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int MultiScaleDeformableAttnPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t MultiScaleDeformableAttnPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { - return 0; -} - -void MultiScaleDeformableAttnPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {} - -void MultiScaleDeformableAttnPluginDynamic::attachToContext( - cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {} - -void MultiScaleDeformableAttnPluginDynamic::detachFromContext() TRT_NOEXCEPT {} - -////////////////////// creator ///////////////////////////// - -MultiScaleDeformableAttnPluginDynamicCreator::MultiScaleDeformableAttnPluginDynamicCreator() { - mPluginAttributes.clear(); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(name); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new MultiScaleDeformableAttnPluginDynamic(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(MultiScaleDeformableAttnPluginDynamicCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVMultiScaleDeformableAttention"}; + } // namespace + + MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic( + const std::string& name) + : TRTPluginBase(name) + { + } + + MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string name, + const void* data, + size_t length) + : TRTPluginBase(name) + { + } + MultiScaleDeformableAttnPluginDynamic::~MultiScaleDeformableAttnPluginDynamic() {} + + nvinfer1::IPluginV2DynamicExt* MultiScaleDeformableAttnPluginDynamic::clone() const TRT_NOEXCEPT + { + MultiScaleDeformableAttnPluginDynamic* plugin = + new MultiScaleDeformableAttnPluginDynamic(mLayerName); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs MultiScaleDeformableAttnPluginDynamic::getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[3].d[1]; + + ret.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], *inputs[0].d[3]); + + return ret; + } + + bool MultiScaleDeformableAttnPluginDynamic::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) + { + if ((pos == 1) || (pos == 2)) + { + return (ioDesc[pos].type == nvinfer1::DataType::kINT32); + } + else + { + return ((ioDesc[pos].type == ioDesc[0].type) && + ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT) || + (ioDesc[pos].type == nvinfer1::DataType::kHALF))); + } + } + else + { + return false; + } + } + + void MultiScaleDeformableAttnPluginDynamic::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + + size_t MultiScaleDeformableAttnPluginDynamic::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int MultiScaleDeformableAttnPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int32_t const batch = inputDesc[0].dims.d[0]; + int32_t spatial_size = inputDesc[0].dims.d[1]; + int32_t num_heads = inputDesc[0].dims.d[2]; + int32_t channels = inputDesc[0].dims.d[3]; + int32_t num_levels = inputDesc[1].dims.d[0]; + int32_t num_query = inputDesc[3].dims.d[1]; + int32_t num_point = inputDesc[3].dims.d[4]; + int32_t rc = 0; + if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) + { + float const* value = static_cast(inputs[0]); + int32_t const* spatialShapes = static_cast(inputs[1]); + int32_t const* levelStartIndex = static_cast(inputs[2]); + float const* samplingLoc = static_cast(inputs[3]); + float const* attnWeight = static_cast(inputs[4]); + float* output = static_cast(outputs[0]); + + rc = ms_deform_attn_cuda_forward(value, + spatialShapes, + levelStartIndex, + samplingLoc, + attnWeight, + output, + batch, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + stream); + } + else if (inputDesc[0].type == nvinfer1::DataType::kHALF) + { + const __half* value = static_cast(inputs[0]); + int32_t const* spatialShapes = static_cast(inputs[1]); + int32_t const* levelStartIndex = static_cast(inputs[2]); + const __half* samplingLoc = static_cast(inputs[3]); + const __half* attnWeight = static_cast(inputs[4]); + __half* output = static_cast<__half*>(outputs[0]); + + rc = ms_deform_attn_cuda_forward(value, + spatialShapes, + levelStartIndex, + samplingLoc, + attnWeight, + output, + batch, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + stream); + } + + return rc; + } + + nvinfer1::DataType MultiScaleDeformableAttnPluginDynamic::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* MultiScaleDeformableAttnPluginDynamic::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* MultiScaleDeformableAttnPluginDynamic::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int MultiScaleDeformableAttnPluginDynamic::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t MultiScaleDeformableAttnPluginDynamic::getSerializationSize() const TRT_NOEXCEPT + { + return 0; + } + + void MultiScaleDeformableAttnPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {} + + void MultiScaleDeformableAttnPluginDynamic::attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} + + void MultiScaleDeformableAttnPluginDynamic::detachFromContext() TRT_NOEXCEPT {} + + ////////////////////// creator ///////////////////////////// + + MultiScaleDeformableAttnPluginDynamicCreator::MultiScaleDeformableAttnPluginDynamicCreator() + { + mPluginAttributes.clear(); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* MultiScaleDeformableAttnPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* MultiScaleDeformableAttnPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* MultiScaleDeformableAttnPluginDynamicCreator::createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + MultiScaleDeformableAttnPluginDynamic* plugin = new MultiScaleDeformableAttnPluginDynamic(name); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* MultiScaleDeformableAttnPluginDynamicCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new MultiScaleDeformableAttnPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(MultiScaleDeformableAttnPluginDynamicCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp index 7e66e9e54d..62821e27ed 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp @@ -9,62 +9,85 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class MultiScaleDeformableAttnPluginDynamic : public TRTPluginBase { - public: - MultiScaleDeformableAttnPluginDynamic(const std::string &name); - - MultiScaleDeformableAttnPluginDynamic(const std::string name, const void *data, size_t length); - - MultiScaleDeformableAttnPluginDynamic(); - - ~MultiScaleDeformableAttnPluginDynamic() TRT_NOEXCEPT override; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; -}; - -class MultiScaleDeformableAttnPluginDynamicCreator : public TRTPluginCreatorBase { - public: - MultiScaleDeformableAttnPluginDynamicCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class MultiScaleDeformableAttnPluginDynamic : public TRTPluginBase + { + public: + MultiScaleDeformableAttnPluginDynamic(const std::string& name); + + MultiScaleDeformableAttnPluginDynamic(const std::string name, + const void* data, + size_t length); + + MultiScaleDeformableAttnPluginDynamic(); + + ~MultiScaleDeformableAttnPluginDynamic() TRT_NOEXCEPT override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + void attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + }; + + class MultiScaleDeformableAttnPluginDynamicCreator : public TRTPluginCreatorBase + { + public: + MultiScaleDeformableAttnPluginDynamicCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_MS_DEFORM_ATTN_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu index 6b7588eae0..2d10a1ee9f 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu @@ -7,58 +7,113 @@ #include "trt_ms_deform_attn_kernel.hpp" #include "trt_plugin_helper.hpp" -template -void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, - int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, - scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, - int32_t const batchSize, int32_t const spatialSize, - int32_t const numHeads, int32_t const channels, - int32_t const numLevels, int32_t const numQuery, - int32_t const numPoint, scalar_t* dataCol) { - int32_t const numKernels = batchSize * numQuery * numHeads * channels; - int32_t const numActualKernels = batchSize * numQuery * numHeads * channels; +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + scalar_t const* dataValue, + int32_t const* dataSpatialShapes, + int32_t const* dataLevelStartIndex, + scalar_t const* dataSamplingLoc, + scalar_t const* dataAttnWeight, + int32_t const batchSize, + int32_t const spatialSize, + int32_t const numHeads, + int32_t const channels, + int32_t const numLevels, + int32_t const numQuery, + int32_t const numPoint, + scalar_t* dataCol) +{ + int32_t const numKernels = batchSize * numQuery * numHeads * channels; + int32_t const numActualKernels = batchSize * numQuery * numHeads * channels; - ms_deformable_im2col_gpu_kernel - <<>>( - numKernels, dataValue, dataSpatialShapes, dataLevelStartIndex, dataSamplingLoc, - dataAttnWeight, batchSize, spatialSize, numHeads, channels, numLevels, numQuery, numPoint, - dataCol); + ms_deformable_im2col_gpu_kernel + <<>>(numKernels, + dataValue, + dataSpatialShapes, + dataLevelStartIndex, + dataSamplingLoc, + dataAttnWeight, + batchSize, + spatialSize, + numHeads, + channels, + numLevels, + numQuery, + numPoint, + dataCol); } -template -int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes, - const int32_t* levelStartIndex, const scalar_t* samplingLoc, - const scalar_t* attnWeight, scalar_t* output, int32_t batch, - int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, - int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint, - cudaStream_t stream) { - auto perValueSize = mSpatialSize * mNumHeads * mChannels; - auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2; - auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint; - auto perOutputSize = mNumQuery * mNumHeads * mChannels; +template +int32_t ms_deform_attn_cuda_forward(const scalar_t* value, + const int32_t* spatialShapes, + const int32_t* levelStartIndex, + const scalar_t* samplingLoc, + const scalar_t* attnWeight, + scalar_t* output, + int32_t batch, + int32_t mSpatialSize, + int32_t mNumHeads, + int32_t mChannels, + int32_t mNumLevels, + int32_t mNumQuery, + int32_t mNumPoint, + cudaStream_t stream) +{ + auto perValueSize = mSpatialSize * mNumHeads * mChannels; + auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2; + auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint; + auto perOutputSize = mNumQuery * mNumHeads * mChannels; - int32_t mIm2colStep = batch; + int32_t mIm2colStep = batch; - for (int32_t n = 0; n < batch / mIm2colStep; ++n) { - auto columns = output + n * mIm2colStep * perOutputSize; - ms_deformable_im2col_cuda( - stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex, - samplingLoc + n * mIm2colStep * perSampleLocSize, - attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep, mSpatialSize, mNumHeads, - mChannels, mNumLevels, mNumQuery, mNumPoint, columns); - } + for (int32_t n = 0; n < batch / mIm2colStep; ++n) + { + auto columns = output + n * mIm2colStep * perOutputSize; + ms_deformable_im2col_cuda(stream, + value + n * mIm2colStep * perValueSize, + spatialShapes, + levelStartIndex, + samplingLoc + n * mIm2colStep * perSampleLocSize, + attnWeight + n * mIm2colStep * perAttnWeightSize, + mIm2colStep, + mSpatialSize, + mNumHeads, + mChannels, + mNumLevels, + mNumQuery, + mNumPoint, + columns); + } - return 0; + return 0; } -template int32_t ms_deform_attn_cuda_forward( - const float* value, const int32_t* spatialShapes, const int32_t* levelStartIndex, - const float* samplingLoc, const float* attnWeight, float* output, int32_t batch, - int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, - int32_t mNumQuery, int32_t mNumPoint, cudaStream_t stream); +template int32_t ms_deform_attn_cuda_forward(const float* value, + const int32_t* spatialShapes, + const int32_t* levelStartIndex, + const float* samplingLoc, + const float* attnWeight, + float* output, + int32_t batch, + int32_t mSpatialSize, + int32_t mNumHeads, + int32_t mChannels, + int32_t mNumLevels, + int32_t mNumQuery, + int32_t mNumPoint, + cudaStream_t stream); -template int32_t ms_deform_attn_cuda_forward<__half>( - const __half* value, const int32_t* spatialShapes, const int32_t* levelStartIndex, - const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch, - int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, - int32_t mNumQuery, int32_t mNumPoint, cudaStream_t stream); +template int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, + const int32_t* spatialShapes, + const int32_t* levelStartIndex, + const __half* samplingLoc, + const __half* attnWeight, + __half* output, + int32_t batch, + int32_t mSpatialSize, + int32_t mNumHeads, + int32_t mChannels, + int32_t mNumLevels, + int32_t mNumQuery, + int32_t mNumPoint, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh index cee34cfe65..0bef6ed98c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh @@ -4,254 +4,323 @@ #include "common_cuda_helper.hpp" -template -__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t*& bottom_data, const int& height, - const int& width, const int& nheads, - const int& channels, const scalar_t& h, - const scalar_t& w, const int& m, const int& c) { - const int h_low = floorf(h); - const int w_low = floorf(w); - const int h_high = h_low + 1; - const int w_high = w_low + 1; +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t*& bottom_data, + const int& height, + const int& width, + const int& nheads, + const int& channels, + const scalar_t& h, + const scalar_t& w, + const int& m, + const int& c) +{ + const int h_low = floorf(h); + const int w_low = floorf(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; - const scalar_t lh = h - h_low; - const scalar_t lw = w - w_low; - const scalar_t hh = 1 - lh, hw = 1 - lw; + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; - const int w_stride = nheads * channels; - const int h_stride = width * w_stride; - const int h_low_ptr_offset = h_low * h_stride; - const int h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int w_low_ptr_offset = w_low * w_stride; - const int w_high_ptr_offset = w_low_ptr_offset + w_stride; - const int base_ptr = m * channels + c; + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) { - const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - v1 = bottom_data[ptr1]; - } - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) { - const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - v2 = bottom_data[ptr2]; - } - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) { - const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - v3 = bottom_data[ptr3]; - } - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) { - const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - v4 = bottom_data[ptr4]; - } + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; } -template <> -__device__ __half ms_deform_attn_im2col_bilinear<__half>( - const __half*& bottomData, int32_t const& height, int32_t const& width, int32_t const& nHeads, - int32_t const& channels, const __half& h, const __half& w, int32_t const& m, int32_t const& c) { - int32_t const hLow = __half2int_rd(h); - int32_t const wLow = __half2int_rd(w); - int32_t const hHigh = hLow + 1; - int32_t const wHigh = wLow + 1; +template<> +__device__ __half ms_deform_attn_im2col_bilinear<__half>(const __half*& bottomData, + int32_t const& height, + int32_t const& width, + int32_t const& nHeads, + int32_t const& channels, + const __half& h, + const __half& w, + int32_t const& m, + int32_t const& c) +{ + int32_t const hLow = __half2int_rd(h); + int32_t const wLow = __half2int_rd(w); + int32_t const hHigh = hLow + 1; + int32_t const wHigh = wLow + 1; - const __half kZERO = __int2half_rz(0); - const __half one = __int2half_rz(1); + const __half kZERO = __int2half_rz(0); + const __half one = __int2half_rz(1); #if __CUDA_ARCH__ >= 530 - const __half lh = __hsub(h, __int2half_rd(hLow)); - const __half lw = __hsub(w, __int2half_rd(wLow)); - const __half hh = __hsub(one, lh), hw = __hsub(one, lw); + const __half lh = __hsub(h, __int2half_rd(hLow)); + const __half lw = __hsub(w, __int2half_rd(wLow)); + const __half hh = __hsub(one, lh), hw = __hsub(one, lw); #else - const __half lh = __float2half(__half2float(h) - hLow); - const __half lw = __float2half(__half2float(w) - wLow); - const __half hh = __float2half(__half2float(one) - __half2float(lh)); - const __half hw = __float2half(__half2float(one) - __half2float(lw)); + const __half lh = __float2half(__half2float(h) - hLow); + const __half lw = __float2half(__half2float(w) - wLow); + const __half hh = __float2half(__half2float(one) - __half2float(lh)); + const __half hw = __float2half(__half2float(one) - __half2float(lw)); #endif - int32_t const wStride = nHeads * channels; - int32_t const hStride = width * wStride; - int32_t const hLowPtrOffset = hLow * hStride; - int32_t const hHighPtrOffset = hLowPtrOffset + hStride; - int32_t const wLowPtrOffset = wLow * wStride; - int32_t const wHighPtrOffset = wLowPtrOffset + wStride; - int32_t const basePtr = m * channels + c; + int32_t const wStride = nHeads * channels; + int32_t const hStride = width * wStride; + int32_t const hLowPtrOffset = hLow * hStride; + int32_t const hHighPtrOffset = hLowPtrOffset + hStride; + int32_t const wLowPtrOffset = wLow * wStride; + int32_t const wHighPtrOffset = wLowPtrOffset + wStride; + int32_t const basePtr = m * channels + c; - __half v1 = kZERO; - if (hLow >= 0 && wLow >= 0) { - int32_t const ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; - v1 = bottomData[ptr1]; - } - __half v2 = kZERO; - if (hLow >= 0 && wHigh <= width - 1) { - int32_t const ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; - v2 = bottomData[ptr2]; - } - __half v3 = kZERO; - if (hHigh <= height - 1 && wLow >= 0) { - int32_t const ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; - v3 = bottomData[ptr3]; - } - __half v4 = kZERO; - if (hHigh <= height - 1 && wHigh <= width - 1) { - int32_t const ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; - v4 = bottomData[ptr4]; - } + __half v1 = kZERO; + if (hLow >= 0 && wLow >= 0) + { + int32_t const ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; + v1 = bottomData[ptr1]; + } + __half v2 = kZERO; + if (hLow >= 0 && wHigh <= width - 1) + { + int32_t const ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; + v2 = bottomData[ptr2]; + } + __half v3 = kZERO; + if (hHigh <= height - 1 && wLow >= 0) + { + int32_t const ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; + v3 = bottomData[ptr3]; + } + __half v4 = kZERO; + if (hHigh <= height - 1 && wHigh <= width - 1) + { + int32_t const ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; + v4 = bottomData[ptr4]; + } #if __CUDA_ARCH__ >= 530 - __half w1 = __hmul(__hmul(hh, hw), v1); - __half w2 = __hmul(__hmul(hh, lw), v2); - __half w3 = __hmul(__hmul(lh, hw), v3); - __half w4 = __hmul(__hmul(lh, lw), v4); + __half w1 = __hmul(__hmul(hh, hw), v1); + __half w2 = __hmul(__hmul(hh, lw), v2); + __half w3 = __hmul(__hmul(lh, hw), v3); + __half w4 = __hmul(__hmul(lh, lw), v4); - w1 = __hadd(w1, w2); - w3 = __hadd(w3, w4); + w1 = __hadd(w1, w2); + w3 = __hadd(w3, w4); - const __half val = __hadd(w1, w3); + const __half val = __hadd(w1, w3); #else - __half w1 = __float2half((__half2float(hh) * __half2float(hw)) * __half2float(v1)); - __half w2 = __float2half((__half2float(hh) * __half2float(lw)) * __half2float(v2)); - __half w3 = __float2half((__half2float(lh) * __half2float(hw)) * __half2float(v3)); - __half w4 = __float2half((__half2float(lh) * __half2float(lw)) * __half2float(v4)); + __half w1 = __float2half((__half2float(hh) * __half2float(hw)) * __half2float(v1)); + __half w2 = __float2half((__half2float(hh) * __half2float(lw)) * __half2float(v2)); + __half w3 = __float2half((__half2float(lh) * __half2float(hw)) * __half2float(v3)); + __half w4 = __float2half((__half2float(lh) * __half2float(lw)) * __half2float(v4)); - w1 = __float2half(__half2float(w1) + __half2float(w2)); - w3 = __float2half(__half2float(w3) + __half2float(w4)); + w1 = __float2half(__half2float(w1) + __half2float(w2)); + w3 = __float2half(__half2float(w3) + __half2float(w4)); - const __half val = __float2half(__half2float(w1) + __half2float(w3)); + const __half val = __float2half(__half2float(w1) + __half2float(w3)); #endif - return val; + return val; } #if 1 -template -__global__ void ms_deformable_im2col_gpu_kernel( - int32_t const n, scalar_t const* dataValue, int32_t const* dataSpatialShapes, - int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, - scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, - int32_t const numHeads, int32_t const channels, int32_t const numLevels, int32_t const numQuery, - int32_t const numPoint, scalar_t* dataCol) { - CUDA_1D_KERNEL_LOOP(index, n) { - int32_t _temp = index; - int32_t const cCol = _temp % channels; - _temp /= channels; - int32_t const samplingIndex = _temp; - int32_t const mCol = _temp % numHeads; - _temp /= numHeads; - _temp /= numQuery; - int32_t const bCol = _temp; +template +__global__ void ms_deformable_im2col_gpu_kernel(int32_t const n, + scalar_t const* dataValue, + int32_t const* dataSpatialShapes, + int32_t const* dataLevelStartIndex, + scalar_t const* dataSamplingLoc, + scalar_t const* dataAttnWeight, + int32_t const batchSize, + int32_t const spatialSize, + int32_t const numHeads, + int32_t const channels, + int32_t const numLevels, + int32_t const numQuery, + int32_t const numPoint, + scalar_t* dataCol) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + int32_t _temp = index; + int32_t const cCol = _temp % channels; + _temp /= channels; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; + _temp /= numHeads; + _temp /= numQuery; + int32_t const bCol = _temp; - scalar_t* dataColPtr = dataCol + index; - int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; - int32_t dataLocWPtr = dataWeightPtr << 1; - int32_t const qidStride = numHeads * channels; - int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - scalar_t col = 0; + scalar_t* dataColPtr = dataCol + index; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; + scalar_t col = 0; - for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - int32_t const levelStartId = dataLevelStartIndex[lCol]; - int32_t const spatialHPtr = lCol << 1; - int32_t const spatialH = dataSpatialShapes[spatialHPtr]; - int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; - scalar_t const* dataValuePtr = - dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); - for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - scalar_t const locW = dataSamplingLoc[dataLocWPtr]; - scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; - scalar_t const weight = dataAttnWeight[dataWeightPtr]; + for (int32_t lCol = 0; lCol < numLevels; ++lCol) + { + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + scalar_t const* dataValuePtr = + dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); + for (int32_t pCol = 0; pCol < numPoint; ++pCol) + { + scalar_t const locW = dataSamplingLoc[dataLocWPtr]; + scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; + scalar_t const weight = dataAttnWeight[dataWeightPtr]; - scalar_t const hIm = locH * spatialH - 0.5; - scalar_t const wIm = locW * spatialW - 0.5; + scalar_t const hIm = locH * spatialH - 0.5; + scalar_t const wIm = locW * spatialW - 0.5; - if (hIm > -1 && wIm > -1 && hIm < spatialH && wIm < spatialW) { - col += ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, - channels, hIm, wIm, mCol, cCol) * - weight; - } + if (hIm > -1 && wIm > -1 && hIm < spatialH && wIm < spatialW) + { + col += ms_deform_attn_im2col_bilinear(dataValuePtr, + spatialH, + spatialW, + numHeads, + channels, + hIm, + wIm, + mCol, + cCol) * + weight; + } - dataWeightPtr += 1; - dataLocWPtr += 2; - } + dataWeightPtr += 1; + dataLocWPtr += 2; + } + } + *dataColPtr = col; } - *dataColPtr = col; - } } -template <> -__global__ void ms_deformable_im2col_gpu_kernel<__half>( - int32_t const n, const __half* dataValue, int32_t const* dataSpatialShapes, - int32_t const* dataLevelStartIndex, const __half* dataSamplingLoc, const __half* dataAttnWeight, - int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, - int32_t const channels, int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, - __half* dataCol) { - CUDA_1D_KERNEL_LOOP(index, n) { - int32_t _temp = index; - int32_t const cCol = _temp % channels; - _temp /= channels; - int32_t const samplingIndex = _temp; - int32_t const mCol = _temp % numHeads; - _temp /= numHeads; - _temp /= numQuery; - int32_t const bCol = _temp; +template<> +__global__ void ms_deformable_im2col_gpu_kernel<__half>(int32_t const n, + const __half* dataValue, + int32_t const* dataSpatialShapes, + int32_t const* dataLevelStartIndex, + const __half* dataSamplingLoc, + const __half* dataAttnWeight, + int32_t const batchSize, + int32_t const spatialSize, + int32_t const numHeads, + int32_t const channels, + int32_t const numLevels, + int32_t const numQuery, + int32_t const numPoint, + __half* dataCol) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + int32_t _temp = index; + int32_t const cCol = _temp % channels; + _temp /= channels; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; + _temp /= numHeads; + _temp /= numQuery; + int32_t const bCol = _temp; - __half* dataColPtr = dataCol + index; - int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; - int32_t dataLocWPtr = dataWeightPtr << 1; - int32_t const qidStride = numHeads * channels; - int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - const __half kZERO_POINT_FIVE = __float2half(0.5f); - const __half kMINUS_ONE = __float2half(-1.0f); - const __half kZERO = __int2half_rz(0); - __half tpVal = kZERO; - __half col = kZERO; + __half* dataColPtr = dataCol + index; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; + const __half kZERO_POINT_FIVE = __float2half(0.5f); + const __half kMINUS_ONE = __float2half(-1.0f); + const __half kZERO = __int2half_rz(0); + __half tpVal = kZERO; + __half col = kZERO; - for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - int32_t const levelStartId = dataLevelStartIndex[lCol]; - int32_t const spatialHPtr = lCol << 1; - int32_t const spatialH = dataSpatialShapes[spatialHPtr]; - int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; - const __half spatialHHalf = __int2half_rd(spatialH); - const __half spatialWHalf = __int2half_rd(spatialW); - const __half* dataValuePtr = dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); - for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - const __half locW = dataSamplingLoc[dataLocWPtr]; - const __half locH = dataSamplingLoc[dataLocWPtr + 1]; - const __half weight = dataAttnWeight[dataWeightPtr]; -#if __CUDA_ARCH__ >= 530 - const __half hIm = __hsub(__hmul(locH, spatialHHalf), kZERO_POINT_FIVE); - const __half wIm = __hsub(__hmul(locW, spatialWHalf), kZERO_POINT_FIVE); + for (int32_t lCol = 0; lCol < numLevels; ++lCol) + { + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + const __half spatialHHalf = __int2half_rd(spatialH); + const __half spatialWHalf = __int2half_rd(spatialW); + const __half* dataValuePtr = dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); + for (int32_t pCol = 0; pCol < numPoint; ++pCol) + { + const __half locW = dataSamplingLoc[dataLocWPtr]; + const __half locH = dataSamplingLoc[dataLocWPtr + 1]; + const __half weight = dataAttnWeight[dataWeightPtr]; + #if __CUDA_ARCH__ >= 530 + const __half hIm = __hsub(__hmul(locH, spatialHHalf), kZERO_POINT_FIVE); + const __half wIm = __hsub(__hmul(locW, spatialWHalf), kZERO_POINT_FIVE); - if (__hgt(hIm, kMINUS_ONE) && __hgt(wIm, kMINUS_ONE) && __hlt(hIm, spatialHHalf) && - __hlt(wIm, spatialWHalf)) { - tpVal = ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, - channels, hIm, wIm, mCol, cCol); - col = __hadd(col, __hmul(tpVal, weight)); - } -#else - const __half hIm = __float2half(__half2float(locH) * __half2float(spatialHHalf) - - __half2float(kZERO_POINT_FIVE)); - const __half wIm = __float2half(__half2float(locW) * __half2float(spatialWHalf) - - __half2float(kZERO_POINT_FIVE)); + if (__hgt(hIm, kMINUS_ONE) && __hgt(wIm, kMINUS_ONE) && __hlt(hIm, spatialHHalf) && + __hlt(wIm, spatialWHalf)) + { + tpVal = ms_deform_attn_im2col_bilinear(dataValuePtr, + spatialH, + spatialW, + numHeads, + channels, + hIm, + wIm, + mCol, + cCol); + col = __hadd(col, __hmul(tpVal, weight)); + } + #else + const __half hIm = __float2half(__half2float(locH) * __half2float(spatialHHalf) - + __half2float(kZERO_POINT_FIVE)); + const __half wIm = __float2half(__half2float(locW) * __half2float(spatialWHalf) - + __half2float(kZERO_POINT_FIVE)); - if ((__half2float(hIm) > __half2float(kMINUS_ONE)) && - (__half2float(wIm) > __half2float(kMINUS_ONE)) && - (__half2float(hIm) < __half2float(spatialHHalf)) && - (__half2float(wIm) < __half2float(spatialWHalf))) { - tpVal = ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, - channels, hIm, wIm, mCol, cCol); - col = __float2half(__half2float(col) + (__half2float(tpVal) * __half2float(weight))); + if ((__half2float(hIm) > __half2float(kMINUS_ONE)) && + (__half2float(wIm) > __half2float(kMINUS_ONE)) && + (__half2float(hIm) < __half2float(spatialHHalf)) && + (__half2float(wIm) < __half2float(spatialWHalf))) + { + tpVal = ms_deform_attn_im2col_bilinear(dataValuePtr, + spatialH, + spatialW, + numHeads, + channels, + hIm, + wIm, + mCol, + cCol); + col = __float2half(__half2float(col) + (__half2float(tpVal) * __half2float(weight))); + } + #endif + dataWeightPtr += 1; + dataLocWPtr += 2; + } } -#endif - dataWeightPtr += 1; - dataLocWPtr += 2; - } + *dataColPtr = col; } - *dataColPtr = col; - } } #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp index adbe2566fd..b052c8ce7c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp @@ -4,12 +4,20 @@ #include #include -template -int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes, - const int32_t* levelStartIndex, const scalar_t* samplingLoc, - const scalar_t* attnWeight, scalar_t* output, int32_t batch, - int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, - int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint, - cudaStream_t stream); +template +int32_t ms_deform_attn_cuda_forward(const scalar_t* value, + const int32_t* spatialShapes, + const int32_t* levelStartIndex, + const scalar_t* samplingLoc, + const scalar_t* attnWeight, + scalar_t* output, + int32_t batch, + int32_t mSpatialSize, + int32_t mNumHeads, + int32_t mChannels, + int32_t mNumLevels, + int32_t mNumQuery, + int32_t mNumPoint, + cudaStream_t stream); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp index 988893125d..0d71885676 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp @@ -9,233 +9,315 @@ #include "trt_roi_align_kernel.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVRoiAlign"}; -} // namespace - -TRTRoIAlign::TRTRoIAlign(const std::string &name, int outWidth, int outHeight, float spatialScale, - int sampleRatio, int poolMode, bool aligned) - : TRTPluginBase(name), - mOutWidth(outWidth), - mOutHeight(outHeight), - mSpatialScale(spatialScale), - mSampleRatio(sampleRatio), - mPoolMode(poolMode), - mAligned(aligned) {} - -TRTRoIAlign::TRTRoIAlign(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mOutWidth); - deserialize_value(&data, &length, &mOutHeight); - deserialize_value(&data, &length, &mSpatialScale); - deserialize_value(&data, &length, &mSampleRatio); - deserialize_value(&data, &length, &mPoolMode); - deserialize_value(&data, &length, &mAligned); -} - -nvinfer1::IPluginV2DynamicExt *TRTRoIAlign::clone() const TRT_NOEXCEPT { - TRTRoIAlign *plugin = new TRTRoIAlign(mLayerName, mOutWidth, mOutHeight, mSpatialScale, - mSampleRatio, mPoolMode, mAligned); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[1].d[0]; - ret.d[1] = inputs[0].d[1]; - ret.d[2] = exprBuilder.constant(mOutHeight); - ret.d[3] = exprBuilder.constant(mOutWidth); - - return ret; -} - -bool TRTRoIAlign::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -void TRTRoIAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t TRTRoIAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - size_t output_size = 0; - size_t word_size = 0; - switch (mPoolMode) { - case 0: // max - output_size = - outputs[0].dims.d[0] * outputs[0].dims.d[1] * outputs[0].dims.d[2] * outputs[0].dims.d[3]; - word_size = mmdeploy::getElementSize(outputs[0].type); - return output_size * word_size * 2; - break; - case 1: - return 0; - break; - default: - return 0; - } - return 0; -} - -int TRTRoIAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - int channels = inputDesc[0].dims.d[1]; - int height = inputDesc[0].dims.d[2]; - int width = inputDesc[0].dims.d[3]; - - int output_size = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2] * - outputDesc[0].dims.d[3]; - int word_size = mmdeploy::getElementSize(outputDesc[0].type); - - const void *feat = inputs[0]; - const void *rois = inputs[1]; - void *output = outputs[0]; - void *argmax_y = nullptr; - void *argmax_x = nullptr; - - switch (mPoolMode) { - case 0: // max - argmax_y = workSpace; - argmax_x = (char *)argmax_y + output_size * word_size; - break; - case 1: // avg - break; - } - - switch (outputDesc[0].type) { - case nvinfer1::DataType::kFLOAT: - TRTRoIAlignForwardCUDAKernelLauncher( - (const float *)feat, (const float *)rois, (float *)output, (float *)argmax_y, - (float *)argmax_x, output_size, channels, height, width, mOutHeight, mOutWidth, - mSpatialScale, mSampleRatio, mPoolMode, mAligned, stream); - break; - - default: - break; - } - - return 0; -} - -nvinfer1::DataType TRTRoIAlign::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *TRTRoIAlign::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTRoIAlign::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTRoIAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTRoIAlign::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mOutWidth) + serialized_size(mOutHeight) + serialized_size(mSpatialScale) + - serialized_size(mSampleRatio) + serialized_size(mPoolMode) + serialized_size(mAligned); -} - -void TRTRoIAlign::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mOutWidth); - serialize_value(&buffer, mOutHeight); - serialize_value(&buffer, mSpatialScale); - serialize_value(&buffer, mSampleRatio); - serialize_value(&buffer, mPoolMode); - serialize_value(&buffer, mAligned); -} - -TRTRoIAlignCreator::TRTRoIAlignCreator() { - mPluginAttributes.emplace_back(nvinfer1::PluginField("output_height")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("output_width")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("spatial_scale")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("sampling_ratio")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("mode")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("aligned")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTRoIAlignCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTRoIAlignCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int outWidth = 7; - int outHeight = 7; - float spatialScale = 1.0; - int sampleRatio = 0; - int poolMode = -1; - bool aligned = true; - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("output_height") == 0) { - outHeight = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("output_width") == 0) { - outWidth = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("spatial_scale") == 0) { - spatialScale = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("sampling_ratio") == 0) { - sampleRatio = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("mode") == 0) { - int data_size = fc->fields[i].length; - ASSERT(data_size > 0); - const char *data_start = static_cast(fc->fields[i].data); - std::string pool_mode(data_start); - if (pool_mode == "avg") { - poolMode = 1; - } else if (pool_mode == "max") { - poolMode = 0; - } else { - std::cout << "Unknown pool mode \"" << pool_mode << "\"." << std::endl; - } - ASSERT(poolMode >= 0); - } - - if (field_name.compare("aligned") == 0) { - int aligned_int = static_cast(fc->fields[i].data)[0]; - aligned = aligned_int != 0; - } - } - - ASSERT(outHeight > 0); - ASSERT(outWidth > 0); - ASSERT(spatialScale > 0.); - ASSERT(poolMode >= 0); - - TRTRoIAlign *plugin = - new TRTRoIAlign(name, outWidth, outHeight, spatialScale, sampleRatio, poolMode, aligned); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTRoIAlignCreator::deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTRoIAlign(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(TRTRoIAlignCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVRoiAlign"}; + } // namespace + + TRTRoIAlign::TRTRoIAlign(const std::string& name, + int outWidth, + int outHeight, + float spatialScale, + int sampleRatio, + int poolMode, + bool aligned) + : TRTPluginBase(name) + , mOutWidth(outWidth) + , mOutHeight(outHeight) + , mSpatialScale(spatialScale) + , mSampleRatio(sampleRatio) + , mPoolMode(poolMode) + , mAligned(aligned) + { + } + + TRTRoIAlign::TRTRoIAlign(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mOutWidth); + deserialize_value(&data, &length, &mOutHeight); + deserialize_value(&data, &length, &mSpatialScale); + deserialize_value(&data, &length, &mSampleRatio); + deserialize_value(&data, &length, &mPoolMode); + deserialize_value(&data, &length, &mAligned); + } + + nvinfer1::IPluginV2DynamicExt* TRTRoIAlign::clone() const TRT_NOEXCEPT + { + TRTRoIAlign* plugin = new TRTRoIAlign(mLayerName, + mOutWidth, + mOutHeight, + mSpatialScale, + mSampleRatio, + mPoolMode, + mAligned); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[1].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mOutHeight); + ret.d[3] = exprBuilder.constant(mOutWidth); + + return ret; + } + + bool TRTRoIAlign::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + void TRTRoIAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + + size_t TRTRoIAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + size_t output_size = 0; + size_t word_size = 0; + switch (mPoolMode) + { + case 0: // max + output_size = outputs[0].dims.d[0] * outputs[0].dims.d[1] * outputs[0].dims.d[2] * outputs[0].dims.d[3]; + word_size = mmdeploy::getElementSize(outputs[0].type); + return output_size * word_size * 2; + break; + case 1: + return 0; + break; + default: + return 0; + } + return 0; + } + + int TRTRoIAlign::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + + int output_size = outputDesc[0].dims.d[0] * + outputDesc[0].dims.d[1] * + outputDesc[0].dims.d[2] * + outputDesc[0].dims.d[3]; + int word_size = mmdeploy::getElementSize(outputDesc[0].type); + + const void* feat = inputs[0]; + const void* rois = inputs[1]; + void* output = outputs[0]; + void* argmax_y = nullptr; + void* argmax_x = nullptr; + + switch (mPoolMode) + { + case 0: // max + argmax_y = workSpace; + argmax_x = (char*)argmax_y + output_size * word_size; + break; + case 1: // avg + break; + } + + switch (outputDesc[0].type) + { + case nvinfer1::DataType::kFLOAT: + TRTRoIAlignForwardCUDAKernelLauncher( + (const float*)feat, + (const float*)rois, + (float*)output, + (float*)argmax_y, + (float*)argmax_x, + output_size, + channels, + height, + width, + mOutHeight, + mOutWidth, + mSpatialScale, + mSampleRatio, + mPoolMode, + mAligned, + stream); + break; + + default: + break; + } + + return 0; + } + + nvinfer1::DataType TRTRoIAlign::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* TRTRoIAlign::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTRoIAlign::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTRoIAlign::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTRoIAlign::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mOutWidth) + serialized_size(mOutHeight) + serialized_size(mSpatialScale) + + serialized_size(mSampleRatio) + serialized_size(mPoolMode) + serialized_size(mAligned); + } + + void TRTRoIAlign::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mOutWidth); + serialize_value(&buffer, mOutHeight); + serialize_value(&buffer, mSpatialScale); + serialize_value(&buffer, mSampleRatio); + serialize_value(&buffer, mPoolMode); + serialize_value(&buffer, mAligned); + } + + TRTRoIAlignCreator::TRTRoIAlignCreator() + { + mPluginAttributes.emplace_back(nvinfer1::PluginField("output_height")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("output_width")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("spatial_scale")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("sampling_ratio")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("mode")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("aligned")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTRoIAlignCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTRoIAlignCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTRoIAlignCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int outWidth = 7; + int outHeight = 7; + float spatialScale = 1.0; + int sampleRatio = 0; + int poolMode = -1; + bool aligned = true; + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("output_height") == 0) + { + outHeight = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("output_width") == 0) + { + outWidth = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("spatial_scale") == 0) + { + spatialScale = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("sampling_ratio") == 0) + { + sampleRatio = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("mode") == 0) + { + int data_size = fc->fields[i].length; + ASSERT(data_size > 0); + const char* data_start = static_cast(fc->fields[i].data); + std::string pool_mode(data_start); + if (pool_mode == "avg") + { + poolMode = 1; + } + else if (pool_mode == "max") + { + poolMode = 0; + } + else + { + std::cout << "Unknown pool mode \"" << pool_mode << "\"." << std::endl; + } + ASSERT(poolMode >= 0); + } + + if (field_name.compare("aligned") == 0) + { + int aligned_int = static_cast(fc->fields[i].data)[0]; + aligned = aligned_int != 0; + } + } + + ASSERT(outHeight > 0); + ASSERT(outWidth > 0); + ASSERT(spatialScale > 0.); + ASSERT(poolMode >= 0); + + TRTRoIAlign* plugin = + new TRTRoIAlign(name, outWidth, outHeight, spatialScale, sampleRatio, poolMode, aligned); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTRoIAlignCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTRoIAlign(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(TRTRoIAlignCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp index cfc14758f7..605c1a4333 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp @@ -8,65 +8,91 @@ #include #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTRoIAlign : public TRTPluginBase { - public: - TRTRoIAlign(const std::string &name, int outWidth, int outHeight, float spatialScale, - int sampleRatio, int poolMode, bool aligned); - - TRTRoIAlign(const std::string name, const void *data, size_t length); - - TRTRoIAlign() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; - - private: - int mOutWidth; - int mOutHeight; - float mSpatialScale; - int mSampleRatio; - int mPoolMode; // 1:avg 0:max - bool mAligned; -}; - -class TRTRoIAlignCreator : public TRTPluginCreatorBase { - public: - TRTRoIAlignCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class TRTRoIAlign : public TRTPluginBase + { + public: + TRTRoIAlign(const std::string& name, + int outWidth, + int outHeight, + float spatialScale, + int sampleRatio, + int poolMode, + bool aligned); + + TRTRoIAlign(const std::string name, + const void* data, + size_t length); + + TRTRoIAlign() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + + private: + int mOutWidth; + int mOutHeight; + float mSpatialScale; + int mSampleRatio; + int mPoolMode; // 1:avg 0:max + bool mAligned; + }; + + class TRTRoIAlignCreator : public TRTPluginCreatorBase + { + public: + TRTRoIAlignCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_ROI_ALIGN_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu index 4e1a825d4f..a8ba93b5ad 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu @@ -4,104 +4,157 @@ #include "trt_roi_align_kernel.hpp" /*** Forward ***/ -template -__global__ void roi_align_forward_cuda_kernel(const int nthreads, const T* input, const T* rois, - T* output, T* argmax_y, T* argmax_x, - const int pooled_height, const int pooled_width, - const T spatial_scale, const int sampling_ratio, - const int pool_mode, // 0 - max pool, 1 - avg pool - const bool aligned, const int channels, - const int height, const int width) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - // (n, c, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; +template +__global__ void roi_align_forward_cuda_kernel(const int nthreads, + const T* input, + const T* rois, + T* output, + T* argmax_y, + T* argmax_x, + const int pooled_height, + const int pooled_width, + const T spatial_scale, + const int sampling_ratio, + const int pool_mode, // 0 - max pool, 1 - avg pool + const bool aligned, + const int channels, + const int height, + const int width) +{ + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { // for backward-compatibility only - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); - } + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) + { // for backward-compatibility only + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - const T* offset_input = input + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = input + (roi_batch_ind * channels + c) * height * width; - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = - (sampling_ratio > 0) ? sampling_ratio : static_cast(ceilf(roi_height / pooled_height)); - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : static_cast(ceilf(roi_width / pooled_width)); + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) ? sampling_ratio : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : static_cast(ceilf(roi_width / pooled_width)); - if (pool_mode == 0) { - // We do max pooling inside a bin - T maxval = -FLT_MAX; - T maxidx_y = -1.f, maxidx_x = -1.f; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - T val = bilinear_interpolate(offset_input, height, width, y, x); - if (val > maxval) { - maxval = val; - maxidx_y = y; - maxidx_x = x; - } + if (pool_mode == 0) + { + // We do max pooling inside a bin + T maxval = -FLT_MAX; + T maxidx_y = -1.f, maxidx_x = -1.f; + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x); + if (val > maxval) + { + maxval = val; + maxidx_y = y; + maxidx_x = x; + } + } + } + output[index] = maxval; + argmax_y[index] = maxidx_y; + argmax_x[index] = maxidx_x; } - } - output[index] = maxval; - argmax_y[index] = maxidx_y; - argmax_x[index] = maxidx_x; - } else if (pool_mode == 1) { - // We do average pooling inside a bin - const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); - T output_val = 0.; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - T val = bilinear_interpolate(offset_input, height, width, y, x); - output_val += val; + else if (pool_mode == 1) + { + // We do average pooling inside a bin + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x); + output_val += val; + } + } + output[index] = output_val / count; } - } - output[index] = output_val / count; } - } } -template -void TRTRoIAlignForwardCUDAKernelLauncher(const scalar_t* input, const scalar_t* rois, - scalar_t* output, scalar_t* argmax_y, scalar_t* argmax_x, - int output_size, int channels, int height, int width, - int aligned_height, int aligned_width, - scalar_t spatial_scale, int sampling_ratio, int pool_mode, - bool aligned, cudaStream_t stream) { - roi_align_forward_cuda_kernel - <<>>( - output_size, input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, - static_cast(spatial_scale), sampling_ratio, pool_mode, aligned, channels, - height, width); +template +void TRTRoIAlignForwardCUDAKernelLauncher(const scalar_t* input, + const scalar_t* rois, + scalar_t* output, + scalar_t* argmax_y, + scalar_t* argmax_x, + int output_size, + int channels, + int height, + int width, + int aligned_height, + int aligned_width, + scalar_t spatial_scale, + int sampling_ratio, + int pool_mode, + bool aligned, + cudaStream_t stream) +{ + roi_align_forward_cuda_kernel + <<>>(output_size, + input, + rois, + output, + argmax_y, + argmax_x, + aligned_height, + aligned_width, + static_cast(spatial_scale), + sampling_ratio, + pool_mode, + aligned, + channels, + height, + width); } -template void TRTRoIAlignForwardCUDAKernelLauncher( - const float* input, const float* rois, float* output, float* argmax_y, float* argmax_x, - int output_size, int channels, int height, int width, int aligned_height, int aligned_width, - float spatial_scale, int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream); +template void TRTRoIAlignForwardCUDAKernelLauncher(const float* input, + const float* rois, + float* output, + float* argmax_y, + float* argmax_x, + int output_size, + int channels, + int height, + int width, + int aligned_height, + int aligned_width, + float spatial_scale, + int sampling_ratio, + int pool_mode, + bool aligned, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp index 3db656bff9..38906636a4 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp @@ -4,12 +4,22 @@ #include "common_cuda_helper.hpp" -template -void TRTRoIAlignForwardCUDAKernelLauncher(const scalar_t* input, const scalar_t* rois, - scalar_t* output, scalar_t* argmax_y, scalar_t* argmax_x, - int output_size, int channels, int height, int width, - int aligned_height, int aligned_width, - scalar_t spatial_scale, int sampling_ratio, int pool_mode, - bool aligned, cudaStream_t stream); +template +void TRTRoIAlignForwardCUDAKernelLauncher(const scalar_t* input, + const scalar_t* rois, + scalar_t* output, + scalar_t* argmax_y, + scalar_t* argmax_x, + int output_size, + int channels, + int height, + int width, + int aligned_height, + int aligned_width, + scalar_t spatial_scale, + int sampling_ratio, + int pool_mode, + bool aligned, + cudaStream_t stream); #endif // ROI_ALIGN_CUDA_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp index a4ecb2356a..551c6ce996 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp @@ -10,174 +10,242 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"ScaledDotProductAttentionTRT"}; -} // namespace - -ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string &name) - : TRTPluginBase(name), mask_dim(0) {} - -ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name), mask_dim(0) {} - -ScaledDotProductAttentionTRT::~ScaledDotProductAttentionTRT() {} - -nvinfer1::IPluginV2DynamicExt *ScaledDotProductAttentionTRT::clone() const TRT_NOEXCEPT { - ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(mLayerName); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::DimsExprs ScaledDotProductAttentionTRT::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - if (outputIndex == 0) return inputs[0]; - nvinfer1::DimsExprs ret; - ret.nbDims = 3; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[0].d[1]; - ret.d[2] = inputs[1].d[1]; - - return ret; -} - -bool ScaledDotProductAttentionTRT::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -// Attach the plugin object to an execution context and grant the plugin the -// access to some context resource. -void ScaledDotProductAttentionTRT::attachToContext(cudnnContext *cudnnContext, - cublasContext *cublasContext, - IGpuAllocator *gpuAllocator) TRT_NOEXCEPT { - _cublas_handle = cublasContext; - _cudnn_handle = cudnnContext; - cudnnCreateTensorDescriptor(&_x_desc); - cudnnCreateTensorDescriptor(&_y_desc); - cudnnCreateTensorDescriptor(&_mask_desc); -} - -// Detach the plugin object from its execution context. -void ScaledDotProductAttentionTRT::detachFromContext() TRT_NOEXCEPT { - cudnnDestroyTensorDescriptor(_y_desc); - cudnnDestroyTensorDescriptor(_x_desc); - cudnnDestroyTensorDescriptor(_mask_desc); -} - -void ScaledDotProductAttentionTRT::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT { - if (nbInputs != 4) { - mask_dim = 0; - } else { - mask_dim = in[3].desc.dims.nbDims; - } -} - -int ScaledDotProductAttentionTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - if (CUDNN_STATUS_SUCCESS != cudnnSetStream(_cudnn_handle, stream)) return 1; - if (CUBLAS_STATUS_SUCCESS != cublasSetStream(_cublas_handle, stream)) return 1; - int B = inputDesc[0].dims.d[0]; // batch * heads - int Nt = inputDesc[0].dims.d[1]; - int Ns = inputDesc[1].dims.d[1]; - int E = inputDesc[0].dims.d[2]; // embeding size - - const void *query = inputs[0]; - const void *key = inputs[1]; - const void *value = inputs[2]; - const void *mask = nullptr; - - int mask_dims[3]; - mask_dims[0] = 0; - if (mask_dim > 0) { - mask = inputs[3]; - // check if mask need broadcast - if (mask_dim == 2) { - mask_dims[0] = 1; - mask_dims[1] = inputDesc[3].dims.d[0]; - mask_dims[2] = inputDesc[3].dims.d[1]; - } else { - mask_dims[0] = inputDesc[3].dims.d[0]; - mask_dims[1] = inputDesc[3].dims.d[1]; - mask_dims[2] = inputDesc[3].dims.d[2]; - } - } - - void *output = outputs[0]; - void *attn = outputs[1]; - - auto data_type = inputDesc[0].type; - cudnnDataType_t cudnn_dtype{}; - convert_trt2cudnn_dtype(data_type, &cudnn_dtype); - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - dot_product_attention_impl((float *)query, (float *)key, (float *)value, (float *)mask, - (float *)attn, (float *)output, B, Nt, Ns, E, &mask_dims[0], - _x_desc, _y_desc, _mask_desc, cudnn_dtype, stream, - _cublas_handle, _cudnn_handle); - break; - default: - return 1; - } - - return 0; -} - -nvinfer1::DataType ScaledDotProductAttentionTRT::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *ScaledDotProductAttentionTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *ScaledDotProductAttentionTRT::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int ScaledDotProductAttentionTRT::getNbOutputs() const TRT_NOEXCEPT { return 2; } - -size_t ScaledDotProductAttentionTRT::getSerializationSize() const TRT_NOEXCEPT { return 0; } - -void ScaledDotProductAttentionTRT::serialize(void *buffer) const TRT_NOEXCEPT {} - -////////////////////// creator ///////////////////////////// - -ScaledDotProductAttentionTRTCreator::ScaledDotProductAttentionTRTCreator() {} - -const char *ScaledDotProductAttentionTRTCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *ScaledDotProductAttentionTRTCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(name); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new ScaledDotProductAttentionTRT(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(ScaledDotProductAttentionTRTCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"ScaledDotProductAttentionTRT"}; + } // namespace + + ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string& name) + : TRTPluginBase(name) + , mask_dim(0) + { + } + + ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string name, + const void* data, + size_t length) + : TRTPluginBase(name) + , mask_dim(0) + { + } + + ScaledDotProductAttentionTRT::~ScaledDotProductAttentionTRT() {} + + nvinfer1::IPluginV2DynamicExt* ScaledDotProductAttentionTRT::clone() const TRT_NOEXCEPT + { + ScaledDotProductAttentionTRT* plugin = new ScaledDotProductAttentionTRT(mLayerName); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::DimsExprs ScaledDotProductAttentionTRT::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + if (outputIndex == 0) return inputs[0]; + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = inputs[1].d[1]; + + return ret; + } + + bool ScaledDotProductAttentionTRT::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + // Attach the plugin object to an execution context and grant the plugin the + // access to some context resource. + void ScaledDotProductAttentionTRT::attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + IGpuAllocator* gpuAllocator) TRT_NOEXCEPT + { + _cublas_handle = cublasContext; + _cudnn_handle = cudnnContext; + cudnnCreateTensorDescriptor(&_x_desc); + cudnnCreateTensorDescriptor(&_y_desc); + cudnnCreateTensorDescriptor(&_mask_desc); + } + + // Detach the plugin object from its execution context. + void ScaledDotProductAttentionTRT::detachFromContext() TRT_NOEXCEPT + { + cudnnDestroyTensorDescriptor(_y_desc); + cudnnDestroyTensorDescriptor(_x_desc); + cudnnDestroyTensorDescriptor(_mask_desc); + } + + void ScaledDotProductAttentionTRT::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT + { + if (nbInputs != 4) + { + mask_dim = 0; + } + else + { + mask_dim = in[3].desc.dims.nbDims; + } + } + + int ScaledDotProductAttentionTRT::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + if (CUDNN_STATUS_SUCCESS != cudnnSetStream(_cudnn_handle, stream)) return 1; + if (CUBLAS_STATUS_SUCCESS != cublasSetStream(_cublas_handle, stream)) return 1; + int B = inputDesc[0].dims.d[0]; // batch * heads + int Nt = inputDesc[0].dims.d[1]; + int Ns = inputDesc[1].dims.d[1]; + int E = inputDesc[0].dims.d[2]; // embeding size + + const void* query = inputs[0]; + const void* key = inputs[1]; + const void* value = inputs[2]; + const void* mask = nullptr; + + int mask_dims[3]; + mask_dims[0] = 0; + if (mask_dim > 0) + { + mask = inputs[3]; + // check if mask need broadcast + if (mask_dim == 2) + { + mask_dims[0] = 1; + mask_dims[1] = inputDesc[3].dims.d[0]; + mask_dims[2] = inputDesc[3].dims.d[1]; + } + else + { + mask_dims[0] = inputDesc[3].dims.d[0]; + mask_dims[1] = inputDesc[3].dims.d[1]; + mask_dims[2] = inputDesc[3].dims.d[2]; + } + } + + void* output = outputs[0]; + void* attn = outputs[1]; + + auto data_type = inputDesc[0].type; + cudnnDataType_t cudnn_dtype{}; + convert_trt2cudnn_dtype(data_type, &cudnn_dtype); + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + dot_product_attention_impl((float*)query, + (float*)key, + (float*)value, + (float*)mask, + (float*)attn, + (float*)output, + B, + Nt, + Ns, + E, + &mask_dims[0], + _x_desc, + _y_desc, + _mask_desc, + cudnn_dtype, + stream, + _cublas_handle, + _cudnn_handle); + break; + default: + return 1; + } + + return 0; + } + + nvinfer1::DataType ScaledDotProductAttentionTRT::getOutputDataType( + int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* ScaledDotProductAttentionTRT::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* ScaledDotProductAttentionTRT::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int ScaledDotProductAttentionTRT::getNbOutputs() const TRT_NOEXCEPT + { + return 2; + } + + size_t ScaledDotProductAttentionTRT::getSerializationSize() const TRT_NOEXCEPT + { + return 0; + } + + void ScaledDotProductAttentionTRT::serialize(void* buffer) const TRT_NOEXCEPT {} + + ////////////////////// creator ///////////////////////////// + + ScaledDotProductAttentionTRTCreator::ScaledDotProductAttentionTRTCreator() {} + + const char* ScaledDotProductAttentionTRTCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* ScaledDotProductAttentionTRTCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* ScaledDotProductAttentionTRTCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + ScaledDotProductAttentionTRT* plugin = new ScaledDotProductAttentionTRT(name); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* ScaledDotProductAttentionTRTCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new ScaledDotProductAttentionTRT(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(ScaledDotProductAttentionTRTCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp index 86d35616a9..4aea4c1e20 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp @@ -9,65 +9,86 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class ScaledDotProductAttentionTRT : public TRTPluginBase { - public: - ScaledDotProductAttentionTRT(const std::string &name); - - ScaledDotProductAttentionTRT(const std::string name, const void *data, size_t length); - - ScaledDotProductAttentionTRT() = delete; - - ~ScaledDotProductAttentionTRT() TRT_NOEXCEPT override; - - virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; - void attachToContext(cudnnContext *cudnn, cublasContext *cublas, - nvinfer1::IGpuAllocator *allocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; - - private: - int mask_dim; - cublasHandle_t _cublas_handle{}; - cudnnHandle_t _cudnn_handle{}; - cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _mask_desc{}; -}; - -class ScaledDotProductAttentionTRTCreator : public TRTPluginCreatorBase { - public: - ScaledDotProductAttentionTRTCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class ScaledDotProductAttentionTRT : public TRTPluginBase + { + public: + ScaledDotProductAttentionTRT(const std::string& name); + + ScaledDotProductAttentionTRT(const std::string name, + const void* data, + size_t length); + + ScaledDotProductAttentionTRT() = delete; + + ~ScaledDotProductAttentionTRT() TRT_NOEXCEPT override; + + virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + + void attachToContext(cudnnContext* cudnn, + cublasContext* cublas, + nvinfer1::IGpuAllocator* allocator) TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + private: + int mask_dim; + cublasHandle_t _cublas_handle{}; + cudnnHandle_t _cudnn_handle{}; + cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _mask_desc{}; + }; + + class ScaledDotProductAttentionTRTCreator : public TRTPluginCreatorBase + { + public: + ScaledDotProductAttentionTRTCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu index a0ee16c998..9775265b78 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu @@ -11,93 +11,228 @@ #include "scaled_dot_product_attention_kernel.hpp" #include "trt_plugin_helper.hpp" -template -cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const scalar_t* alpha, const scalar_t* A, int lda, - long long int strideA, const scalar_t* B, int ldb, - long long int strideB, const scalar_t* beta, - scalar_t* C, int ldc, long long int strideC, - int batchCount); +template +cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const scalar_t* alpha, + const scalar_t* A, + int lda, + long long int strideA, + const scalar_t* B, + int ldb, + long long int strideB, + const scalar_t* beta, + scalar_t* C, + int ldc, + long long int strideC, + int batchCount); -template <> -cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float* alpha, const float* A, int lda, - long long int strideA, const float* B, int ldb, - long long int strideB, const float* beta, - float* C, int ldc, long long int strideC, - int batchCount) { - return cublasSgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, - strideB, beta, C, ldc, strideC, batchCount); +template<> +cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* A, + int lda, + long long int strideA, + const float* B, + int ldb, + long long int strideB, + const float* beta, + float* C, + int ldc, + long long int strideC, + int batchCount) +{ + return cublasSgemmStridedBatched(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + lda, + strideA, + B, + ldb, + strideB, + beta, + C, + ldc, + strideC, + batchCount); } -template <> -cublasStatus_t cublasgemmStridedBatchedWrap<__half>(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const __half* alpha, const __half* A, int lda, - long long int strideA, const __half* B, int ldb, - long long int strideB, const __half* beta, - __half* C, int ldc, long long int strideC, - int batchCount) { - return cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, - strideB, beta, C, ldc, strideC, batchCount); +template<> +cublasStatus_t cublasgemmStridedBatchedWrap<__half>(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const __half* alpha, + const __half* A, + int lda, + long long int strideA, + const __half* B, + int ldb, + long long int strideB, + const __half* beta, + __half* C, + int ldc, + long long int strideC, + int batchCount) +{ + return cublasHgemmStridedBatched(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + lda, + strideA, + B, + ldb, + strideB, + beta, + C, + ldc, + strideC, + batchCount); } -template -void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value, - const scalar_t* mask, scalar_t* attn, scalar_t* output, int B, - int Nt, int Ns, int E, const int* mask_dims, - cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, - cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, - cudaStream_t stream, cublasHandle_t cublas_handle, - cudnnHandle_t cudnn_handle) { - { - // Q @ K - const int m = Ns; - const int n = Nt; - const int k = E; - const auto alpha = scalar_t(1.0f / sqrt(float(E))); - const auto beta = scalar_t(0); - cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, key, k, - Ns * E, query, k, Nt * E, &beta, attn, m, Nt * Ns, B); - } +template +void dot_product_attention_impl(const scalar_t* query, + const scalar_t* key, + const scalar_t* value, + const scalar_t* mask, + scalar_t* attn, + scalar_t* output, + int B, + int Nt, + int Ns, + int E, + const int* mask_dims, + cudnnTensorDescriptor_t& x_desc, + cudnnTensorDescriptor_t& y_desc, + cudnnTensorDescriptor_t& mask_desc, + cudnnDataType_t cudnn_dtype, + cudaStream_t stream, + cublasHandle_t cublas_handle, + cudnnHandle_t cudnn_handle) +{ + { + // Q @ K + const int m = Ns; + const int n = Nt; + const int k = E; + const auto alpha = scalar_t(1.0f / sqrt(float(E))); + const auto beta = scalar_t(0); + cublasgemmStridedBatchedWrap(cublas_handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + m, + n, + k, + &alpha, + key, + k, + Ns * E, + query, + k, + Nt * E, + &beta, + attn, + m, + Nt * Ns, + B); + } - if (mask_dims != nullptr && mask_dims[0] != 0) { - const auto alpha = scalar_t(1); - const auto beta = scalar_t(1); - cudnnSetTensor4dDescriptor(mask_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, mask_dims[0], - mask_dims[1], mask_dims[2]); - cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, B, Nt, Ns); - cudnnAddTensor(cudnn_handle, &alpha, mask_desc, mask, &beta, x_desc, attn); - } + if (mask_dims != nullptr && mask_dims[0] != 0) + { + const auto alpha = scalar_t(1); + const auto beta = scalar_t(1); + cudnnSetTensor4dDescriptor(mask_desc, + CUDNN_TENSOR_NCHW, + cudnn_dtype, + 1, + mask_dims[0], + mask_dims[1], + mask_dims[2]); + cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, B, Nt, Ns); + cudnnAddTensor(cudnn_handle, &alpha, mask_desc, mask, &beta, x_desc, attn); + } - { - // softmax attention - const auto alpha = scalar_t(1); - const auto beta = scalar_t(0); - cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); - cudnnSetTensor4dDescriptor(y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); - cudnnSoftmaxForward(cudnn_handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha, - x_desc, attn, &beta, y_desc, attn); - } + { + // softmax attention + const auto alpha = scalar_t(1); + const auto beta = scalar_t(0); + cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); + cudnnSetTensor4dDescriptor(y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); + cudnnSoftmaxForward(cudnn_handle, + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_INSTANCE, + &alpha, + x_desc, + attn, + &beta, + y_desc, + attn); + } - { - // attn @ v - const int m = E; - const int n = Nt; - const int k = Ns; - const auto alpha = scalar_t(1); - const auto beta = scalar_t(0); - cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, value, m, - Ns * E, (const scalar_t*)(attn), k, Ns * Nt, &beta, output, m, - Nt * E, B); - } + { + // attn @ v + const int m = E; + const int n = Nt; + const int k = Ns; + const auto alpha = scalar_t(1); + const auto beta = scalar_t(0); + cublasgemmStridedBatchedWrap(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + m, + n, + k, + &alpha, + value, + m, + Ns * E, + (const scalar_t*)(attn), + k, + Ns * Nt, + &beta, + output, + m, + Nt * E, + B); + } } -template void dot_product_attention_impl( - const float* query, const float* key, const float* value, const float* mask, float* attn, - float* output, int B, int Nt, int Ns, int E, const int* mask_dims, - cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, - cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, cudaStream_t stream, - cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle); +template void dot_product_attention_impl(const float* query, + const float* key, + const float* value, + const float* mask, + float* attn, + float* output, + int B, + int Nt, + int Ns, + int E, + const int* mask_dims, + cudnnTensorDescriptor_t& x_desc, + cudnnTensorDescriptor_t& y_desc, + cudnnTensorDescriptor_t& mask_desc, + cudnnDataType_t cudnn_dtype, + cudaStream_t stream, + cublasHandle_t cublas_handle, + cudnnHandle_t cudnn_handle); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp index d1cdc7773a..b11a341aa9 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp @@ -5,13 +5,24 @@ #include #include -template -void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value, - const scalar_t* mask, scalar_t* attn, scalar_t* output, int B, - int Nt, int Ns, int E, const int* mask_dims, - cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, - cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, - cudaStream_t stream, cublasHandle_t cublas_handle, - cudnnHandle_t cudnn_handle); +template +void dot_product_attention_impl(const scalar_t* query, + const scalar_t* key, + const scalar_t* value, + const scalar_t* mask, + scalar_t* attn, + scalar_t* output, + int B, + int Nt, + int Ns, + int E, + const int* mask_dims, + cudnnTensorDescriptor_t& x_desc, + cudnnTensorDescriptor_t& y_desc, + cudnnTensorDescriptor_t& mask_desc, + cudnnDataType_t cudnn_dtype, + cudaStream_t stream, + cublasHandle_t cublas_handle, + cudnnHandle_t cudnn_handle); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp index 13c637f408..ca5fe92dcc 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp @@ -2,155 +2,218 @@ #include "NvInferVersion.h" // ScatterND is supported since TensorRT8 #if NV_TENSORRT_MAJOR <= 7 -#include -#include - -#include - -#include "trt_scatternd.hpp" -#include "trt_scatternd_kernel.hpp" -#include "trt_serialize.hpp" - -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"ScatterND"}; -} // namespace - -TRTScatterND::TRTScatterND(const std::string &name) : TRTPluginBase(name) {} - -TRTScatterND::TRTScatterND(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) {} - -nvinfer1::IPluginV2DynamicExt *TRTScatterND::clone() const TRT_NOEXCEPT { - TRTScatterND *plugin = new TRTScatterND(mLayerName); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTScatterND::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - return inputs[0]; -} - -bool TRTScatterND::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos < nbInputs) { - switch (pos) { - case 0: - // data - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) || - (ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - case 1: - // indices - return ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - case 2: - // updates - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - default: - return true; + #include + #include + + #include + + #include "trt_scatternd.hpp" + #include "trt_scatternd_kernel.hpp" + #include "trt_serialize.hpp" + +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"ScatterND"}; + } // namespace + + TRTScatterND::TRTScatterND(const std::string& name) + : TRTPluginBase(name) + { + } + + TRTScatterND::TRTScatterND(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { } - } else { - switch (pos - nbInputs) { - case 0: - // output - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - default: + + nvinfer1::IPluginV2DynamicExt* TRTScatterND::clone() const TRT_NOEXCEPT + { + TRTScatterND* plugin = new TRTScatterND(mLayerName); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTScatterND::getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + return inputs[0]; + } + + bool TRTScatterND::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos < nbInputs) + { + switch (pos) + { + case 0: + // data + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) || + (ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + case 1: + // indices + return ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + case 2: + // updates + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + default: + return true; + } + } + else + { + switch (pos - nbInputs) + { + case 0: + // output + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + default: + return true; + } + } return true; } - } - return true; -} - -void TRTScatterND::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t TRTScatterND::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTScatterND::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - const int *dims = &(inputDesc[0].dims.d[0]); - const int *indices_dims = &(inputDesc[1].dims.d[0]); - int nbDims = inputDesc[0].dims.nbDims; - int indice_nbDims = inputDesc[1].dims.nbDims; - - const void *data = inputs[0]; - const void *indices = inputs[1]; - const void *update = inputs[2]; - void *output = outputs[0]; - - auto data_type = inputDesc[0].type; - - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - TRTONNXScatterNDKernelLauncher((float *)data, (int *)indices, (float *)update, dims, - nbDims, indices_dims, indice_nbDims, (float *)output, - stream); - break; - - case nvinfer1::DataType::kINT32: - TRTONNXScatterNDKernelLauncher((int *)data, (int *)indices, (int *)update, dims, nbDims, - indices_dims, indice_nbDims, (int *)output, stream); - break; - default: - break; - } - - return 0; -} - -nvinfer1::DataType TRTScatterND::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *TRTScatterND::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTScatterND::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTScatterND::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTScatterND::getSerializationSize() const TRT_NOEXCEPT { return 0; } - -void TRTScatterND::serialize(void *buffer) const TRT_NOEXCEPT {} - -TRTScatterNDCreator::TRTScatterNDCreator() { - mPluginAttributes.clear(); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTScatterNDCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTScatterNDCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *TRTScatterNDCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - TRTScatterND *plugin = new TRTScatterND(name); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTScatterNDCreator::deserializePlugin(const char *name, - const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTScatterND(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTScatterNDCreator); + + void TRTScatterND::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + + size_t TRTScatterND::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTScatterND::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + const int* dims = &(inputDesc[0].dims.d[0]); + const int* indices_dims = &(inputDesc[1].dims.d[0]); + int nbDims = inputDesc[0].dims.nbDims; + int indice_nbDims = inputDesc[1].dims.nbDims; + + const void* data = inputs[0]; + const void* indices = inputs[1]; + const void* update = inputs[2]; + void* output = outputs[0]; + + auto data_type = inputDesc[0].type; + + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + TRTONNXScatterNDKernelLauncher((float*)data, + (int*)indices, + (float*)update, + dims, + nbDims, + indices_dims, + indice_nbDims, + (float*)output, + stream); + break; + + case nvinfer1::DataType::kINT32: + TRTONNXScatterNDKernelLauncher((int*)data, + (int*)indices, + (int*)update, + dims, + nbDims, + indices_dims, + indice_nbDims, + (int*)output, + stream); + break; + default: + break; + } + + return 0; + } + + nvinfer1::DataType TRTScatterND::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* TRTScatterND::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTScatterND::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTScatterND::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTScatterND::getSerializationSize() const TRT_NOEXCEPT + { + return 0; + } + + void TRTScatterND::serialize(void* buffer) const TRT_NOEXCEPT {} + + TRTScatterNDCreator::TRTScatterNDCreator() + { + mPluginAttributes.clear(); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTScatterNDCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTScatterNDCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTScatterNDCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + TRTScatterND* plugin = new TRTScatterND(name); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTScatterNDCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTScatterND(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTScatterNDCreator); } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp index d6b859855e..6afbbe450e 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp @@ -9,56 +9,77 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTScatterND : public TRTPluginBase { - public: - TRTScatterND(const std::string &name); - - TRTScatterND(const std::string name, const void *data, size_t length); - - TRTScatterND() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; -}; - -class TRTScatterNDCreator : public TRTPluginCreatorBase { - public: - TRTScatterNDCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class TRTScatterND : public TRTPluginBase + { + public: + TRTScatterND(const std::string& name); + + TRTScatterND(const std::string name, + const void* data, + size_t length); + + TRTScatterND() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + }; + + class TRTScatterNDCreator : public TRTPluginCreatorBase + { + public: + TRTScatterNDCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_SCATTERND_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu index c763992e9f..cd5a235afa 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu @@ -8,68 +8,98 @@ using mmdeploy::TensorDesc; -template -__global__ void onnx_scatternd_kernel(const int n, const int* indices, const T* update, T* output, - TensorDesc tensor_desc, TensorDesc indice_desc) { - const int indice_cols = indice_desc.shape[indice_desc.dim - 1]; - const int copy_stride = tensor_desc.stride[indice_cols - 1]; - const int* stride = &(tensor_desc.stride[0]); - CUDA_1D_KERNEL_LOOP(index, n) { - int output_offset = 0; - const int* indices_current = indices + index * indice_cols; - for (int i = 0; i < indice_cols; ++i) { - output_offset += stride[i] * indices_current[i]; +template +__global__ void onnx_scatternd_kernel(const int n, + const int* indices, + const T* update, + T* output, + TensorDesc tensor_desc, + TensorDesc indice_desc) +{ + const int indice_cols = indice_desc.shape[indice_desc.dim - 1]; + const int copy_stride = tensor_desc.stride[indice_cols - 1]; + const int* stride = &(tensor_desc.stride[0]); + CUDA_1D_KERNEL_LOOP(index, n) + { + int output_offset = 0; + const int* indices_current = indices + index * indice_cols; + for (int i = 0; i < indice_cols; ++i) + { + output_offset += stride[i] * indices_current[i]; + } + memcpy(output + output_offset, update + index * copy_stride, copy_stride * sizeof(T)); } - memcpy(output + output_offset, update + index * copy_stride, copy_stride * sizeof(T)); - } } -template -void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, const T* update, - const int* dims, int nbDims, const int* indices_dims, - int indice_nbDims, T* output, cudaStream_t stream) { - // fill tensordesc and initial - TensorDesc tensor_desc; - memset((void*)&tensor_desc, 0, sizeof(TensorDesc)); - tensor_desc.dim = nbDims; - tensor_desc.shape[nbDims - 1] = dims[nbDims - 1]; - tensor_desc.stride[nbDims - 1] = 1; - for (int i = nbDims - 2; i >= 0; --i) { - tensor_desc.shape[i] = dims[i]; - tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1]; - } - const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0]; +template +void TRTONNXScatterNDKernelLauncher(const T* data, + const int* indices, + const T* update, + const int* dims, + int nbDims, + const int* indices_dims, + int indice_nbDims, + T* output, + cudaStream_t stream) +{ + // fill tensordesc and initial + TensorDesc tensor_desc; + memset((void*)&tensor_desc, 0, sizeof(TensorDesc)); + tensor_desc.dim = nbDims; + tensor_desc.shape[nbDims - 1] = dims[nbDims - 1]; + tensor_desc.stride[nbDims - 1] = 1; + for (int i = nbDims - 2; i >= 0; --i) + { + tensor_desc.shape[i] = dims[i]; + tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1]; + } + const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0]; - TensorDesc indice_desc; - memset((void*)&indice_desc, 0, sizeof(TensorDesc)); - indice_desc.dim = indice_nbDims; - indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1]; - indice_desc.stride[indice_nbDims - 1] = 1; - for (int i = indice_nbDims - 2; i >= 0; --i) { - indice_desc.shape[i] = indices_dims[i]; - indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1]; - } + TensorDesc indice_desc; + memset((void*)&indice_desc, 0, sizeof(TensorDesc)); + indice_desc.dim = indice_nbDims; + indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1]; + indice_desc.stride[indice_nbDims - 1] = 1; + for (int i = indice_nbDims - 2; i >= 0; --i) + { + indice_desc.shape[i] = indices_dims[i]; + indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1]; + } - // output = np.copy(data) - cudaMemcpyAsync(output, data, data_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + // output = np.copy(data) + cudaMemcpyAsync(output, data, data_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - int num_update_indice = 1; - for (int i = 0; i < indice_nbDims - 1; ++i) { - num_update_indice *= indice_desc.shape[i]; - } - // scatter - const int col_block = DIVUP(num_update_indice, THREADS_PER_BLOCK); - onnx_scatternd_kernel<<>>( - num_update_indice, indices, update, output, tensor_desc, indice_desc); + int num_update_indice = 1; + for (int i = 0; i < indice_nbDims - 1; ++i) + { + num_update_indice *= indice_desc.shape[i]; + } + // scatter + const int col_block = DIVUP(num_update_indice, THREADS_PER_BLOCK); + onnx_scatternd_kernel<<>>(num_update_indice, + indices, + update, + output, + tensor_desc, + indice_desc); } -template void TRTONNXScatterNDKernelLauncher(const float* data, const int* indices, - const float* update, const int* dims, - int nbDims, const int* indices_dims, - int indice_nbDims, float* output, +template void TRTONNXScatterNDKernelLauncher(const float* data, + const int* indices, + const float* update, + const int* dims, + int nbDims, + const int* indices_dims, + int indice_nbDims, + float* output, cudaStream_t stream); -template void TRTONNXScatterNDKernelLauncher(const int* data, const int* indices, - const int* update, const int* dims, int nbDims, - const int* indices_dims, int indice_nbDims, - int* output, cudaStream_t stream); +template void TRTONNXScatterNDKernelLauncher(const int* data, + const int* indices, + const int* update, + const int* dims, + int nbDims, + const int* indices_dims, + int indice_nbDims, + int* output, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp index b64b66494d..093ccda4f0 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp @@ -3,9 +3,15 @@ #define TRT_SCATTERND_KERNEL_HPP #include -template -void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, const T* update, - const int* dims, int nbDims, const int* indices_dims, - int indice_nbDims, T* output, cudaStream_t stream); +template +void TRTONNXScatterNDKernelLauncher(const T* data, + const int* indices, + const T* update, + const int* dims, + int nbDims, + const int* indices_dims, + int indice_nbDims, + T* output, + cudaStream_t stream); #endif // TRT_SCATTERND_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/torchscript/ops/CMakeLists.txt b/csrc/mmdeploy/backend_ops/torchscript/ops/CMakeLists.txt index 4a6120d0f8..91e0254570 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/ops/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/torchscript/ops/CMakeLists.txt @@ -1,41 +1,48 @@ # Copyright (c) OpenMMLab. All rights reserved. if("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES) - project(mmdeploy_torchscript_ops CUDA CXX) - file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp *.cu) + project(mmdeploy_torchscript_ops CUDA CXX) + file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp *.cu) else() - project(mmdeploy_torchscript_ops CXX) - file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp) + project(mmdeploy_torchscript_ops CXX) + file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp) endif() find_package(Torch REQUIRED) if(MSVC) - # workaround to fix building torchscript ops on windows - set(_TORCH_TARGET torch_cuda_cu torch_cuda_cpp torch_cpu) - foreach(_target IN LISTS _TORCH_TARGET) - if(TARGET ${_target}) - get_property(FIXED_TORCH_CPU_COMPILE_OPTIONS TARGET ${_target} PROPERTY INTERFACE_COMPILE_OPTIONS) - string(REPLACE ";" " " FIXED_TORCH_CPU_COMPILE_OPTIONS "${FIXED_TORCH_CPU_COMPILE_OPTIONS}") - set_property(TARGET ${_target} PROPERTY INTERFACE_COMPILE_OPTIONS -Xcompiler "${FIXED_TORCH_CPU_COMPILE_OPTIONS}") - else() - message(WARNING "Target ${_target} not found.") - endif() - endforeach() + # workaround to fix building torchscript ops on windows + set(_TORCH_TARGET torch_cuda_cu torch_cuda_cpp torch_cpu) + foreach(_target IN LISTS _TORCH_TARGET) + if(TARGET ${_target}) + get_property( + FIXED_TORCH_CPU_COMPILE_OPTIONS + TARGET ${_target} + PROPERTY INTERFACE_COMPILE_OPTIONS) + string(REPLACE ";" " " FIXED_TORCH_CPU_COMPILE_OPTIONS + "${FIXED_TORCH_CPU_COMPILE_OPTIONS}") + set_property( + TARGET ${_target} PROPERTY INTERFACE_COMPILE_OPTIONS -Xcompiler + "${FIXED_TORCH_CPU_COMPILE_OPTIONS}") + else() + message(WARNING "Target ${_target} not found.") + endif() + endforeach() endif() add_library(${PROJECT_NAME}_obj OBJECT "${BACKEND_OPS_SRCS}") -set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE 1) +set_target_properties(${PROJECT_NAME}_obj PROPERTIES POSITION_INDEPENDENT_CODE + 1) target_compile_definitions(${PROJECT_NAME}_obj - PRIVATE -DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT=1) + PRIVATE -DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT=1) target_include_directories(${PROJECT_NAME}_obj - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../common) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../common) target_include_directories(${PROJECT_NAME}_obj - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/common) + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/common) if("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES) - target_include_directories(${PROJECT_NAME}_obj - PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/include) + target_include_directories(${PROJECT_NAME}_obj + PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/include) endif() target_link_libraries(${PROJECT_NAME}_obj PRIVATE ${TORCH_LIBRARIES}) mmdeploy_export(${PROJECT_NAME}_obj) diff --git a/csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp b/csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp index f236ac9b66..777b2b1eed 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp @@ -1,13 +1,14 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "torch/script.h" -TORCH_LIBRARY(mmdeploy, m) { - m.def( - "modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor offset, Tensor " - "mask, " - "int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int " - "dilation_h,int dilation_w, int groups, int deform_groups, bool with_bias) -> Tensor") - .def( - "coreml_nms(Tensor boxes, Tensor scores, float iou_threshold, " - "float score_threshold, int max_boxes) -> Tensor[]"); +TORCH_LIBRARY(mmdeploy, m) +{ + m.def( + "modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor offset, Tensor " + "mask, " + "int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int " + "dilation_h,int dilation_w, int groups, int deform_groups, bool with_bias) -> Tensor") + .def( + "coreml_nms(Tensor boxes, Tensor scores, float iou_threshold, " + "float score_threshold, int max_boxes) -> Tensor[]"); } diff --git a/csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp b/csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp index a78b701349..f83a0ec313 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp @@ -4,28 +4,36 @@ #include #include "torch/script.h" -namespace mmdeploy { - -using at::Tensor; - -std::vector coreml_nms_cpu(Tensor boxes, Tensor scores, double iou_threshold, - double score_threshold, int64_t max_boxes) { - assert(boxes.dim() == 3); // bboxes with shape (batch_size, num_bboxes, 4) - assert(boxes.size(2) == 4); - assert(boxes.size(0) == scores.size(0)); // check batch size - assert(boxes.size(1) == scores.size(1)); // check num boxes - - auto batch_size = boxes.size(0); - auto num_boxes = boxes.size(1); - auto num_classes = scores.size(2); - - Tensor ret_boxes = at::zeros({batch_size, max_boxes, 4}); - Tensor ret_scores = at::zeros({batch_size, max_boxes, num_classes}); - Tensor indices = at::zeros({batch_size, max_boxes}, at::kInt); - Tensor num_outputs = at::zeros({batch_size}, at::kInt); - - return std::vector({ret_boxes, ret_scores, indices, num_outputs}); -} - -TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) { m.impl("coreml_nms", coreml_nms_cpu); } +namespace mmdeploy +{ + + using at::Tensor; + + std::vector coreml_nms_cpu(Tensor boxes, + Tensor scores, + double iou_threshold, + double score_threshold, + int64_t max_boxes) + { + assert(boxes.dim() == 3); // bboxes with shape (batch_size, num_bboxes, 4) + assert(boxes.size(2) == 4); + assert(boxes.size(0) == scores.size(0)); // check batch size + assert(boxes.size(1) == scores.size(1)); // check num boxes + + auto batch_size = boxes.size(0); + auto num_boxes = boxes.size(1); + auto num_classes = scores.size(2); + + Tensor ret_boxes = at::zeros({batch_size, max_boxes, 4}); + Tensor ret_scores = at::zeros({batch_size, max_boxes, num_classes}); + Tensor indices = at::zeros({batch_size, max_boxes}, at::kInt); + Tensor num_outputs = at::zeros({batch_size}, at::kInt); + + return std::vector({ret_boxes, ret_scores, indices, num_outputs}); + } + + TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) + { + m.impl("coreml_nms", coreml_nms_cpu); + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp b/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp index c6d980919f..3a9b32e83b 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp @@ -3,92 +3,133 @@ #include "torch/script.h" -namespace mmdeploy { - -void modulated_deformable_im2col_cpu( - const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, - const int64_t batch_size, const int64_t channels, const int64_t height_im, - const int64_t width_im, const int64_t height_col, const int64_t width_col, - const int64_t kernel_h, const int64_t kernel_w, const int64_t pad_h, const int64_t pad_w, - const int64_t stride_h, const int64_t stride_w, const int64_t dilation_h, - const int64_t dilation_w, int64_t deformable_group, at::Tensor data_col) { - // num_axes should be smaller than block size - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_im.scalar_type(), "modulated_deformable_im2col_cpu", ([&] { - const scalar_t *data_im_ = data_im.data_ptr(); - const scalar_t *data_offset_ = data_offset.data_ptr(); - const scalar_t *data_mask_ = data_mask.data_ptr(); - scalar_t *data_col_ = data_col.data_ptr(); - - deformable_im2col_2d(data_im_, data_offset_, data_mask_, height_im, width_im, - kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, channels, deformable_group, - height_col, width_col, data_mask_ != nullptr, data_col_); - })); -} - -at::Tensor modulated_deform_conv_forward_cpu(at::Tensor input, at::Tensor weight, at::Tensor bias, - at::Tensor offset, at::Tensor mask, int64_t kernel_h, - int64_t kernel_w, int64_t stride_h, int64_t stride_w, - int64_t pad_h, int64_t pad_w, int64_t dilation_h, - int64_t dilation_w, int64_t group, - int64_t deformable_group, bool with_bias) { - at::DeviceGuard guard(input.device()); - - const int batch = input.size(0); - const int channels = input.size(1); - const int height = input.size(2); - const int width = input.size(3); - - const int channels_out = weight.size(0); - const int channels_kernel = weight.size(1); - const int kernel_h_ = weight.size(2); - const int kernel_w_ = weight.size(3); - - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, - kernel_h_, kernel_w_); - if (channels != channels_kernel * group) - AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels, - channels_kernel * group); - - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - // resize output - at::Tensor output = - at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options()); - // resize temporary columns - at::Tensor columns = at::zeros( - {group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, input.options()); - - // divide into group - weight = - weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); - for (int b = 0; b < batch; b++) { - modulated_deformable_im2col_cpu(input[b], offset[b], mask[b], 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, - stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns); - - for (int g = 0; g < group; g++) { - output[b][g] = - output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]); +namespace mmdeploy +{ + + void modulated_deformable_im2col_cpu(const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int64_t batch_size, + const int64_t channels, + const int64_t height_im, + const int64_t width_im, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + int64_t deformable_group, + at::Tensor data_col) + { + // num_axes should be smaller than block size + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(data_im.scalar_type(), + "modulated_deformable_im2col_cpu", + ([&] + { + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + const scalar_t* data_mask_ = data_mask.data_ptr(); + scalar_t* data_col_ = data_col.data_ptr(); + + deformable_im2col_2d(data_im_, + data_offset_, + data_mask_, + height_im, + width_im, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channels, + deformable_group, + height_col, + width_col, + data_mask_ != nullptr, + data_col_); })); } - } - output = output.view( - {output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)}); - - if (with_bias) { - output += bias.view({1, bias.size(0), 1, 1}); - } - - return output; -} + at::Tensor modulated_deform_conv_forward_cpu(at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor offset, + at::Tensor mask, + int64_t kernel_h, + int64_t kernel_w, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t group, + int64_t deformable_group, + bool with_bias) + { + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels, channels_kernel * group); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + // resize output + at::Tensor output = + at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options()); + // resize temporary columns + at::Tensor columns = at::zeros( + {group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, + input.options()); + + // divide into group + weight = + weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); + for (int b = 0; b < batch; b++) + { + modulated_deformable_im2col_cpu(input[b], offset[b], mask[b], 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns); + + for (int g = 0; g < group; g++) + { + output[b][g] = + output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]); + } + } + + output = output.view( + {output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)}); + + if (with_bias) + { + output += bias.view({1, bias.size(0), 1, 1}); + } + + return output; + } -TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) { - m.impl("modulated_deform_conv", modulated_deform_conv_forward_cpu); -} + TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) + { + m.impl("modulated_deform_conv", modulated_deform_conv_forward_cpu); + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu b/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu index 3f9b6aef08..53cb5fd65c 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu +++ b/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu @@ -3,95 +3,157 @@ #include "modulated_deform_conv/modulated_deform_conv_cuda.cuh" #include "torch/script.h" -namespace mmdeploy { - -void modulated_deformable_im2col_cuda( - const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, - const int64_t batch_size, const int64_t channels, const int64_t height_im, - const int64_t width_im, const int64_t height_col, const int64_t width_col, - const int64_t kernel_h, const int64_t kernel_w, const int64_t pad_h, const int64_t pad_w, - const int64_t stride_h, const int64_t stride_w, const int64_t dilation_h, - const int64_t dilation_w, const int64_t deformable_group, at::Tensor data_col) { - // num_axes should be smaller than block size - const int channel_per_deformable_group = channels / deformable_group; - const int num_kernels = channels * batch_size * height_col * width_col; - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_im.scalar_type(), "modulated_deformable_im2col_cuda", ([&] { - const scalar_t *data_im_ = data_im.data_ptr(); - const scalar_t *data_offset_ = data_offset.data_ptr(); - const scalar_t *data_mask_ = data_mask.data_ptr(); - scalar_t *data_col_ = data_col.data_ptr(); - modulated_deformable_im2col_gpu_kernel - <<>>( - num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, - kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - channel_per_deformable_group, batch_size, channels, deformable_group, height_col, - width_col, data_col_); - })); -} - -at::Tensor modulated_deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, at::Tensor bias, - at::Tensor offset, at::Tensor mask, int64_t kernel_h, - int64_t kernel_w, int64_t stride_h, int64_t stride_w, - int64_t pad_h, int64_t pad_w, int64_t dilation_h, - int64_t dilation_w, int64_t group, - int64_t deformable_group, bool with_bias) { - at::DeviceGuard guard(input.device()); - - const int batch = input.size(0); - const int channels = input.size(1); - const int height = input.size(2); - const int width = input.size(3); - - const int channels_out = weight.size(0); - const int channels_kernel = weight.size(1); - const int kernel_h_ = weight.size(2); - const int kernel_w_ = weight.size(3); - - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, - kernel_h_, kernel_w_); - if (channels != channels_kernel * group) - AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels, - channels_kernel * group); - - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - // resize output - at::Tensor output = - at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options()); - // resize temporary columns - at::Tensor columns = at::zeros( - {group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, input.options()); - - // divide into group - weight = - weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); - for (int b = 0; b < batch; b++) { - modulated_deformable_im2col_cuda(input[b], offset[b], mask[b], 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, - stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns); - - for (int g = 0; g < group; g++) { - output[b][g] = - output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]); +namespace mmdeploy +{ + + void modulated_deformable_im2col_cuda(const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int64_t batch_size, + const int64_t channels, + const int64_t height_im, + const int64_t width_im, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t deformable_group, + at::Tensor data_col) + { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), + "modulated_deformable_im2col_cuda", + ([&] + { + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + const scalar_t* data_mask_ = data_mask.data_ptr(); + scalar_t* data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel + <<>>(num_kernels, + data_im_, + data_offset_, + data_mask_, + height_im, + width_im, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + batch_size, + channels, + deformable_group, + height_col, + width_col, + data_col_); })); } - } - output = output.view( - {output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)}); + at::Tensor modulated_deform_conv_forward_cuda(at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor offset, + at::Tensor mask, + int64_t kernel_h, + int64_t kernel_w, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t group, + int64_t deformable_group, + bool with_bias) + { + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels, channels_kernel * group); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - if (with_bias) { - output += bias.view({1, bias.size(0), 1, 1}); - } + // resize output + at::Tensor output = + at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options()); + // resize temporary columns + at::Tensor columns = at::zeros( + {group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, + input.options()); - return output; -} + // divide into group + weight = + weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); + for (int b = 0; b < batch; b++) + { + modulated_deformable_im2col_cuda(input[b], + offset[b], + mask[b], + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + columns); -TORCH_LIBRARY_IMPL(mmdeploy, CUDA, m) { - m.impl("modulated_deform_conv", modulated_deform_conv_forward_cuda); -} + for (int g = 0; g < group; g++) + { + output[b][g] = + output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]); + } + } + + output = output.view( + {output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)}); + + if (with_bias) + { + output += bias.view({1, bias.size(0), 1, 1}); + } + + return output; + } + + TORCH_LIBRARY_IMPL(mmdeploy, CUDA, m) + { + m.impl("modulated_deform_conv", modulated_deform_conv_forward_cuda); + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/CMakeLists.txt b/csrc/mmdeploy/backend_ops/torchscript/optimizer/CMakeLists.txt index 1b5e75ccca..c528972177 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/CMakeLists.txt @@ -3,16 +3,18 @@ project(ts_optimizer) find_package(Torch REQUIRED) -find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") -if (NOT TARGET pybind11) - add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11) -endif () +find_library(TORCH_PYTHON_LIBRARY torch_python + PATHS "${TORCH_INSTALL_PREFIX}/lib") +if(NOT TARGET pybind11) + add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11) +endif() file(GLOB_RECURSE OPTIMIZER_SRCS *.cpp) pybind11_add_module(${PROJECT_NAME} ${OPTIMIZER_SRCS}) -target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) +target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} + ${TORCH_PYTHON_LIBRARY}) target_link_directories(${PROJECT_NAME} PRIVATE mmdeploy::torchscript_ops) set_target_properties( - ${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY - ${CMAKE_SOURCE_DIR}/mmdeploy/backend/torchscript) + ${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY + ${CMAKE_SOURCE_DIR}/mmdeploy/backend/torchscript) diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp index 3b8bb0f632..49d3d8930a 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp @@ -12,36 +12,45 @@ #include "passes/onnx/merge_shape_concate.h" #include "passes/onnx/onnx_peephole.h" -namespace mmdeploy { -namespace torch_jit { +namespace mmdeploy +{ + namespace torch_jit + { -void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript", - const std::string& backend = "torchscript") { - if (ir == "torchscript") { - model = optimize_for_torchscript(model); - } else if (ir == "onnx") { - model = optimize_for_onnx(model); - } else { - fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(), - backend.c_str()); - exit(-1); - } -} + void optimize_for_backend(torch::jit::Module& model, + const std::string& ir = "torchscript", + const std::string& backend = "torchscript") + { + if (ir == "torchscript") + { + model = optimize_for_torchscript(model); + } + else if (ir == "onnx") + { + model = optimize_for_onnx(model); + } + else + { + fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(), backend.c_str()); + exit(-1); + } + } -PYBIND11_MODULE(ts_optimizer, m) { - namespace py = pybind11; - m.def("optimize_for_backend", optimize_for_backend, py::arg("module"), - py::arg("ir") = std::string("torchscript"), - py::arg("backend") = std::string("torchscript")); - py::module_ onnx_module = m.def_submodule("onnx"); - onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph")); - onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph")); - onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph")); - onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"), - py::arg("params")); - onnx_module.def("_jit_pass_common_subgraph_elimination", CommonSubgraphElimination, - py::arg("graph"), py::arg("params")); -} + PYBIND11_MODULE(ts_optimizer, m) + { + namespace py = pybind11; + m.def("optimize_for_backend", + optimize_for_backend, + py::arg("module"), + py::arg("ir") = std::string("torchscript"), + py::arg("backend") = std::string("torchscript")); + py::module_ onnx_module = m.def_submodule("onnx"); + onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph")); + onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph")); + onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph")); + onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"), py::arg("params")); + onnx_module.def("_jit_pass_common_subgraph_elimination", CommonSubgraphElimination, py::arg("graph"), py::arg("params")); + } -} // namespace torch_jit + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp index 10ce9829d5..4834f1d3d5 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp @@ -8,306 +8,399 @@ #include #include -namespace mmdeploy { -namespace torch_jit { - -using torch::jit::AttributeKind; -using torch::jit::ClassType; -using torch::jit::Node; -using torch::jit::Symbol; -using torch::jit::Value; - -namespace prim { -using namespace ::c10::prim; -} - -namespace attr { -using namespace ::c10::attr; -} - -/** - * \brief A class implementing an API for comparing subgraphs. - */ -class SubgraphMatcher::SubgraphMatcherImpl { - public: - explicit SubgraphMatcherImpl(const Graph& pattern, MatchAttribute match_attribute) - : pattern_(pattern), match_attribute_(match_attribute) {} - - /** - * \brief Compare matchGraph with the part of the graph denoted by a node \p - * ANCHOR. - * - * The anchor node would be compared against the deepest node in the - * match-graph. A node is considered matching if its number of inputs/outputs - * is the same as in the corresponding matchGraph node, its type is the same, - * and all nodes producing input-values also match. - */ - bool matchesSubgraphFromAnchorNode(Node* anchor); - - /** \brief Return match map for nodes. */ - std::unordered_map nodes_map() const { return nodes_map_; } - - /** \brief Return match map for values. */ - std::unordered_map values_map() const { return values_map_; } - - private: - bool matchValues(const Value* v1, Value* v2); - bool matchNodes(const Node* n1, Node* n2); - bool matchAttributes(const Node* n1, Node* n2); - - static bool isInput(const Value* v); - static bool isOutput(const Value* v); - - std::unordered_map nodes_map_; - std::unordered_map values_map_; - - const MatchAttribute match_attribute_; - const Graph& pattern_; - const Node* anchor_ = nullptr; -}; - -bool SubgraphMatcher::SubgraphMatcherImpl::isInput(const Value* v) { - return v->node()->kind() == prim::Param; -} - -bool SubgraphMatcher::SubgraphMatcherImpl::isOutput(const Value* v) { - for (const Value* output : v->owningGraph()->outputs()) { - if (v == output) { - return true; - } - } - return false; -} - -/** - * Compare two Values. V1 is from pattern, V2 is from the actual graph. - * - * The values are considered matching if: - * 1) the nodes defining them match - * 2) they have the same number of uses, except they are entry or exit nodes. - */ -bool SubgraphMatcher::SubgraphMatcherImpl::matchValues(const Value* v1, Value* v2) { - // Check if we've already visited these values. - if (values_map_.count(v1)) { - if (values_map_.at(v1) != v2) { - GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), - " did not match because %", v1->debugName(), " has already been matched with %", - values_map_.at(v1)->debugName(), ".\n"); - return false; - } - return true; - } - - // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is - // PARAM, we're comparing entering values - in these two cases the number of - // uses don't need to be the same. - if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) { - GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), - " did not match because number of their uses is different.\n"); - return false; - } - - // Add the values to the map before calling matchNodes to avoid infinite - // recursion. - GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n"); - values_map_[v1] = v2; - return matchNodes(v1->node(), v2->node()); -} - -bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node* n2) { - if (match_attribute_ == FORCE_MATCH && n1->numAttributes() != n2->numAttributes()) { - GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2); - return false; - } - for (const Symbol& attr_name : n1->attributeNames()) { - if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) { - GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(), - "' did not match:\n", *n1, *n2); - return false; - } - std::vector n1is, n2is; - std::vector n1fs, n2fs; - switch (n1->kindOf(attr_name)) { - case AttributeKind::s: - if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) { - GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), - "' did not match: ", n1->s(attr_name), " != ", n2->s(attr_name), " \n", *n1, - *n2); - return false; +namespace mmdeploy +{ + namespace torch_jit + { + + using torch::jit::AttributeKind; + using torch::jit::ClassType; + using torch::jit::Node; + using torch::jit::Symbol; + using torch::jit::Value; + + namespace prim + { + using namespace ::c10::prim; } - break; - case AttributeKind::f: - if (n1->f(attr_name) != n2->f(attr_name)) { - GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), - "' did not match:", n1->f(attr_name), " != ", n2->f(attr_name), " \n", *n1, - *n2); - return false; + + namespace attr + { + using namespace ::c10::attr; + } + + /** + * \brief A class implementing an API for comparing subgraphs. + */ + class SubgraphMatcher::SubgraphMatcherImpl + { + public: + explicit SubgraphMatcherImpl(const Graph& pattern, MatchAttribute match_attribute) + : pattern_(pattern) + , match_attribute_(match_attribute) + { + } + + /** + * \brief Compare matchGraph with the part of the graph denoted by a node \p + * ANCHOR. + * + * The anchor node would be compared against the deepest node in the + * match-graph. A node is considered matching if its number of inputs/outputs + * is the same as in the corresponding matchGraph node, its type is the same, + * and all nodes producing input-values also match. + */ + bool matchesSubgraphFromAnchorNode(Node* anchor); + + /** \brief Return match map for nodes. */ + std::unordered_map nodes_map() const + { + return nodes_map_; + } + + /** \brief Return match map for values. */ + std::unordered_map values_map() const + { + return values_map_; + } + + private: + bool matchValues(const Value* v1, Value* v2); + bool matchNodes(const Node* n1, Node* n2); + bool matchAttributes(const Node* n1, Node* n2); + + static bool isInput(const Value* v); + static bool isOutput(const Value* v); + + std::unordered_map nodes_map_; + std::unordered_map values_map_; + + const MatchAttribute match_attribute_; + const Graph& pattern_; + const Node* anchor_ = nullptr; + }; + + bool SubgraphMatcher::SubgraphMatcherImpl::isInput(const Value* v) + { + return v->node()->kind() == prim::Param; + } + + bool SubgraphMatcher::SubgraphMatcherImpl::isOutput(const Value* v) + { + for (const Value* output : v->owningGraph()->outputs()) + { + if (v == output) + { + return true; + } + } + return false; + } + + /** + * Compare two Values. V1 is from pattern, V2 is from the actual graph. + * + * The values are considered matching if: + * 1) the nodes defining them match + * 2) they have the same number of uses, except they are entry or exit nodes. + */ + bool SubgraphMatcher::SubgraphMatcherImpl::matchValues(const Value* v1, Value* v2) + { + // Check if we've already visited these values. + if (values_map_.count(v1)) + { + if (values_map_.at(v1) != v2) + { + GRAPH_DEBUG("Values %", + v1->debugName(), + " and %", + v2->debugName(), + " did not match because %", + v1->debugName(), + " has already been matched with %", + values_map_.at(v1)->debugName(), + ".\n"); + return false; + } + return true; + } + + // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is + // PARAM, we're comparing entering values - in these two cases the number of + // uses don't need to be the same. + if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) + { + GRAPH_DEBUG("Values %", + v1->debugName(), + " and %", + v2->debugName(), + " did not match because number of their uses is different.\n"); + return false; + } + + // Add the values to the map before calling matchNodes to avoid infinite + // recursion. + GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n"); + values_map_[v1] = v2; + return matchNodes(v1->node(), v2->node()); + } + + bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node* n2) + { + if (match_attribute_ == FORCE_MATCH && n1->numAttributes() != n2->numAttributes()) + { + GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2); + return false; + } + for (const Symbol& attr_name : n1->attributeNames()) + { + if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) + { + GRAPH_DEBUG("Nodes did not match because type of attribute '", + attr_name.toQualString(), + "' did not match:\n", + *n1, + *n2); + return false; + } + std::vector n1is, n2is; + std::vector n1fs, n2fs; + switch (n1->kindOf(attr_name)) + { + case AttributeKind::s: + if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) + { + GRAPH_DEBUG("Nodes did not match because attribute '", + attr_name.toQualString(), + "' did not match: ", + n1->s(attr_name), + " != ", + n2->s(attr_name), + " \n", + *n1, + *n2); + return false; + } + break; + case AttributeKind::f: + if (n1->f(attr_name) != n2->f(attr_name)) + { + GRAPH_DEBUG("Nodes did not match because attribute '", + attr_name.toQualString(), + "' did not match:", + n1->f(attr_name), + " != ", + n2->f(attr_name), + " \n", + *n1, + *n2); + return false; + } + break; + case AttributeKind::i: + if (n1->i(attr_name) != n2->i(attr_name)) + { + GRAPH_DEBUG("Nodes did not match because attribute '", + attr_name.toQualString(), + "' did not match:", + n1->i(attr_name), + " != ", + n2->i(attr_name), + " \n", + *n1, + *n2); + return false; + } + break; + case AttributeKind::is: + n1is = n1->is(attr_name); + n2is = n2->is(attr_name); + if (n1is.size() != n2is.size()) return false; + for (size_t i = 0; i < n1is.size(); ++i) + { + if (n1is[i] != n2is[i]) return false; + } + break; + case AttributeKind::fs: + n1fs = n1->fs(attr_name); + n2fs = n2->fs(attr_name); + if (n1fs.size() != n2fs.size()) return false; + for (size_t i = 0; i < n1fs.size(); ++i) + { + if (n1fs[i] != n2fs[i]) return false; + } + break; + default: + { + // Other attributes types not supported yet + GRAPH_DEBUG("Nodes did not match because type of attribute '", + attr_name.toQualString(), + "' is not supported.\n", + *n1, + *n2); + return false; + } + } + } + return true; + } + + static bool endsWith(const std::string& str, const std::string& suffix) + { + return str.size() >= suffix.size() && + 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } - break; - case AttributeKind::i: - if (n1->i(attr_name) != n2->i(attr_name)) { - GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), - "' did not match:", n1->i(attr_name), " != ", n2->i(attr_name), " \n", *n1, - *n2); - return false; + + /** + * Compare two Nodes. N1 is from pattern, N2 is from the actual graph. + * + * The nodes are considered matching if: + * 1) N1 and N2 are of the same kind. + * 2) Number of inputs and outputs is the same. + * 3) All input and output values match. + * + * A special case is when N1 is PARAM - this is considered outside the pattern, + * so it matches everything. + */ + bool SubgraphMatcher::SubgraphMatcherImpl::matchNodes(const Node* n1, Node* n2) + { + // Check if we've already visited these nodes. + if (nodes_map_.count(n1)) + { + return nodes_map_.at(n1) == n2; + } + + // Param node in pattern graph matches everything. + if (n1->kind() == prim::Param) + { + GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); + return true; + } + + // We don't allow matches to span across blocks, so check if N2 is in the same + // block as the first (anchor) node. + if (n2->owningBlock() != anchor_->owningBlock()) + { + GRAPH_DEBUG("Nodes did not match because it is in the different block:\n", *n1, *n2); + return false; + } + + // Special handling for matching modules + if (n1->kind() == Symbol::fromQualString("match::module")) + { + if (n2->kind() == prim::GetAttr) + { + if (!n1->hasAttributeS("name")) + { + GRAPH_DEBUG( + "Nodes did not match because special node match::module does not have 'name' " + "attribute:\n", + *n1, + *n2); + return false; + } + auto t = n2->output()->type()->expect(); + auto real_typename = t->name()->qualifiedName(); + auto pattern_typename = n1->s(attr::name); + if (!endsWith(real_typename, pattern_typename)) + { + GRAPH_DEBUG("Nodes did not match because expected module type is different:\n"); + GRAPH_DEBUG(" actualtype: ", real_typename, "\n"); + GRAPH_DEBUG(" expected type: ", pattern_typename, "\n"); + GRAPH_DEBUG("Nodes:", *n1, *n2); + return false; + } + } + } + else + { + if (n1->kind() != n2->kind() || n1->outputs().size() != n2->outputs().size() || + n1->inputs().size() != n2->inputs().size()) + { + GRAPH_DEBUG("Nodes did not match in their kind or number of inputs/outputs:\n", *n1, *n2); + return false; + } + + if (match_attribute_ != NO_MATCH) + { + if (!matchAttributes(n1, n2)) + { + return false; + } + } + } + + // Add nodes to the map before calling matchValues to avoid infinite + // recursion. + nodes_map_[n1] = n2; + for (const auto i : c10::irange(n1->outputs().size())) + { + if (!matchValues(n1->outputs()[i], n2->outputs()[i])) + { + return false; + } + } + for (const auto i : c10::irange(n1->inputs().size())) + { + if (!matchValues(n1->inputs()[i], n2->inputs()[i])) + { + return false; + } + } + + GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); + return true; + } + + /** + * Recursively try to match pattern with the actual graph starting from the + * exiting node in the pattern and anchor node in the actual graph. + */ + bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* anchor) + { + GRAPH_UPDATE("Starting match from a new anchor: ", *anchor); + nodes_map_.clear(); + values_map_.clear(); + anchor_ = anchor; + + const Node* bottom_node = *(pattern_.nodes().end()); + bottom_node = bottom_node->input(0)->node(); + + if (!matchNodes(bottom_node, anchor)) + { + return false; + } + + for (const Value* output : pattern_.outputs()) + { + AT_ASSERT(values_map_.count(output)); + } + + GRAPH_UPDATE("Pattern matched!\n"); + return true; } - break; - case AttributeKind::is: - n1is = n1->is(attr_name); - n2is = n2->is(attr_name); - if (n1is.size() != n2is.size()) return false; - for (size_t i = 0; i < n1is.size(); ++i) { - if (n1is[i] != n2is[i]) return false; + + SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute) + : impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) + { } - break; - case AttributeKind::fs: - n1fs = n1->fs(attr_name); - n2fs = n2->fs(attr_name); - if (n1fs.size() != n2fs.size()) return false; - for (size_t i = 0; i < n1fs.size(); ++i) { - if (n1fs[i] != n2fs[i]) return false; + + SubgraphMatcher::~SubgraphMatcher() = default; + + bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) + { + return impl_->matchesSubgraphFromAnchorNode(anchor); } - break; - default: { - // Other attributes types not supported yet - GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(), - "' is not supported.\n", *n1, *n2); - return false; - } - } - } - return true; -} - -static bool endsWith(const std::string& str, const std::string& suffix) { - return str.size() >= suffix.size() && - 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); -} - -/** - * Compare two Nodes. N1 is from pattern, N2 is from the actual graph. - * - * The nodes are considered matching if: - * 1) N1 and N2 are of the same kind. - * 2) Number of inputs and outputs is the same. - * 3) All input and output values match. - * - * A special case is when N1 is PARAM - this is considered outside the pattern, - * so it matches everything. - */ -bool SubgraphMatcher::SubgraphMatcherImpl::matchNodes(const Node* n1, Node* n2) { - // Check if we've already visited these nodes. - if (nodes_map_.count(n1)) { - return nodes_map_.at(n1) == n2; - } - - // Param node in pattern graph matches everything. - if (n1->kind() == prim::Param) { - GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); - return true; - } - - // We don't allow matches to span across blocks, so check if N2 is in the same - // block as the first (anchor) node. - if (n2->owningBlock() != anchor_->owningBlock()) { - GRAPH_DEBUG("Nodes did not match because it is in the different block:\n", *n1, *n2); - return false; - } - - // Special handling for matching modules - if (n1->kind() == Symbol::fromQualString("match::module")) { - if (n2->kind() == prim::GetAttr) { - if (!n1->hasAttributeS("name")) { - GRAPH_DEBUG( - "Nodes did not match because special node match::module does not have 'name' " - "attribute:\n", - *n1, *n2); - return false; - } - auto t = n2->output()->type()->expect(); - auto real_typename = t->name()->qualifiedName(); - auto pattern_typename = n1->s(attr::name); - if (!endsWith(real_typename, pattern_typename)) { - GRAPH_DEBUG("Nodes did not match because expected module type is different:\n"); - GRAPH_DEBUG(" actualtype: ", real_typename, "\n"); - GRAPH_DEBUG(" expected type: ", pattern_typename, "\n"); - GRAPH_DEBUG("Nodes:", *n1, *n2); - return false; - } - } - } else { - if (n1->kind() != n2->kind() || n1->outputs().size() != n2->outputs().size() || - n1->inputs().size() != n2->inputs().size()) { - GRAPH_DEBUG("Nodes did not match in their kind or number of inputs/outputs:\n", *n1, *n2); - return false; - } - - if (match_attribute_ != NO_MATCH) { - if (!matchAttributes(n1, n2)) { - return false; - } - } - } - - // Add nodes to the map before calling matchValues to avoid infinite - // recursion. - nodes_map_[n1] = n2; - for (const auto i : c10::irange(n1->outputs().size())) { - if (!matchValues(n1->outputs()[i], n2->outputs()[i])) { - return false; - } - } - for (const auto i : c10::irange(n1->inputs().size())) { - if (!matchValues(n1->inputs()[i], n2->inputs()[i])) { - return false; - } - } - - GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); - return true; -} - -/** - * Recursively try to match pattern with the actual graph starting from the - * exiting node in the pattern and anchor node in the actual graph. - */ -bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* anchor) { - GRAPH_UPDATE("Starting match from a new anchor: ", *anchor); - nodes_map_.clear(); - values_map_.clear(); - anchor_ = anchor; - - const Node* bottom_node = *(pattern_.nodes().end()); - bottom_node = bottom_node->input(0)->node(); - - if (!matchNodes(bottom_node, anchor)) { - return false; - } - - for (const Value* output : pattern_.outputs()) { - AT_ASSERT(values_map_.count(output)); - } - - GRAPH_UPDATE("Pattern matched!\n"); - return true; -} - -SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute) - : impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {} - -SubgraphMatcher::~SubgraphMatcher() = default; - -bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) { - return impl_->matchesSubgraphFromAnchorNode(anchor); -} - -std::unordered_map SubgraphMatcher::nodes_map() const { - return impl_->nodes_map(); -} - -std::unordered_map SubgraphMatcher::values_map() const { - return impl_->values_map(); -} - -} // namespace torch_jit + + std::unordered_map SubgraphMatcher::nodes_map() const + { + return impl_->nodes_map(); + } + + std::unordered_map SubgraphMatcher::values_map() const + { + return impl_->values_map(); + } + + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h index e2488e252c..ffe1b51aa8 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h @@ -5,34 +5,42 @@ #include #include -namespace mmdeploy { -namespace torch_jit { -using torch::jit::Graph; -using torch::jit::Node; -using torch::jit::Value; - -enum MatchAttribute { FORCE_MATCH, TRY_MATCH, NO_MATCH }; - -class SubgraphMatcher { - public: - explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH); - - ~SubgraphMatcher(); - - bool matchesSubgraphFromAnchorNode(Node* anchor); - - /** \brief Return match map for nodes. */ - std::unordered_map nodes_map() const; - - /** \brief Return match map for values. */ - std::unordered_map values_map() const; - - private: - class SubgraphMatcherImpl; - std::unique_ptr impl_; -}; - -} // namespace torch_jit +namespace mmdeploy +{ + namespace torch_jit + { + using torch::jit::Graph; + using torch::jit::Node; + using torch::jit::Value; + + enum MatchAttribute + { + FORCE_MATCH, + TRY_MATCH, + NO_MATCH + }; + + class SubgraphMatcher + { + public: + explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH); + + ~SubgraphMatcher(); + + bool matchesSubgraphFromAnchorNode(Node* anchor); + + /** \brief Return match map for nodes. */ + std::unordered_map nodes_map() const; + + /** \brief Return match map for values. */ + std::unordered_map values_map() const; + + private: + class SubgraphMatcherImpl; + std::unique_ptr impl_; + }; + + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp index 05ef9d54cd..2178bb3a4e 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp @@ -12,59 +12,63 @@ #include #if TORCH_VERSION_MINOR >= 9 -#include -#include -#include + #include + #include + #include #endif -namespace mmdeploy { +namespace mmdeploy +{ -using torch::jit::Graph; -const std::shared_ptr& required_passes(const std::shared_ptr& graph) { - RemoveExpands(graph); - CanonicalizeOps(graph); - EliminateDeadCode(graph); - return graph; -} + using torch::jit::Graph; + const std::shared_ptr& required_passes(const std::shared_ptr& graph) + { + RemoveExpands(graph); + CanonicalizeOps(graph); + EliminateDeadCode(graph); + return graph; + } -Module optimize_for_torchscript(const Module& model) { - auto frozen_model = freeze_module(model); - auto graph = frozen_model.get_method("forward").graph(); - OptimizeFrozenGraph(graph, true); + Module optimize_for_torchscript(const Module& model) + { + auto frozen_model = freeze_module(model); + auto graph = frozen_model.get_method("forward").graph(); + OptimizeFrozenGraph(graph, true); #if TORCH_VERSION_MINOR >= 9 - FuseFrozenConvAddRelu(graph); - ConvertFrozenOpsToMKLDNN(graph); - FrozenLinearTranspose(graph); + FuseFrozenConvAddRelu(graph); + ConvertFrozenOpsToMKLDNN(graph); + FrozenLinearTranspose(graph); #endif - graph = required_passes(graph); - EliminateCommonSubexpression(graph); - PeepholeOptimize(graph); - ConstantPropagation(graph); - ConstantPooling(graph); + graph = required_passes(graph); + EliminateCommonSubexpression(graph); + PeepholeOptimize(graph); + ConstantPropagation(graph); + ConstantPooling(graph); - // TODO: add more custom passes + // TODO: add more custom passes - return frozen_model; -} + return frozen_model; + } -Module optimize_for_onnx(const Module& model) { - auto frozen_model = freeze_module(model, {"training"}); - auto graph = frozen_model.get_method("forward").graph(); - OptimizeFrozenGraph(graph, true); + Module optimize_for_onnx(const Module& model) + { + auto frozen_model = freeze_module(model, {"training"}); + auto graph = frozen_model.get_method("forward").graph(); + OptimizeFrozenGraph(graph, true); #if TORCH_VERSION_MINOR >= 9 - FuseFrozenConvAddRelu(graph); - ConvertFrozenOpsToMKLDNN(graph); - FrozenLinearTranspose(graph); + FuseFrozenConvAddRelu(graph); + ConvertFrozenOpsToMKLDNN(graph); + FrozenLinearTranspose(graph); #endif - // TODO: add more custom passes + // TODO: add more custom passes - return frozen_model; -} + return frozen_model; + } -// TODO: add optimizer for other backend/onnx + // TODO: add optimizer for other backend/onnx } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h index d0d91c627d..fc5a3725d1 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h @@ -1,10 +1,11 @@ // Copyright (c) OpenMMLab. All rights reserved. #include -namespace mmdeploy { -using torch::jit::script::Module; +namespace mmdeploy +{ + using torch::jit::script::Module; -Module optimize_for_torchscript(const Module &model); + Module optimize_for_torchscript(const Module& model); -Module optimize_for_onnx(const Module &model); + Module optimize_for_onnx(const Module& model); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp index c6541e630a..c26db5a34f 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp @@ -4,135 +4,161 @@ #include #include -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::Block; -using torch::jit::EqualNode; -using torch::jit::HashNode; -using torch::jit::Node; -using torch::jit::Value; - -struct EqualNodeWithParams { - EqualNodeWithParams(std::unordered_map& params) : params_(params) {} - - bool operator()(const Node* lhs, const Node* rhs) const { - auto lhs_inputs = lhs->inputs(); - auto rhs_inputs = rhs->inputs(); - } - - private: - std::unordered_map& params_; -}; - -struct CommonSubexpressionEliminator { - using ParamMapType = std::unordered_map>; - CommonSubexpressionEliminator(std::shared_ptr graph, - std::unordered_map& params) - : graph_(std::move(graph)), params_(params) {} - - bool run(std::function parent_lookup_fn) { - ParamMapType param_map; - return run(graph_->block(), std::move(parent_lookup_fn), param_map); - } - - // The function implements common subexpression elimination. - // Since the nodes are visited in topological order, one pass is enough. - // returns true if CSE made changes to a graph - bool run(Block* block, std::function parent_lookup_fn, ParamMapType& param_map) { - std::unordered_set subexprs; - bool changed = false; - for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { - auto node = *it; - - // check if inputs come from params(graph input) - auto node_inputs = node->inputs(); - for (auto input : node_inputs) { - if (input->node()->kind() == Symbol::fromQualString("prim::Param")) { - auto debug_name = input->debugName(); - - // check if input in params_ - if (params_.find(debug_name) == params_.end()) continue; - - // check if input is already visited. - if (param_map.find(debug_name) != param_map.end()) continue; - - // check if there is a param has same value with input - auto val = params_[debug_name]; - bool update_map = true; - for (auto kv : param_map) { - auto param_val = kv.second.first; - if (val.device() != param_val.device()) continue; - if (val.dtype() != param_val.dtype()) continue; - if (!val.equal(param_val)) continue; - input->replaceAllUsesWith(kv.second.second); - update_map = false; - break; - } - - // add input to param_map - if (update_map) { - param_map.emplace(debug_name, - std::make_pair(std::move(val), std::move(input))); - } - } - } - - if (!node->blocks().empty()) { - // Traverse sub-blocks. - for (auto block : node->blocks()) { - changed |= run( - block, - [&](Node* n) { - auto existing = subexprs.find(n); - if (existing != subexprs.end()) { - return *existing; +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::Block; + using torch::jit::EqualNode; + using torch::jit::HashNode; + using torch::jit::Node; + using torch::jit::Value; + + struct EqualNodeWithParams + { + EqualNodeWithParams(std::unordered_map& params) + : params_(params) + { + } + + bool operator()(const Node* lhs, const Node* rhs) const + { + auto lhs_inputs = lhs->inputs(); + auto rhs_inputs = rhs->inputs(); + } + + private: + std::unordered_map& params_; + }; + + struct CommonSubexpressionEliminator + { + using ParamMapType = std::unordered_map>; + CommonSubexpressionEliminator(std::shared_ptr graph, + std::unordered_map& params) + : graph_(std::move(graph)) + , params_(params) + { + } + + bool run(std::function parent_lookup_fn) + { + ParamMapType param_map; + return run(graph_->block(), std::move(parent_lookup_fn), param_map); + } + + // The function implements common subexpression elimination. + // Since the nodes are visited in topological order, one pass is enough. + // returns true if CSE made changes to a graph + bool run(Block* block, std::function parent_lookup_fn, ParamMapType& param_map) + { + std::unordered_set subexprs; + bool changed = false; + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) + { + auto node = *it; + + // check if inputs come from params(graph input) + auto node_inputs = node->inputs(); + for (auto input : node_inputs) + { + if (input->node()->kind() == Symbol::fromQualString("prim::Param")) + { + auto debug_name = input->debugName(); + + // check if input in params_ + if (params_.find(debug_name) == params_.end()) continue; + + // check if input is already visited. + if (param_map.find(debug_name) != param_map.end()) continue; + + // check if there is a param has same value with input + auto val = params_[debug_name]; + bool update_map = true; + for (auto kv : param_map) + { + auto param_val = kv.second.first; + if (val.device() != param_val.device()) continue; + if (val.dtype() != param_val.dtype()) continue; + if (!val.equal(param_val)) continue; + input->replaceAllUsesWith(kv.second.second); + update_map = false; + break; + } + + // add input to param_map + if (update_map) + { + param_map.emplace(debug_name, + std::make_pair(std::move(val), std::move(input))); + } + } + } + + if (!node->blocks().empty()) + { + // Traverse sub-blocks. + for (auto block : node->blocks()) + { + changed |= run( + block, + [&](Node* n) + { + auto existing = subexprs.find(n); + if (existing != subexprs.end()) + { + return *existing; + } + + return parent_lookup_fn(n); + }, + param_map); + } + + continue; + } + + // Check for CSE opportunities in the parent block. + auto parent_lookup = parent_lookup_fn(node); + auto g_out = node->owningGraph()->outputs(); + if (parent_lookup != nullptr) + { + changed = true; + node->replaceAllUsesWith(parent_lookup); + it.destroyCurrent(); + continue; + } + + // Check whether the same subexpression already exists. + auto subit = subexprs.insert(node); + if (!subit.second) + { + // Subexpression exists, replace the uses of node, and destroy it. + auto existing = *subit.first; + + changed = true; + node->replaceAllUsesWith(existing); + // Destroy the node. + it.destroyCurrent(); + } } - return parent_lookup_fn(n); - }, - param_map); - } + return changed; + } - continue; - } - - // Check for CSE opportunities in the parent block. - auto parent_lookup = parent_lookup_fn(node); - auto g_out = node->owningGraph()->outputs(); - if (parent_lookup != nullptr) { - changed = true; - node->replaceAllUsesWith(parent_lookup); - it.destroyCurrent(); - continue; - } - - // Check whether the same subexpression already exists. - auto subit = subexprs.insert(node); - if (!subit.second) { - // Subexpression exists, replace the uses of node, and destroy it. - auto existing = *subit.first; - - changed = true; - node->replaceAllUsesWith(existing); - // Destroy the node. - it.destroyCurrent(); - } - } - - return changed; - } - - private: - std::shared_ptr graph_; - std::unordered_map& params_; -}; - -void CommonSubgraphElimination(std::shared_ptr& graph, - std::unordered_map& params) { - CommonSubexpressionEliminator cse(graph, params); - cse.run([](Node*) { return nullptr; }); -} -} // namespace torch_jit + private: + std::shared_ptr graph_; + std::unordered_map& params_; + }; + + void CommonSubgraphElimination(std::shared_ptr& graph, + std::unordered_map& params) + { + CommonSubexpressionEliminator cse(graph, params); + cse.run([](Node*) + { return nullptr; }); + } + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h index d90b98073e..da108ff733 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h @@ -3,18 +3,20 @@ #define _COMMON_SUBGRAPH_ELIMINATION_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::Tensor; -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::Tensor; + using torch::jit::Graph; -// This pass is used eliminate the common subgraph. -// There are two main difference between the one in torch/csrc/jit/pass -// 1. AliasDb is not needed in ONNX model -// 2. params might also participated in the elimination -void CommonSubgraphElimination(std::shared_ptr& graph, - std::unordered_map& params); -} // namespace torch_jit + // This pass is used eliminate the common subgraph. + // There are two main difference between the one in torch/csrc/jit/pass + // 1. AliasDb is not needed in ONNX model + // 2. params might also participated in the elimination + void CommonSubgraphElimination(std::shared_ptr& graph, + std::unordered_map& params); + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp index 73f8965412..db44bdb4c1 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp @@ -9,89 +9,94 @@ #include "utils.h" -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::IValue; -using torch::jit::Match; -using torch::jit::TensorType; -using torch::jit::TypeKind; -using torch::jit::Value; - -static bool matchClsHead(const Match& match, const std::unordered_map& map) { - // TODO: check if value map in latest pytorch can ease the filter. - - // check cat -1 - { - // check if the shape of second inputs is 1 - auto cat_v1 = match.values_map.at(map.at("cat1")); - if (cat_v1->type()->kind() != TypeKind::TensorType) return false; - auto cat_v1_type = cat_v1->type()->cast(); - auto cat_v1_size = cat_v1_type->sizes().concrete_sizes(); - if (!cat_v1_size.has_value()) return false; - IValue cat_v1_size_value(cat_v1_size.value()); - auto size_list = cat_v1_size_value.toIntList(); - if (size_list.size() != 1 || size_list[0] != 1) return false; - } - - // check unsqueeze - auto cat_v0 = match.values_map.at(map.at("cat0")); - auto unsqueeze_node = cat_v0->node(); - { - if (!is_kind(unsqueeze_node, "onnx::Unsqueeze")) return false; - auto unsqueeze_axes = unsqueeze_node->is(Symbol::attr("axes")); - if (unsqueeze_axes.size() != 1 || unsqueeze_axes[0] != 0) return false; - } - - // check gather - auto gather_node = unsqueeze_node->input()->node(); - auto gather_inputs = gather_node->inputs(); - { - if (!is_kind(gather_node, "onnx::Gather")) return false; - auto gather_axis = gather_node->i(Symbol::attr("axis")); - if (gather_axis != 0) return false; - } - - auto x = match.values_map.at(map.at("x")); - // check shape - auto shape_node = gather_inputs[0]->node(); - { - if (!is_kind(shape_node, "onnx::Shape")) return false; - if (shape_node->input() != x) return false; - } - - // check constant - auto const_node = gather_inputs[1]->node(); - { - if (!is_kind(const_node, "onnx::Constant")) return false; - auto ival = const_node->t(Symbol::attr("value")); - if (ival.dim() != 0) return false; - auto ival_dataptr = ival.data_ptr(); - if (ival_dataptr[0] != 0) return false; - } - - // check if reshape is the output of the graph - auto reshape_pattern = map.at("reshape"); - auto reshape_node = match.values_map.at(reshape_pattern); - auto uses = reshape_node->uses(); - for (auto use : uses) { - auto user = use.user; - if (is_kind(user, "prim::Return")) return false; - } - - return true; -} - -// from: -// x->shape->gather->unsqueeze->concat -// | | -// gap--------------------------reshape -// -// to: -// x->gap->flatten -void FlattenClsHead(std::shared_ptr& graph) { - std::string pattern = R"IR( +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::IValue; + using torch::jit::Match; + using torch::jit::TensorType; + using torch::jit::TypeKind; + using torch::jit::Value; + + static bool matchClsHead(const Match& match, const std::unordered_map& map) + { + // TODO: check if value map in latest pytorch can ease the filter. + + // check cat -1 + { + // check if the shape of second inputs is 1 + auto cat_v1 = match.values_map.at(map.at("cat1")); + if (cat_v1->type()->kind() != TypeKind::TensorType) return false; + auto cat_v1_type = cat_v1->type()->cast(); + auto cat_v1_size = cat_v1_type->sizes().concrete_sizes(); + if (!cat_v1_size.has_value()) return false; + IValue cat_v1_size_value(cat_v1_size.value()); + auto size_list = cat_v1_size_value.toIntList(); + if (size_list.size() != 1 || size_list[0] != 1) return false; + } + + // check unsqueeze + auto cat_v0 = match.values_map.at(map.at("cat0")); + auto unsqueeze_node = cat_v0->node(); + { + if (!is_kind(unsqueeze_node, "onnx::Unsqueeze")) return false; + auto unsqueeze_axes = unsqueeze_node->is(Symbol::attr("axes")); + if (unsqueeze_axes.size() != 1 || unsqueeze_axes[0] != 0) return false; + } + + // check gather + auto gather_node = unsqueeze_node->input()->node(); + auto gather_inputs = gather_node->inputs(); + { + if (!is_kind(gather_node, "onnx::Gather")) return false; + auto gather_axis = gather_node->i(Symbol::attr("axis")); + if (gather_axis != 0) return false; + } + + auto x = match.values_map.at(map.at("x")); + // check shape + auto shape_node = gather_inputs[0]->node(); + { + if (!is_kind(shape_node, "onnx::Shape")) return false; + if (shape_node->input() != x) return false; + } + + // check constant + auto const_node = gather_inputs[1]->node(); + { + if (!is_kind(const_node, "onnx::Constant")) return false; + auto ival = const_node->t(Symbol::attr("value")); + if (ival.dim() != 0) return false; + auto ival_dataptr = ival.data_ptr(); + if (ival_dataptr[0] != 0) return false; + } + + // check if reshape is the output of the graph + auto reshape_pattern = map.at("reshape"); + auto reshape_node = match.values_map.at(reshape_pattern); + auto uses = reshape_node->uses(); + for (auto use : uses) + { + auto user = use.user; + if (is_kind(user, "prim::Return")) return false; + } + + return true; + } + + // from: + // x->shape->gather->unsqueeze->concat + // | | + // gap--------------------------reshape + // + // to: + // x->gap->flatten + void FlattenClsHead(std::shared_ptr& graph) + { + std::string pattern = R"IR( graph(%x, %cat0, %cat1): %gap = onnx::GlobalAveragePool(%x) %cat = onnx::Concat[axis=0](%cat0, %cat1) @@ -99,21 +104,22 @@ void FlattenClsHead(std::shared_ptr& graph) { return (%reshape) )IR"; - std::string replacement = R"IR( + std::string replacement = R"IR( graph(%x, %cat0, %cat1): %gap = onnx::GlobalAveragePool(%x) %flatten = onnx::Flatten(%gap) return (%flatten) )IR"; - torch::jit::SubgraphRewriter subgraph_rewriter; - subgraph_rewriter.RegisterRewritePattern(pattern, replacement); - subgraph_rewriter.runOnGraph(graph, matchClsHead); + torch::jit::SubgraphRewriter subgraph_rewriter; + subgraph_rewriter.RegisterRewritePattern(pattern, replacement); + subgraph_rewriter.runOnGraph(graph, matchClsHead); - torch::jit::EliminateDeadCode( - graph->block(), true, - torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); -} + torch::jit::EliminateDeadCode( + graph->block(), + true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); + } -} // namespace torch_jit + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h index b66b700d1c..64d8ea3352 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h @@ -3,12 +3,14 @@ #define _FLATTEN_CLS_HEAD_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::jit::Graph; -void FlattenClsHead(std::shared_ptr& graph); -} // namespace torch_jit + void FlattenClsHead(std::shared_ptr& graph); + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp index 8dc5847753..bc784671ea 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp @@ -6,131 +6,155 @@ #include "common_subgraph_elimination.h" #include "torch/csrc/jit/ir/irparser.h" -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::Block; -using torch::jit::IValue; -using torch::jit::Node; - -bool RemoveBoolCast(Node* node) { - auto bottom_node = node->input()->node(); - if (bottom_node->kind() != Symbol::onnx("Greater") && - bottom_node->kind() != Symbol::onnx("Less")) { - return false; - } - node->output()->replaceAllUsesWith(bottom_node->output()); - return true; -} - -bool FuseSelectAssign(Node* node, std::unordered_map& params, - std::unordered_map& vmap, SubgraphMatcher& matcher) { - auto values_map = matcher.values_map(); - - auto cmp1 = values_map[vmap["cmp_1"]]->node(); - auto cmp2 = values_map[vmap["cmp_2"]]->node(); - if (cmp1 != cmp2) { - // cmp_1 == cmp_2, cmp in (Great, Less) - if (cmp1->kind() != cmp2->kind()) return false; - if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less"))) - return false; - - // check threshold - Node* cmps[] = {cmp1, cmp2}; - float thres = 0.0f; - Node* x = nullptr; - for (int i = 0; i < 2; ++i) { - auto cmp = cmps[i]; - auto threshold = cmp->inputs()[1]->node(); - if (threshold->kind() != Symbol::onnx("Constant")) return false; - auto thres_val = threshold->t(Symbol::attr("value")); - if (i == 0) { - thres = thres_val.data_ptr()[0]; - x = cmp->inputs()[0]->node(); - } else { - float tmp_val = thres_val.data_ptr()[0]; - if (fabs(thres - tmp_val) > 1e-10) { - return false; +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::Block; + using torch::jit::IValue; + using torch::jit::Node; + + bool RemoveBoolCast(Node* node) + { + auto bottom_node = node->input()->node(); + if (bottom_node->kind() != Symbol::onnx("Greater") && + bottom_node->kind() != Symbol::onnx("Less")) + { + return false; + } + node->output()->replaceAllUsesWith(bottom_node->output()); + return true; } - if (x != cmp->inputs()[0]->node()) { - return false; + + bool FuseSelectAssign(Node* node, + std::unordered_map& params, + std::unordered_map& vmap, + SubgraphMatcher& matcher) + { + auto values_map = matcher.values_map(); + + auto cmp1 = values_map[vmap["cmp_1"]]->node(); + auto cmp2 = values_map[vmap["cmp_2"]]->node(); + if (cmp1 != cmp2) + { + // cmp_1 == cmp_2, cmp in (Great, Less) + if (cmp1->kind() != cmp2->kind()) return false; + if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less"))) + return false; + + // check threshold + Node* cmps[] = {cmp1, cmp2}; + float thres = 0.0f; + Node* x = nullptr; + for (int i = 0; i < 2; ++i) + { + auto cmp = cmps[i]; + auto threshold = cmp->inputs()[1]->node(); + if (threshold->kind() != Symbol::onnx("Constant")) return false; + auto thres_val = threshold->t(Symbol::attr("value")); + if (i == 0) + { + thres = thres_val.data_ptr()[0]; + x = cmp->inputs()[0]->node(); + } + else + { + float tmp_val = thres_val.data_ptr()[0]; + if (fabs(thres - tmp_val) > 1e-10) + { + return false; + } + if (x != cmp->inputs()[0]->node()) + { + return false; + } + } + } + } + + { + // check shape of reshape + Node* shape = values_map[vmap["reshape_1_shape"]]->node(); + auto shape_val = shape->t(Symbol::attr("value")); + if (shape_val.dim() != 1) return false; + if (shape_val.data_ptr()[0] != -1) return false; + } + + { + // check transpose + Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()}; + for (auto tran : trans) + { + auto tran_perm = tran->is(Symbol::attr("perm")); + if (tran_perm.size() != 2) return false; + if (tran_perm[0] != 1 || tran_perm[1] != 0) return false; + } + } + + { + // check gather indice + Node* gather_inds = values_map[vmap["gather_inds_2"]]->node(); + auto inds_val = gather_inds->t(Symbol::attr("value")); + if (inds_val.dim() != 0) return false; + if (inds_val.data_ptr()[0] != 0) return false; + } + + { + // check slice start + Node* slice = values_map[vmap["slice_2"]]->node(); + auto start_name = slice->inputs()[1]->debugName(); + auto start_val = params[start_name]; + if (start_val.dim() != 1) return false; + if (start_val.data_ptr()[0] != 0) return false; + } + + // create new node + auto graph = node->owningGraph(); + auto z = values_map[vmap["z"]]; + auto y = values_map[vmap["y"]]; + auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y}); + where_node->insertBefore(node); + where_node->output()->copyMetadata(node->output()); + node->output()->replaceAllUsesWith(where_node->output()); + return true; + } + + void FuseSelectAssign(Block* block, + std::unordered_map& params, + std::unordered_map& vmap, + SubgraphMatcher& matcher) + { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) + { + auto node = *it; + ++it; + for (auto block : node->blocks()) + { + FuseSelectAssign(block, params, vmap, matcher); + } + + if (node->kind() == Symbol::onnx("Cast") && node->i(Symbol::attr("to")) == 9) + { + RemoveBoolCast(node); + } + else if (matcher.matchesSubgraphFromAnchorNode(node)) + { + FuseSelectAssign(node, params, vmap, matcher); + } + } } - } - } - } - - { - // check shape of reshape - Node* shape = values_map[vmap["reshape_1_shape"]]->node(); - auto shape_val = shape->t(Symbol::attr("value")); - if (shape_val.dim() != 1) return false; - if (shape_val.data_ptr()[0] != -1) return false; - } - - { - // check transpose - Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()}; - for (auto tran : trans) { - auto tran_perm = tran->is(Symbol::attr("perm")); - if (tran_perm.size() != 2) return false; - if (tran_perm[0] != 1 || tran_perm[1] != 0) return false; - } - } - - { - // check gather indice - Node* gather_inds = values_map[vmap["gather_inds_2"]]->node(); - auto inds_val = gather_inds->t(Symbol::attr("value")); - if (inds_val.dim() != 0) return false; - if (inds_val.data_ptr()[0] != 0) return false; - } - - { - // check slice start - Node* slice = values_map[vmap["slice_2"]]->node(); - auto start_name = slice->inputs()[1]->debugName(); - auto start_val = params[start_name]; - if (start_val.dim() != 1) return false; - if (start_val.data_ptr()[0] != 0) return false; - } - - // create new node - auto graph = node->owningGraph(); - auto z = values_map[vmap["z"]]; - auto y = values_map[vmap["y"]]; - auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y}); - where_node->insertBefore(node); - where_node->output()->copyMetadata(node->output()); - node->output()->replaceAllUsesWith(where_node->output()); - return true; -} - -void FuseSelectAssign(Block* block, std::unordered_map& params, - std::unordered_map& vmap, SubgraphMatcher& matcher) { - auto graph = block->owningGraph(); - auto it = block->nodes().begin(); - while (it != block->nodes().end()) { - auto node = *it; - ++it; - for (auto block : node->blocks()) { - FuseSelectAssign(block, params, vmap, matcher); - } - - if (node->kind() == Symbol::onnx("Cast") && node->i(Symbol::attr("to")) == 9) { - RemoveBoolCast(node); - } else if (matcher.matchesSubgraphFromAnchorNode(node)) { - FuseSelectAssign(node, params, vmap, matcher); - } - } -} - -void FuseSelectAssign(std::shared_ptr& graph, - std::unordered_map& params) { - // cse before search - CommonSubgraphElimination(graph, params); - - std::string pattern_str = R"IR( + + void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params) + { + // cse before search + CommonSubgraphElimination(graph, params); + + std::string pattern_str = R"IR( graph(%y, %z, %cmp_1, %cmp_2, %start, %axes, %shape_2): %nz_1 = onnx::NonZero(%cmp_1) %trans_1 = onnx::Transpose(%nz_1) @@ -149,15 +173,16 @@ void FuseSelectAssign(std::shared_ptr& graph, return (%scatter_2) )IR"; - Graph pattern; - std::unordered_map vmap; - torch::jit::parseIR(pattern_str, &pattern, vmap); - - SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH); - FuseSelectAssign(graph->block(), params, vmap, matcher); - torch::jit::EliminateDeadCode( - graph->block(), true, - torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); -} -} // namespace torch_jit + Graph pattern; + std::unordered_map vmap; + torch::jit::parseIR(pattern_str, &pattern, vmap); + + SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH); + FuseSelectAssign(graph->block(), params, vmap, matcher); + torch::jit::EliminateDeadCode( + graph->block(), + true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); + } + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h index afa0dc56d6..0e80ec1d67 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h @@ -3,15 +3,17 @@ #define _FUSE_SELECT_ASSIGN_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::Tensor; -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::Tensor; + using torch::jit::Graph; -// this pass is used to fuse y[x>thres] = z[x>thres] -void FuseSelectAssign(std::shared_ptr& graph, - std::unordered_map& params); -} // namespace torch_jit + // this pass is used to fuse y[x>thres] = z[x>thres] + void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params); + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp index 3da4933b15..dea6909f8b 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp @@ -5,111 +5,131 @@ #include "utils.h" -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::Block; -using torch::jit::IValue; -using torch::jit::Node; -using torch::jit::TensorType; -using torch::jit::Value; - -void MergeShapeConcate(Node* node) { - auto inputs = node->inputs(); - - std::vector gather_value; - Value* shape_from = nullptr; - - std::vector node_to_remove{node}; - - // check pattern shape->gather->unsqueeze->concate - for (auto input : inputs) { - auto unsqueeze_node = input->node(); - if (!is_kind(unsqueeze_node, "onnx::Unsqueeze") || unsqueeze_node->output()->uses().size() != 1) - return; - - if (unsqueeze_node->hasAttribute(Symbol::attr("axes"))) { - auto axes = unsqueeze_node->is(Symbol::attr("axes")); - if (axes.size() != 1 && axes[0] != 0) return; - } - - auto gather_node = unsqueeze_node->input(0)->node(); - if (!is_kind(gather_node, "onnx::Gather") || gather_node->i(Symbol::attr("axis")) != 0 || - gather_node->output()->uses().size() != 1) - return; - - auto gather_inputs = gather_node->inputs(); - auto gather_data = gather_inputs[0]; - auto gather_indices = gather_inputs[1]; - auto shape_node = gather_data->node(); - if (!is_kind(shape_node, "onnx::Shape") || shape_node->output()->uses().size() != 1) return; - - auto current_shape_from = shape_node->input(); - if (!shape_from) { - shape_from = current_shape_from; - } else { - if (shape_from != current_shape_from) return; - } - - auto constant_node = gather_indices->node(); - if (!is_kind(constant_node, "onnx::Constant")) return; - - auto gather_indices_val = constant_node->t(Symbol::attr("value")); - int64_t* data_ptr = gather_indices_val.data_ptr(); - if (gather_indices_val.dim() == 0) { - gather_value.push_back(data_ptr[0]); - } else { - int element_size = gather_indices_val.element_size(); - for (int j = 0; j < element_size; ++j) { - gather_value.push_back(data_ptr[j]); - } - } - - node_to_remove.insert(node_to_remove.end(), {unsqueeze_node, gather_node, shape_node}); - } - - // create constant value - auto graph = node->owningGraph(); - auto const_node = graph->create(Symbol::onnx("Constant")); - const_node->t_(Symbol::attr("value"), at::tensor(gather_value)); - auto first_node = node->owningGraph()->block()->nodes().front(); - if (const_node != first_node) const_node->insertBefore(first_node); - - // recreate shape node - auto shape_node = graph->create(Symbol::onnx("Shape"), {shape_from}); - shape_node->insertBefore(node); - - // create gather node - auto gather_node = - graph->create(Symbol::onnx("Gather"), {shape_node->output(), const_node->output()}); - - // insert into graph - gather_node->insertAfter(node); - node->output()->replaceAllUsesWith(gather_node->output()); - - for (auto n : node_to_remove) { - n->destroy(); - } -} - -void MergeShapeConcate(Block* block) { - auto graph = block->owningGraph(); - auto it = block->nodes().begin(); - while (it != block->nodes().end()) { - auto node = *it; - ++it; - for (auto block : node->blocks()) { - MergeShapeConcate(block); - } - - if (is_kind(node, "onnx::Concat")) { - MergeShapeConcate(node); - } - } -} - -void MergeShapeConcate(const std::shared_ptr& graph) { MergeShapeConcate(graph->block()); } - -} // namespace torch_jit +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::Block; + using torch::jit::IValue; + using torch::jit::Node; + using torch::jit::TensorType; + using torch::jit::Value; + + void MergeShapeConcate(Node* node) + { + auto inputs = node->inputs(); + + std::vector gather_value; + Value* shape_from = nullptr; + + std::vector node_to_remove{node}; + + // check pattern shape->gather->unsqueeze->concate + for (auto input : inputs) + { + auto unsqueeze_node = input->node(); + if (!is_kind(unsqueeze_node, "onnx::Unsqueeze") || unsqueeze_node->output()->uses().size() != 1) + return; + + if (unsqueeze_node->hasAttribute(Symbol::attr("axes"))) + { + auto axes = unsqueeze_node->is(Symbol::attr("axes")); + if (axes.size() != 1 && axes[0] != 0) return; + } + + auto gather_node = unsqueeze_node->input(0)->node(); + if (!is_kind(gather_node, "onnx::Gather") || gather_node->i(Symbol::attr("axis")) != 0 || + gather_node->output()->uses().size() != 1) + return; + + auto gather_inputs = gather_node->inputs(); + auto gather_data = gather_inputs[0]; + auto gather_indices = gather_inputs[1]; + auto shape_node = gather_data->node(); + if (!is_kind(shape_node, "onnx::Shape") || shape_node->output()->uses().size() != 1) return; + + auto current_shape_from = shape_node->input(); + if (!shape_from) + { + shape_from = current_shape_from; + } + else + { + if (shape_from != current_shape_from) return; + } + + auto constant_node = gather_indices->node(); + if (!is_kind(constant_node, "onnx::Constant")) return; + + auto gather_indices_val = constant_node->t(Symbol::attr("value")); + int64_t* data_ptr = gather_indices_val.data_ptr(); + if (gather_indices_val.dim() == 0) + { + gather_value.push_back(data_ptr[0]); + } + else + { + int element_size = gather_indices_val.element_size(); + for (int j = 0; j < element_size; ++j) + { + gather_value.push_back(data_ptr[j]); + } + } + + node_to_remove.insert(node_to_remove.end(), {unsqueeze_node, gather_node, shape_node}); + } + + // create constant value + auto graph = node->owningGraph(); + auto const_node = graph->create(Symbol::onnx("Constant")); + const_node->t_(Symbol::attr("value"), at::tensor(gather_value)); + auto first_node = node->owningGraph()->block()->nodes().front(); + if (const_node != first_node) const_node->insertBefore(first_node); + + // recreate shape node + auto shape_node = graph->create(Symbol::onnx("Shape"), {shape_from}); + shape_node->insertBefore(node); + + // create gather node + auto gather_node = + graph->create(Symbol::onnx("Gather"), {shape_node->output(), const_node->output()}); + + // insert into graph + gather_node->insertAfter(node); + node->output()->replaceAllUsesWith(gather_node->output()); + + for (auto n : node_to_remove) + { + n->destroy(); + } + } + + void MergeShapeConcate(Block* block) + { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) + { + auto node = *it; + ++it; + for (auto block : node->blocks()) + { + MergeShapeConcate(block); + } + + if (is_kind(node, "onnx::Concat")) + { + MergeShapeConcate(node); + } + } + } + + void MergeShapeConcate(const std::shared_ptr& graph) + { + MergeShapeConcate(graph->block()); + } + + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h index 8656da63c2..13a67f0f47 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h @@ -3,12 +3,14 @@ #define _MERGE_SHAPE_CONCATE_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::jit::Graph; -void MergeShapeConcate(const std::shared_ptr& graph); -} // namespace torch_jit + void MergeShapeConcate(const std::shared_ptr& graph); + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp index f0ef5a5230..7c2f866b85 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp @@ -7,75 +7,90 @@ #include "utils.h" -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::Block; -using torch::jit::IValue; -using torch::jit::Node; -using torch::jit::TensorType; -using torch::jit::Value; - -void RemoveReshapeChain(Node* node) { - // reshape->reshape => reshape - auto output = node->output(); - if (!(output->hasUses())) { - return; - } - auto uses = output->uses(); - - for (auto use : uses) { - if (!is_kind(use.user, "onnx::Reshape") || use.offset != 0) { - return; - } - } - - auto input = node->inputs()[0]; - output->replaceAllUsesWith(input); - - node->destroy(); -} - -void RemoveRedundantCast(Node* node) { - // Cast(type n)->Cast(type n) => Cast(type n) - - auto to_type = node->i(Symbol::attr("to")); - auto input = node->input(); - - auto input_node = input->node(); - if (is_kind(input_node, "onnx::Cast") && input_node->i(Symbol::attr("to")) == to_type) { - auto output = node->output(); - - output->replaceAllUsesWith(input); - node->destroy(); - } -} - -void ONNXPeephole(Block* block) { - auto graph = block->owningGraph(); - auto it = block->nodes().begin(); - while (it != block->nodes().end()) { - auto node = *it; - ++it; - for (auto block : node->blocks()) { - ONNXPeephole(block); - } - - if (is_kind(node, "onnx::Reshape")) { - RemoveReshapeChain(node); - } else if (is_kind(node, "onnx::Cast")) { - RemoveRedundantCast(node); - } - } -} - -void ONNXPeephole(const std::shared_ptr& graph) { - ONNXPeephole(graph->block()); - torch::jit::EliminateDeadCode( - graph->block(), true, - torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); -} - -} // namespace torch_jit +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::Block; + using torch::jit::IValue; + using torch::jit::Node; + using torch::jit::TensorType; + using torch::jit::Value; + + void RemoveReshapeChain(Node* node) + { + // reshape->reshape => reshape + auto output = node->output(); + if (!(output->hasUses())) + { + return; + } + auto uses = output->uses(); + + for (auto use : uses) + { + if (!is_kind(use.user, "onnx::Reshape") || use.offset != 0) + { + return; + } + } + + auto input = node->inputs()[0]; + output->replaceAllUsesWith(input); + + node->destroy(); + } + + void RemoveRedundantCast(Node* node) + { + // Cast(type n)->Cast(type n) => Cast(type n) + + auto to_type = node->i(Symbol::attr("to")); + auto input = node->input(); + + auto input_node = input->node(); + if (is_kind(input_node, "onnx::Cast") && input_node->i(Symbol::attr("to")) == to_type) + { + auto output = node->output(); + + output->replaceAllUsesWith(input); + node->destroy(); + } + } + + void ONNXPeephole(Block* block) + { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) + { + auto node = *it; + ++it; + for (auto block : node->blocks()) + { + ONNXPeephole(block); + } + + if (is_kind(node, "onnx::Reshape")) + { + RemoveReshapeChain(node); + } + else if (is_kind(node, "onnx::Cast")) + { + RemoveRedundantCast(node); + } + } + } + + void ONNXPeephole(const std::shared_ptr& graph) + { + ONNXPeephole(graph->block()); + torch::jit::EliminateDeadCode(graph->block(), + true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); + } + + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h index f388da1bfa..21b7be15d1 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h @@ -3,13 +3,15 @@ #define _ONNX_PEEPHOLE_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::jit::Graph; -void ONNXPeephole(const std::shared_ptr& graph); + void ONNXPeephole(const std::shared_ptr& graph); -} // namespace torch_jit + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/utils.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/utils.h index 1c92cd15a1..147e5b1349 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/utils.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/utils.h @@ -3,18 +3,24 @@ #include -namespace mmdeploy { -namespace torch_jit { -using c10::Symbol; -using torch::jit::Node; +namespace mmdeploy +{ + namespace torch_jit + { + using c10::Symbol; + using torch::jit::Node; -inline bool is_kind(const Node* node, const Symbol& symbol) { return node->kind() == symbol; } + inline bool is_kind(const Node* node, const Symbol& symbol) + { + return node->kind() == symbol; + } -inline bool is_kind(const Node* node, const char* symbol_name) { - return is_kind(node, Symbol::fromQualString(symbol_name)); -} + inline bool is_kind(const Node* node, const char* symbol_name) + { + return is_kind(node, Symbol::fromQualString(symbol_name)); + } -} // namespace torch_jit + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/codebase/CMakeLists.txt b/csrc/mmdeploy/codebase/CMakeLists.txt index f933b7fb92..172274efcb 100644 --- a/csrc/mmdeploy/codebase/CMakeLists.txt +++ b/csrc/mmdeploy/codebase/CMakeLists.txt @@ -3,29 +3,29 @@ project(mmdeploy_codebase) set(CODEBASES "") -if ("all" IN_LIST MMDEPLOY_CODEBASES) - list(APPEND CODEBASES "mmcls") - list(APPEND CODEBASES "mmdet") - list(APPEND CODEBASES "mmseg") - list(APPEND CODEBASES "mmocr") - list(APPEND CODEBASES "mmedit") - list(APPEND CODEBASES "mmpose") - list(APPEND CODEBASES "mmrotate") - list(APPEND CODEBASES "mmaction") -else () - set(CODEBASES ${MMDEPLOY_CODEBASES}) -endif () +if("all" IN_LIST MMDEPLOY_CODEBASES) + list(APPEND CODEBASES "mmcls") + list(APPEND CODEBASES "mmdet") + list(APPEND CODEBASES "mmseg") + list(APPEND CODEBASES "mmocr") + list(APPEND CODEBASES "mmedit") + list(APPEND CODEBASES "mmpose") + list(APPEND CODEBASES "mmrotate") + list(APPEND CODEBASES "mmaction") +else() + set(CODEBASES ${MMDEPLOY_CODEBASES}) +endif() -foreach (codebase IN LISTS CODEBASES) - message(STATUS "build codebase: ${codebase}") - if (codebase STREQUAL "mmpretrain") - set(subdir_name "mmcls") - elseif (codebase STREQUAL "mmyolo") - set(subdir_name "mmdet") - elseif (codebase STREQUAL "mmagic") - set(subdir_name "mmedit") - else() - set(subdir_name ${codebase}) - endif() - add_subdirectory(${subdir_name}) -endforeach () +foreach(codebase IN LISTS CODEBASES) + message(STATUS "build codebase: ${codebase}") + if(codebase STREQUAL "mmpretrain") + set(subdir_name "mmcls") + elseif(codebase STREQUAL "mmyolo") + set(subdir_name "mmdet") + elseif(codebase STREQUAL "mmagic") + set(subdir_name "mmedit") + else() + set(subdir_name ${codebase}) + endif() + add_subdirectory(${subdir_name}) +endforeach() diff --git a/csrc/mmdeploy/codebase/common.h b/csrc/mmdeploy/codebase/common.h index 391f177590..a0d1bc4a18 100644 --- a/csrc/mmdeploy/codebase/common.h +++ b/csrc/mmdeploy/codebase/common.h @@ -9,69 +9,92 @@ #include "mmdeploy/core/utils/formatter.h" #include "mmdeploy/experimental/module_adapter.h" -namespace mmdeploy { +namespace mmdeploy +{ -using namespace framework; + using namespace framework; -class Context { - public: - explicit Context(const Value& config) { - MMDEPLOY_DEBUG("config: {}", config); - device_ = config["context"]["device"].get(); - stream_ = config["context"]["stream"].get(); - } + class Context + { + public: + explicit Context(const Value& config) + { + MMDEPLOY_DEBUG("config: {}", config); + device_ = config["context"]["device"].get(); + stream_ = config["context"]["stream"].get(); + } - Device& device() { return device_; } - Stream& stream() { return stream_; } + Device& device() + { + return device_; + } - protected: - Device device_; - Stream stream_; -}; + Stream& stream() + { + return stream_; + } -template -class CodebaseCreator : public Creator { - public: - std::string_view name() const noexcept override { return Tag::name; } - std::unique_ptr Create(const Value& cfg) override { - constexpr auto key{"component"}; - if (!cfg.contains(key)) { - MMDEPLOY_ERROR("no key '{}' in config {}", key, cfg); - throw_exception(eInvalidArgument); - } - if (!cfg[key].is_string()) { - MMDEPLOY_ERROR("key '{}' is not a string", key); - throw_exception(eInvalidArgument); - } - auto postprocess_type = cfg[key].get(); - auto creator = gRegistry().Get(postprocess_type); - if (creator == nullptr) { - MMDEPLOY_ERROR("Could not found entry '{}' in {}. Available components: {}", postprocess_type, - Tag::name, gRegistry().List()); - throw_exception(eEntryNotFound); - } - return creator->Create(cfg); - } -}; + protected: + Device device_; + Stream stream_; + }; -#define MMDEPLOY_DECLARE_CODEBASE(codebase_type, codebase_name) \ - class codebase_type : public Context { \ - public: \ - static constexpr const auto name = #codebase_name; \ - using type = std::unique_ptr; \ - explicit codebase_type(const Value& config) : Context(config) {} \ - }; \ - MMDEPLOY_DECLARE_REGISTRY(codebase_type, std::unique_ptr(const Value& config)); + template + class CodebaseCreator : public Creator + { + public: + std::string_view name() const noexcept override + { + return Tag::name; + } -#define MMDEPLOY_REGISTER_CODEBASE(codebase) \ - using codebase##_##Creator = CodebaseCreator; \ - MMDEPLOY_REGISTER_CREATOR(Module, codebase##_##Creator) \ - MMDEPLOY_DEFINE_REGISTRY(codebase) + std::unique_ptr Create(const Value& cfg) override + { + constexpr auto key{"component"}; + if (!cfg.contains(key)) + { + MMDEPLOY_ERROR("no key '{}' in config {}", key, cfg); + throw_exception(eInvalidArgument); + } + if (!cfg[key].is_string()) + { + MMDEPLOY_ERROR("key '{}' is not a string", key); + throw_exception(eInvalidArgument); + } + auto postprocess_type = cfg[key].get(); + auto creator = gRegistry().Get(postprocess_type); + if (creator == nullptr) + { + MMDEPLOY_ERROR("Could not found entry '{}' in {}. Available components: {}", + postprocess_type, + Tag::name, + gRegistry().List()); + throw_exception(eEntryNotFound); + } + return creator->Create(cfg); + } + }; -#define MMDEPLOY_REGISTER_CODEBASE_COMPONENT(codebase, component_type) \ - MMDEPLOY_REGISTER_FACTORY_FUNC(codebase, (component_type, 0), [](const Value& config) { \ - return CreateTask(component_type(config)); \ - }) +#define MMDEPLOY_DECLARE_CODEBASE(codebase_type, codebase_name) \ + class codebase_type : public Context \ + { \ + public: \ + static constexpr const auto name = #codebase_name; \ + using type = std::unique_ptr; \ + explicit codebase_type(const Value& config) \ + : Context(config) \ + { \ + } \ + }; \ + MMDEPLOY_DECLARE_REGISTRY(codebase_type, std::unique_ptr(const Value& config)); + +#define MMDEPLOY_REGISTER_CODEBASE(codebase) \ + using codebase##_##Creator = CodebaseCreator; \ + MMDEPLOY_REGISTER_CREATOR(Module, codebase##_##Creator) \ + MMDEPLOY_DEFINE_REGISTRY(codebase) + +#define MMDEPLOY_REGISTER_CODEBASE_COMPONENT(codebase, component_type) \ + MMDEPLOY_REGISTER_FACTORY_FUNC(codebase, (component_type, 0), [](const Value& config) { return CreateTask(component_type(config)); }) } // namespace mmdeploy diff --git a/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt b/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt index 2ea41f7271..380b7b6f46 100644 --- a/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt +++ b/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt @@ -5,11 +5,12 @@ project(mmdeploy_mmaction) file(GLOB SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp") mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") -target_link_libraries(${PROJECT_NAME} PRIVATE - mmdeploy_operation - mmdeploy_transform - mmdeploy_opencv_utils) +target_link_libraries( + ${PROJECT_NAME} PRIVATE mmdeploy_operation mmdeploy_transform + mmdeploy_opencv_utils) add_library(mmdeploy::mmaction ALIAS ${PROJECT_NAME}) -set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} video_recognizer CACHE INTERNAL "") +set(MMDEPLOY_TASKS + ${MMDEPLOY_TASKS} video_recognizer + CACHE INTERNAL "") diff --git a/csrc/mmdeploy/codebase/mmaction/base_head.cpp b/csrc/mmdeploy/codebase/mmaction/base_head.cpp index 931c9663eb..2e541fd660 100644 --- a/csrc/mmdeploy/codebase/mmaction/base_head.cpp +++ b/csrc/mmdeploy/codebase/mmaction/base_head.cpp @@ -7,66 +7,75 @@ #include "mmdeploy/core/tensor.h" #include "mmdeploy/core/utils/device_utils.h" -namespace mmdeploy::mmaction { +namespace mmdeploy::mmaction +{ -class BaseHead : public MMAction { - public: - explicit BaseHead(const Value& cfg) : MMAction(cfg) { - if (cfg.contains("params")) { - topk_ = cfg["params"].value("topk", 1); - if (topk_ <= 0) { - MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_); - throw_exception(eInvalidArgument); - } - } - } + class BaseHead : public MMAction + { + public: + explicit BaseHead(const Value& cfg) + : MMAction(cfg) + { + if (cfg.contains("params")) + { + topk_ = cfg["params"].value("topk", 1); + if (topk_ <= 0) + { + MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_); + throw_exception(eInvalidArgument); + } + } + } - Result operator()(const Value& infer_res) { - MMDEPLOY_DEBUG("infer_res: {}", infer_res); - auto output = infer_res["output"].get(); + Result operator()(const Value& infer_res) + { + MMDEPLOY_DEBUG("infer_res: {}", infer_res); + auto output = infer_res["output"].get(); - if (!(output.shape().size() >= 2 && output.data_type() == DataType::kFLOAT)) { - MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", output.shape(), - (int)output.data_type()); - return Status(eNotSupported); - } + if (!(output.shape().size() >= 2 && output.data_type() == DataType::kFLOAT)) + { + MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", output.shape(), (int)output.data_type()); + return Status(eNotSupported); + } - auto class_num = (int)output.shape(1); + auto class_num = (int)output.shape(1); - OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output, kHost, stream())); - OUTCOME_TRY(stream().Wait()); + OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output, kHost, stream())); + OUTCOME_TRY(stream().Wait()); - return GetLabels(_scores, class_num); - } + return GetLabels(_scores, class_num); + } - private: - Value GetLabels(const Tensor& scores, int class_num) const { - auto scores_data = scores.data(); - Labels output; - output.reserve(topk_); - std::vector idx(class_num); - iota(begin(idx), end(idx), 0); - partial_sort(begin(idx), begin(idx) + topk_, end(idx), - [&](int i, int j) { return scores_data[i] > scores_data[j]; }); - for (int i = 0; i < topk_; ++i) { - auto label = Label{idx[i], scores_data[idx[i]]}; - MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score); - output.push_back(label); - } - return to_value(std::move(output)); - } + private: + Value GetLabels(const Tensor& scores, int class_num) const + { + auto scores_data = scores.data(); + Labels output; + output.reserve(topk_); + std::vector idx(class_num); + iota(begin(idx), end(idx), 0); + partial_sort(begin(idx), begin(idx) + topk_, end(idx), [&](int i, int j) + { return scores_data[i] > scores_data[j]; }); + for (int i = 0; i < topk_; ++i) + { + auto label = Label{idx[i], scores_data[idx[i]]}; + MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score); + output.push_back(label); + } + return to_value(std::move(output)); + } - private: - static constexpr const auto kHost = Device{0}; - int topk_{1}; -}; + private: + static constexpr const auto kHost = Device{0}; + int topk_{1}; + }; -MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, BaseHead); + MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, BaseHead); -using SlowFastHead = BaseHead; -MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, SlowFastHead); + using SlowFastHead = BaseHead; + MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, SlowFastHead); -using TSNHead = BaseHead; -MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, TSNHead); + using TSNHead = BaseHead; + MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, TSNHead); } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/format_shape.cpp b/csrc/mmdeploy/codebase/mmaction/format_shape.cpp index 7d8c6ac5c6..ff65fe184d 100644 --- a/csrc/mmdeploy/codebase/mmaction/format_shape.cpp +++ b/csrc/mmdeploy/codebase/mmaction/format_shape.cpp @@ -7,122 +7,141 @@ using namespace std; -namespace mmdeploy::mmaction { - -FormatShape::FormatShape(const Value& args) { - input_format_ = args.value("input_format", std::string("")); - if (input_format_ != "NCHW" && input_format_ != "NCTHW") { - MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'"); - throw_exception(eInvalidArgument); - } - permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); -} - -Result FormatShape::MergeInputs(const std::vector& images, Tensor& inputs) { - auto N = static_cast(images.size()); - auto H = images[0].shape(1); - auto W = images[0].shape(2); - auto C = images[0].shape(3); - auto& device = operation::gContext().device(); - auto& stream = operation::gContext().stream(); - - TensorDesc desc = {device, DataType::kFLOAT, {N, H, W, C}}; - inputs = Tensor(desc); - auto offset = 0UL; - auto n_item = H * W * C; - auto copy_size = n_item * sizeof(float); - for (int i = 0; i < N; i++) { - auto src_buffer = images[i].buffer(); - auto dst_buffer = inputs.buffer(); - OUTCOME_TRY(stream.Copy(src_buffer, dst_buffer, copy_size, 0, offset)); - offset += copy_size; - } - return success(); -} - -Result FormatShape::Format(const std::vector& images, Tensor& output, int clip_len, - int num_clips) { - Tensor inputs; - OUTCOME_TRY(MergeInputs(images, inputs)); - - // Tensor dst; - if (input_format_ == "NCHW") { - OUTCOME_TRY(FormatNCHW(inputs, clip_len, num_clips, output)); - } - if (input_format_ == "NCTHW") { - OUTCOME_TRY(FormatNCTHW(inputs, clip_len, num_clips, output)); - } - - TensorShape expand_dim = output.shape(); - expand_dim.insert(expand_dim.begin(), 1); - output.Reshape(expand_dim); - - return success(); -} - -Result FormatShape::FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) { - const vector axes = {0, 3, 1, 2}; - OUTCOME_TRY(permute_.Apply(src, dst, axes)); - return success(); -} - -Result FormatShape::FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) { - auto N = src.shape(0); - auto H = src.shape(1); - auto W = src.shape(2); - auto C = src.shape(3); - int L = clip_len; - if (N % L != 0) { - return Status(eInvalidArgument); - } - int M = N / L; - src.Reshape({M, L, H, W, C}); - const vector axes = {0, 4, 1, 2, 3}; - OUTCOME_TRY(permute_.Apply(src, dst, axes)); - return success(); -} - -Result FormatShape::Apply(Value& data) { - MMDEPLOY_DEBUG("input: {}", data); - - if (!data.is_array()) { - MMDEPLOY_ERROR("input of format shape should be array"); - return Status(eInvalidArgument); - } - if (!(data[0].contains("imgs") || data[0].contains("img"))) { - MMDEPLOY_ERROR("input should contains imgs or img"); - return Status(eInvalidArgument); - } - - int n_image = data.size(); - int clip_len = data[0]["clip_len"].get(); - int num_clips = data[0]["num_clips"].get(); - std::vector images; - - if (data[0].contains("imgs")) { - int n_crop = data[0]["imgs"].size(); - int total = n_image * n_crop; - images.reserve(total); - for (int i = 0; i < n_crop; i++) { - for (int j = 0; j < n_image; j++) { - images.push_back(data[j]["imgs"][i].get()); - } +namespace mmdeploy::mmaction +{ + + FormatShape::FormatShape(const Value& args) + { + input_format_ = args.value("input_format", std::string("")); + if (input_format_ != "NCHW" && input_format_ != "NCTHW") + { + MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'"); + throw_exception(eInvalidArgument); + } + permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); } - } else if (data[0].contains("img")) { - images.reserve(n_image); - for (int i = 0; i < n_image; i++) { - images.push_back(data[i]["img"].get()); + + Result FormatShape::MergeInputs(const std::vector& images, Tensor& inputs) + { + auto N = static_cast(images.size()); + auto H = images[0].shape(1); + auto W = images[0].shape(2); + auto C = images[0].shape(3); + auto& device = operation::gContext().device(); + auto& stream = operation::gContext().stream(); + + TensorDesc desc = {device, DataType::kFLOAT, {N, H, W, C}}; + inputs = Tensor(desc); + auto offset = 0UL; + auto n_item = H * W * C; + auto copy_size = n_item * sizeof(float); + for (int i = 0; i < N; i++) + { + auto src_buffer = images[i].buffer(); + auto dst_buffer = inputs.buffer(); + OUTCOME_TRY(stream.Copy(src_buffer, dst_buffer, copy_size, 0, offset)); + offset += copy_size; + } + return success(); + } + + Result FormatShape::Format(const std::vector& images, Tensor& output, int clip_len, int num_clips) + { + Tensor inputs; + OUTCOME_TRY(MergeInputs(images, inputs)); + + // Tensor dst; + if (input_format_ == "NCHW") + { + OUTCOME_TRY(FormatNCHW(inputs, clip_len, num_clips, output)); + } + if (input_format_ == "NCTHW") + { + OUTCOME_TRY(FormatNCTHW(inputs, clip_len, num_clips, output)); + } + + TensorShape expand_dim = output.shape(); + expand_dim.insert(expand_dim.begin(), 1); + output.Reshape(expand_dim); + + return success(); } - } - Tensor dst; - data = Value{}; - OUTCOME_TRY(Format(images, dst, clip_len, num_clips)); - data["img"] = std::move(dst); + Result FormatShape::FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) + { + const vector axes = {0, 3, 1, 2}; + OUTCOME_TRY(permute_.Apply(src, dst, axes)); + return success(); + } - return success(); -} + Result FormatShape::FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) + { + auto N = src.shape(0); + auto H = src.shape(1); + auto W = src.shape(2); + auto C = src.shape(3); + int L = clip_len; + if (N % L != 0) + { + return Status(eInvalidArgument); + } + int M = N / L; + src.Reshape({M, L, H, W, C}); + const vector axes = {0, 4, 1, 2, 3}; + OUTCOME_TRY(permute_.Apply(src, dst, axes)); + return success(); + } + + Result FormatShape::Apply(Value& data) + { + MMDEPLOY_DEBUG("input: {}", data); + + if (!data.is_array()) + { + MMDEPLOY_ERROR("input of format shape should be array"); + return Status(eInvalidArgument); + } + if (!(data[0].contains("imgs") || data[0].contains("img"))) + { + MMDEPLOY_ERROR("input should contains imgs or img"); + return Status(eInvalidArgument); + } + + int n_image = data.size(); + int clip_len = data[0]["clip_len"].get(); + int num_clips = data[0]["num_clips"].get(); + std::vector images; + + if (data[0].contains("imgs")) + { + int n_crop = data[0]["imgs"].size(); + int total = n_image * n_crop; + images.reserve(total); + for (int i = 0; i < n_crop; i++) + { + for (int j = 0; j < n_image; j++) + { + images.push_back(data[j]["imgs"][i].get()); + } + } + } + else if (data[0].contains("img")) + { + images.reserve(n_image); + for (int i = 0; i < n_image; i++) + { + images.push_back(data[i]["img"].get()); + } + } + + Tensor dst; + data = Value{}; + OUTCOME_TRY(Format(images, dst, clip_len, num_clips)); + data["img"] = std::move(dst); + + return success(); + } -MMDEPLOY_REGISTER_TRANSFORM(FormatShape); + MMDEPLOY_REGISTER_TRANSFORM(FormatShape); } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/format_shape.h b/csrc/mmdeploy/codebase/mmaction/format_shape.h index 97e4f99356..7ea0326c84 100644 --- a/csrc/mmdeploy/codebase/mmaction/format_shape.h +++ b/csrc/mmdeploy/codebase/mmaction/format_shape.h @@ -12,27 +12,28 @@ #include "mmdeploy/operation/vision.h" #include "mmdeploy/preprocess/transform/transform.h" -namespace mmdeploy::mmaction { +namespace mmdeploy::mmaction +{ -class FormatShape : public Transform { - public: - explicit FormatShape(const Value& args); + class FormatShape : public Transform + { + public: + explicit FormatShape(const Value& args); - Result Apply(Value& data) override; + Result Apply(Value& data) override; - Result Format(const std::vector& images, Tensor& output, int clip_len, - int num_clips); + Result Format(const std::vector& images, Tensor& output, int clip_len, int num_clips); - Result FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); + Result FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); - Result FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); + Result FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); - Result MergeInputs(const std::vector& images, Tensor& inputs); + Result MergeInputs(const std::vector& images, Tensor& inputs); - private: - std::string input_format_; - operation::Managed permute_; -}; + private: + std::string input_format_; + operation::Managed permute_; + }; } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/mmaction.cpp b/csrc/mmdeploy/codebase/mmaction/mmaction.cpp index dc590a1800..7de226ecd1 100644 --- a/csrc/mmdeploy/codebase/mmaction/mmaction.cpp +++ b/csrc/mmdeploy/codebase/mmaction/mmaction.cpp @@ -2,8 +2,9 @@ #include "mmdeploy/codebase/mmaction/mmaction.h" -namespace mmdeploy::mmaction { +namespace mmdeploy::mmaction +{ -MMDEPLOY_REGISTER_CODEBASE(MMAction); + MMDEPLOY_REGISTER_CODEBASE(MMAction); } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/mmaction.h b/csrc/mmdeploy/codebase/mmaction/mmaction.h index ef097e6f20..a3add86894 100644 --- a/csrc/mmdeploy/codebase/mmaction/mmaction.h +++ b/csrc/mmdeploy/codebase/mmaction/mmaction.h @@ -8,17 +8,19 @@ #include "mmdeploy/core/module.h" #include "mmdeploy/core/serialization.h" -namespace mmdeploy::mmaction { +namespace mmdeploy::mmaction +{ -struct Label { - int label_id; - float score; - MMDEPLOY_ARCHIVE_MEMBERS(label_id, score); -}; + struct Label + { + int label_id; + float score; + MMDEPLOY_ARCHIVE_MEMBERS(label_id, score); + }; -using Labels = std::vector