Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added preliminary support for cuda #10

Open
wants to merge 12 commits into
base: cc_build
Choose a base branch
from
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ opencl = ["mnn-sys/opencl"]

metal = ["mnn-sys/metal"]
coreml = ["mnn-sys/coreml"]
cuda = ["mnn-sys/cuda"]

vulkan = [] # This is currently unimplemented

Expand All @@ -44,7 +45,7 @@ serde = ["dep:serde"]

simd = ["mnn-sys/simd"]

default = ["simd"]
# default = ["simd"]


[dev-dependencies]
Expand Down
15 changes: 15 additions & 0 deletions benches/mnn-bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,19 @@ mod mnn_realesr_bench_with_ones {
net.wait(&session);
});
}

#[cfg(feature = "cuda")]
#[divan::bench]
pub fn mnn_realesr_benchmark_cuda(bencher: Bencher) {
let net = Interpreter::from_file("tests/assets/realesr.mnn").unwrap();
let mut config = ScheduleConfig::new();
config.set_type(ForwardType::Cuda);
let session = net.create_session(config).unwrap();
bencher.bench_local(|| {
let mut input = net.input(&session, "data").unwrap();
input.fill(1f32);
net.run_session(&session).unwrap();
net.wait(&session);
});
}
}
52 changes: 31 additions & 21 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
flake-utils.lib.eachDefaultSystem (
system: let
pkgs = import nixpkgs {
config.allowUnfree = true;
inherit system;
overlays = [
rust-overlay.overlays.default
Expand All @@ -67,33 +68,39 @@
extensions = ["rust-docs" "rust-src" "rust-analyzer"];
}
// (lib.optionalAttrs pkgs.stdenv.isDarwin {
targets = ["aarch64-apple-darwin" "x86_64-apple-darwin"];
targets = ["aarch64-apple-darwin" "x86_64-apple-darwin" "wasm32-unknown-unknown"];
}));
nightlyToolchain = pkgs.rust-bin.nightly.latest.default;
craneLib = (crane.mkLib pkgs).overrideToolchain rustToolchain;
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain rustToolchainWithLLvmTools;

src = lib.sources.sourceFilesBySuffices ./. [".rs" ".toml" ".patch" ".mnn" ".h" ".cpp" ".svg" "lock"];
src = lib.sources.sourceFilesBySuffices ./. [".rs" ".toml" ".patch" ".mnn" ".h" ".cpp" ".svg" ".lock"];
MNN_SRC = pkgs.applyPatches {
name = "mnn-src";
src = mnn-src;
patches = [./mnn-sys/patches/mnn-tracing.patch];
};
commonArgs = {
inherit src MNN_SRC;
stdenv = pkgs.clangStdenv;
pname = "mnn";
doCheck = false;
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
nativeBuildInputs = with pkgs; [
cmake
llvmPackages.libclang.lib
clang
pkg-config
];
buildInputs = with pkgs;
[]
nativeBuildInputs = with pkgs;
[
pkg-config
libclang.lib
]
++ (lib.optionals pkgs.stdenv.isLinux [
cudatoolkit
]);
LIBCLANG_PATH = "${pkgs.libclang.lib}/lib";
buildInputs = with pkgs;
(lib.optionals pkgs.stdenv.isLinux [
ocl-icd
opencl-headers
(lib.getDev cudaPackages.cuda_cudart)
(lib.getLib cudaPackages.cuda_cudart)
(lib.getStatic cudaPackages.cuda_cudart)
])
++ (lib.optionals pkgs.stdenv.isDarwin [
apple-sdk_13
Expand Down Expand Up @@ -157,18 +164,13 @@
# name = "mnn-leaks";
# cargoLock = {
# lockFile = ./Cargo.lock;
# outputHashes = {
# "cmake-0.1.50" = "sha256-GM2D7dpb2i2S6qYVM4HYk5B40TwKCmGQnUPfXksyf0M=";
# };
# };
#
# buildPhase = ''
# cargo test --target aarch64-apple-darwin
# cargo test --profile rwd --target aarch64-apple-darwin
# '';
# RUSTFLAGS = "-Zsanitizer=address";
# ASAN_OPTIONS = "detect_leaks=1";
# # MNN_COMPILE = "NO";
# # MNN_LIB_DIR = "${pkgs.mnn}/lib";
# }
# );
}
Expand Down Expand Up @@ -200,10 +202,13 @@
};

devShells = {
default = pkgs.mkShell (commonArgs
// {
default = pkgs.mkShell.override {stdenv = pkgs.clangStdenv;} (
{
MNN_SRC = null;
LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver";
nativeBuildInputs = commonArgs.nativeBuildInputs;
buildInputs = commonArgs.buildInputs;
LIBCLANG_PATH = commonArgs.LIBCLANG_PATH;
packages = with pkgs;
[
cargo-audit
Expand All @@ -220,14 +225,19 @@
rust-bindgen
google-cloud-sdk
rustToolchainWithRustAnalyzer
mnn
]
++ (
lib.optionals pkgs.stdenv.isLinux [
cudatoolkit
cargo-llvm-cov
]
);
# ++ (with packages; [bencher inspect]);
});
}
// lib.optionalAttrs pkgs.stdenv.isLinux {
CUDA_PATH = "${pkgs.cudatoolkit}";
}
);
};
}
)
Expand Down
1 change: 1 addition & 0 deletions mnn-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ thiserror = "2.0.3"
[features]
opencl = []

