Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZLUDA v3.8.7 #66

Merged
merged 3 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

131 changes: 66 additions & 65 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,65 +1,66 @@
[workspace]

resolver = "2"

# Remember to also update the project's Cargo.toml
# if it's a top-level project
members = [
"atiadlxx-sys",
"comgr",
"cuda_base",
"cuda_types",
"detours-sys",
"ext/llvm-sys.rs",
"hip_common",
"hip_runtime-sys",
"hipblaslt-sys",
"hipfft-sys",
"hiprt-sys",
"miopen-sys",
"offline_compiler",
"optix_base",
"optix_dump",
"process_address_table",
"ptx",
"rocblas-sys",
"rocm_smi-sys",
"rocsparse-sys",
"xtask",
"zluda",
"zluda_api",
"zluda_blas",
"zluda_blaslt",
"zluda_ccl",
"zluda_dark_api",
"zluda_dnn",
"zluda_dump",
"zluda_fft",
"zluda_inject",
"zluda_lib",
"zluda_llvm",
"zluda_ml",
"zluda_redirect",
"zluda_rt",
"zluda_rtc",
"zluda_runtime",
"zluda_sparse",
]

# Cargo does not support OS-specific or profile-specific
# targets. We keep list here to bare minimum and rely on xtask
default-members = [
"zluda_lib",
"zluda_ml",
"zluda_inject",
"zluda_redirect"
]

[profile.dev.package.blake3]
opt-level = 3

[profile.dev.package.lz4-sys]
opt-level = 3

[profile.dev.package.xtask]
opt-level = 2
[workspace]

resolver = "2"

# Remember to also update the project's Cargo.toml
# if it's a top-level project
members = [
"atiadlxx-sys",
"comgr",
"cuda_base",
"cuda_types",
"detours-sys",
"ext/llvm-sys.rs",
"hip_common",
"hip_runtime-sys",
"hipblaslt-sys",
"hipfft-sys",
"hiprt-sys",
"miopen-sys",
"offline_compiler",
"optix_base",
"optix_dump",
"process_address_table",
"ptx",
"rocblas-sys",
"rocm_smi-sys",
"rocsparse-sys",
"xtask",
"zluda",
"zluda_api",
"zluda_blas",
"zluda_blaslt",
"zluda_ccl",
"zluda_dark_api",
"zluda_dnn",
"zluda_dump",
"zluda_fft",
"zluda_fftw",
"zluda_inject",
"zluda_lib",
"zluda_llvm",
"zluda_ml",
"zluda_redirect",
"zluda_rt",
"zluda_rtc",
"zluda_runtime",
"zluda_sparse",
]

# Cargo does not support OS-specific or profile-specific
# targets. We keep list here to bare minimum and rely on xtask
default-members = [
"zluda_lib",
"zluda_ml",
"zluda_inject",
"zluda_redirect"
]

[profile.dev.package.blake3]
opt-level = 3

[profile.dev.package.lz4-sys]
opt-level = 3

[profile.dev.package.xtask]
opt-level = 2
1 change: 1 addition & 0 deletions hipblaslt-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ impl hipblasOperation_t {
impl hipblasOperation_t {
pub const HIPBLAS_OP_C: hipblasOperation_t = hipblasOperation_t(113);
}
#[allow(non_camel_case_types)]
#[repr(transparent)]
#[derive(Copy, Clone, Hash, PartialEq, Eq)]
pub struct hipblasOperation_t(pub ::std::os::raw::c_int);
107 changes: 12 additions & 95 deletions zluda_blas/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#![allow(warnings)]
#[allow(warnings)]
mod common;
#[allow(warnings)]
mod cublas;
#[allow(warnings)]
mod cublasxt;

pub use common::*;
Expand All @@ -13,7 +15,7 @@ use rocsolver_sys::{
rocsolver_cgetrf_batched, rocsolver_cgetri_outofplace_batched, rocsolver_dgetrs_batched,
rocsolver_sgetrs_batched, rocsolver_zgetrf_batched, rocsolver_zgetri_outofplace_batched,
};
use std::{mem, ptr};
use std::ptr;

#[cfg(debug_assertions)]
pub(crate) fn unsupported() -> cublasStatus_t {
Expand Down Expand Up @@ -223,61 +225,20 @@ unsafe fn set_stream(handle: cublasHandle_t, stream_id: cudaStream_t) -> cublasS
) -> CUresult>(b"cuGetExportTable\0")
.unwrap();
let mut export_table = ptr::null();
(cu_get_export_table)(&mut export_table, &zluda_dark_api::ZludaExt::GUID);
assert_eq!(
(cu_get_export_table)(&mut export_table, &zluda_dark_api::ZludaExt::GUID),
CUresult::CUDA_SUCCESS
);
let zluda_ext = zluda_dark_api::ZludaExt::new(export_table);
let stream: Result<_, _> = zluda_ext.get_hip_stream(stream_id as _).into();
to_cuda(rocblas_set_stream(handle as _, stream.unwrap() as _))
}

fn set_math_mode(handle: cublasHandle_t, mode: cublasMath_t) -> cublasStatus_t {
fn set_math_mode(_handle: cublasHandle_t, _mode: cublasMath_t) -> cublasStatus_t {
// llama.cpp uses CUBLAS_TF32_TENSOR_OP_MATH
cublasStatus_t::CUBLAS_STATUS_SUCCESS
}

unsafe fn sgemm(
transa: std::ffi::c_char,
transb: std::ffi::c_char,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
) -> cublasStatus_t {
let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
status = to_cuda(rocblas_sgemm(
handle.cast(),
transa,
transb,
m,
n,
k,
&alpha,
a,
lda,
b,
ldb,
&beta,
c,
ldc,
));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
to_cuda(rocblas_destroy_handle(*handle))
}

unsafe fn sgemm_v2(
handle: cublasHandle_t,
transa: cublasOperation_t,
Expand Down Expand Up @@ -495,7 +456,7 @@ unsafe fn gemm_ex(
))
}

fn to_algo(algo: cublasGemmAlgo_t) -> rocblas_gemm_algo_ {
fn to_algo(_algo: cublasGemmAlgo_t) -> rocblas_gemm_algo_ {
// only option
rocblas_gemm_algo::rocblas_gemm_algo_standard
}
Expand Down Expand Up @@ -807,7 +768,7 @@ unsafe fn sgetrs_batched(
dev_ipiv: *const i32,
b: *const *mut f32,
ldb: i32,
info: *mut i32,
_info: *mut i32,
batch_size: i32,
) -> cublasStatus_t {
let trans = op_from_cuda_for_solver(trans);
Expand Down Expand Up @@ -837,7 +798,7 @@ unsafe fn dgetrs_batched(
dev_ipiv: *const i32,
b: *const *mut f64,
ldb: i32,
info: *mut i32,
_info: *mut i32,
batch_size: i32,
) -> cublasStatus_t {
let trans = op_from_cuda_for_solver(trans);
Expand Down Expand Up @@ -1048,50 +1009,6 @@ unsafe fn dger(
))
}

unsafe fn dgemm(
transa: std::ffi::c_char,
transb: std::ffi::c_char,
m: i32,
n: i32,
k: i32,
alpha: f64,
a: *const f64,
lda: i32,
b: *const f64,
ldb: i32,
beta: f64,
c: *mut f64,
ldc: i32,
) -> cublasStatus_t {
let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
status = to_cuda(rocblas_dgemm(
handle.cast(),
transa,
transb,
m,
n,
k,
&alpha,
a,
lda,
b,
ldb,
&beta,
c,
ldc,
));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
to_cuda(rocblas_destroy_handle(*handle))
}

unsafe fn dgemm_v2(
handle: *mut cublasContext,
transa: cublasOperation_t,
Expand Down
4 changes: 2 additions & 2 deletions zluda_fft/src/cufft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,15 @@ pub unsafe extern "system" fn cufftSetWorkArea(
plan: cufftHandle,
workArea: *mut ::std::os::raw::c_void,
) -> cufftResult {
crate::unsupported()
crate::set_work_area(plan, workArea)
}

#[no_mangle]
pub unsafe extern "system" fn cufftSetAutoAllocation(
plan: cufftHandle,
autoAllocate: ::std::os::raw::c_int,
) -> cufftResult {
crate::unsupported()
crate::set_auto_allocation(plan, autoAllocate)
}

#[no_mangle]
Expand Down
19 changes: 17 additions & 2 deletions zluda_fft/src/cufftxt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,22 @@ pub unsafe extern "system" fn cufftXtMakePlanMany(
workSize: *mut usize,
executiontype: cudaDataType,
) -> cufftResult {
crate::unsupported()
crate::xt_make_plan_many(
plan,
rank,
n,
inembed,
istride,
idist,
inputtype,
onembed,
ostride,
odist,
outputtype,
batch,
workSize,
executiontype,
)
}

#[no_mangle]
Expand Down Expand Up @@ -406,7 +421,7 @@ pub unsafe extern "system" fn cufftXtExec(
output: *mut ::std::os::raw::c_void,
direction: ::std::os::raw::c_int,
) -> cufftResult {
crate::unsupported()
crate::xt_exec(plan, input, output, direction)
}

#[no_mangle]
Expand Down
Loading