cuda = []
metal = []
coreml = ["metal"]
vulkan = []
Expand Down
109 changes: 98 additions & 11 deletions mnn-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ pub fn mnn_c_bindgen(vendor: impl AsRef<Path>, out: impl AsRef<Path>) -> Result<
.clang_arg(CxxOption::METAL.cxx())
.clang_arg(CxxOption::COREML.cxx())
.clang_arg(CxxOption::OPENCL.cxx())
.clang_arg(CxxOption::CUDA.cxx())
.pipe(|builder| {
if is_emscripten() {
println!("cargo:rustc-cdylib-link-arg=-fvisibility=default");
Expand Down Expand Up @@ -314,6 +315,7 @@ pub fn mnn_cpp_bindgen(vendor: impl AsRef<Path>, out: impl AsRef<Path>) -> Resul
.clang_arg(CxxOption::METAL.cxx())
.clang_arg(CxxOption::COREML.cxx())
.clang_arg(CxxOption::OPENCL.cxx())
.clang_arg(CxxOption::CUDA.cxx())
.clang_arg(format!("-I{}", vendor.join("include").to_string_lossy()))
.generate_cstr(true)
.generate_inline_functions(true)
Expand All @@ -327,9 +329,12 @@ pub fn mnn_cpp_bindgen(vendor: impl AsRef<Path>, out: impl AsRef<Path>) -> Resul
.join("Interpreter.hpp")
.to_string_lossy(),
)
// .header(
// vendor
// .join("include/MNN/MNNSharedContext.h")
// .to_string_lossy(),
// )
.allowlist_item(".*SessionInfoCode.*");
// let cmd = bindings.command_line_flags().join(" ");
// println!("cargo:warn=bindgen: {}", cmd);
let bindings = bindings.generate().change_context(Error)?;
bindings
.write_to_file(out.as_ref().join("mnn_cpp.rs"))
Expand All @@ -351,19 +356,17 @@ pub fn mnn_c_build(path: impl AsRef<Path>, vendor: impl AsRef<Path>) -> Result<(
let vendor = vendor.as_ref();
cc::Build::new()
.include(vendor.join("include"))
// .includes(vulkan_includes(vendor))
.pipe(|config| {
#[cfg(feature = "vulkan")]
config.define("MNN_VULKAN", "1");
#[cfg(feature = "metal")]
config.define("MNN_METAL", "1");
#[cfg(feature = "coreml")]
config.define("MNN_COREML", "1");
#[cfg(feature = "opencl")]
config.define("MNN_OPENCL", "ON");
CxxOption::COREML.define(config);
CxxOption::CUDA.define(config);
CxxOption::METAL.define(config);
CxxOption::OPENCL.define(config);
CxxOption::VULKAN.define(config);
if is_emscripten() {
config.compiler("emcc");
// We can't compile wasm32-unknown-unknown with emscripten
// emscripten works with cpu backend only so we are not sure if it would work with
// others at all
config.target("wasm32-unknown-emscripten");
config.cpp_link_stdlib("c++-noexcept");
}
Expand Down Expand Up @@ -463,6 +466,7 @@ impl CxxOption {
cxx_option_from_features! {
VULKAN => "vulkan", "MNN_VULKAN",
METAL => "metal", "MNN_METAL",
CUDA => "cuda", "MNN_CUDA",
COREML => "coreml", "MNN_COREML",
OPENCL => "opencl", "MNN_OPENCL",
CRT_STATIC => "crt_static", "MNN_WIN_RUNTIME_MT",
Expand Down Expand Up @@ -621,6 +625,7 @@ pub fn mnn_cpp_build(vendor: impl AsRef<Path>) -> Result<()> {

// CxxOption::VULKAN.define(&mut build);
// CxxOption::COREML.define(&mut build);
CxxOption::CUDA.define(&mut build);
CxxOption::METAL.define(&mut build);
CxxOption::OPENCL.define(&mut build);
CxxOption::CRT_STATIC.define(&mut build);
Expand Down Expand Up @@ -697,6 +702,8 @@ pub fn mnn_cpp_build(vendor: impl AsRef<Path>) -> Result<()> {
let build = opencl(build, vendor).change_context(Error)?;
#[cfg(feature = "metal")]
let build = metal(build, vendor).change_context(Error)?;
#[cfg(feature = "cuda")]
let build = cuda(build, vendor).change_context(Error)?;

build
.try_compile("mnn")
Expand Down Expand Up @@ -1037,3 +1044,83 @@ pub fn cc_builder() -> cc::Build {
.std("c++11")
.to_owned()
}

pub fn cuda(mut build: cc::Build, vendor: impl AsRef<Path>) -> Result<cc::Build> {
let cuda_dir = vendor.as_ref().join("source/backend/cuda");
let (cuda_files_cu, cuda_files_cpp): (Vec<_>, Vec<_>) =
ignore::WalkBuilder::new(cuda_dir.join("core"))
.add(cuda_dir.join("execution"))
.build()
.flatten()
.filter(|p| p.path().has_extension(["cpp", "cu"]))
.map(|e| e.into_path())
.filter(|p| {
!p.components()
.any(|component| component.as_os_str().eq("plugin"))
})
.filter(|p| {
!p.components()
.any(|component| component.as_os_str().eq("weight_only_quant"))
})
.partition(|p| p.has_extension(["cu"]));

fn cuda_compute(version: u8, enable: bool) -> impl FnOnce(&mut cc::Build) -> &mut cc::Build {
move |build: &mut cc::Build| {
if enable {
build.define(&format!("MNN_CUDA_ENABLE_SM{version}"), None);
}
build.flag("-gencode");
build.flag(&format!("arch=compute_{version},code=sm_{version}",))
}
}

let cuda_objects = cc::Build::new()
.cuda(true)
.cudart("static")
.flag("-m64")
.flag("--std")
.flag("c++11")
.flag("-w")
.flag("-O3")
.flag("-g")
.define("MNN_Cuda_Main_EXPORTS", None)
// .flag("--std=c++17")
// .flag("-O3")
.includes(mnn_includes(vendor.as_ref()))
.include(vendor.as_ref().join("3rd_party/cutlass/v2_9_0/include"))
.include(&cuda_dir)
.pipe(|b| {
if *TARGET_OS == "windows" {
b.flag("-Xcompiler").flag("/FS");
}
b
})
.pipe(cuda_compute(60, false))
.pipe(cuda_compute(61, false))
.pipe(cuda_compute(62, false))
.pipe(cuda_compute(70, false))
.pipe(cuda_compute(72, false))
.pipe(cuda_compute(75, true))
.pipe(cuda_compute(80, true))
.pipe(cuda_compute(86, true))
.pipe(cuda_compute(89, true))
.files(cuda_files_cu)
.try_compile_intermediates()
.change_context(Error)
.attach_printable("Failed to compile MNNCuda")?;

cc_builder()
.includes(mnn_includes(vendor.as_ref()))
.include(vendor.as_ref().join("3rd_party/cutlass/v2_9_0/include"))
.include(&cuda_dir)
.file(cuda_dir.join("Register.cpp"))
.files(cuda_files_cpp)
.objects(cuda_objects)
.cargo_debug(true)
.try_compile("MNNCuda")
.change_context(Error)
.attach_printable("Failed to compile cuda/Register.cpp")?;

CxxOption::CUDA.define(&mut build);
Ok(build)
}
2 changes: 1 addition & 1 deletion mnn-sys/vendor
Submodule vendor updated 263 files
Loading