diff --git a/crates/cust/src/memory/array.rs b/crates/cust/src/memory/array.rs index 36525b70..2d54f024 100644 --- a/crates/cust/src/memory/array.rs +++ b/crates/cust/src/memory/array.rs @@ -39,10 +39,10 @@ pub enum ArrayFormat { I16, /// Signed 32-bit integer I32, - /// Half-precision floating point number + /// Half-precision floating point number (f16) + F16, + /// Single-precision floating point number (f32) F32, - /// Single-precision floating point number - F64, } impl ArrayFormat { @@ -52,9 +52,8 @@ impl ArrayFormat { match self { U8 | I8 => 1, - U16 | I16 => 2, + U16 | I16 | F16 => 2, U32 | I32 | F32 => 4, - F64 => 8, } } } @@ -74,7 +73,6 @@ impl private::Sealed for i8 {} impl private::Sealed for i16 {} impl private::Sealed for i32 {} impl private::Sealed for f32 {} -impl private::Sealed for f64 {} impl ArrayPrimitive for u8 { fn array_format() -> ArrayFormat { @@ -118,12 +116,6 @@ impl ArrayPrimitive for f32 { } } -impl ArrayPrimitive for f64 { - fn array_format() -> ArrayFormat { - ArrayFormat::F64 - } -} - impl ArrayFormat { /// Creates ArrayFormat from the CUDA Driver API enum pub fn from_raw(raw: CUarray_format) -> Self { @@ -134,8 +126,8 @@ impl ArrayFormat { CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT8 => ArrayFormat::I8, CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT16 => ArrayFormat::I16, CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT32 => ArrayFormat::I32, - CUarray_format_enum::CU_AD_FORMAT_HALF => ArrayFormat::F32, - CUarray_format_enum::CU_AD_FORMAT_FLOAT => ArrayFormat::F64, + CUarray_format_enum::CU_AD_FORMAT_HALF => ArrayFormat::F16, + CUarray_format_enum::CU_AD_FORMAT_FLOAT => ArrayFormat::F32, // there are literally no docs on what nv12 is??? // it seems to be something with multiplanar arrays, needs some investigation CUarray_format_enum::CU_AD_FORMAT_NV12 => panic!("nv12 is not supported yet"), @@ -152,8 +144,8 @@ impl ArrayFormat { ArrayFormat::I8 => CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT8, ArrayFormat::I16 => CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT16, ArrayFormat::I32 => CUarray_format_enum::CU_AD_FORMAT_SIGNED_INT32, - ArrayFormat::F32 => CUarray_format_enum::CU_AD_FORMAT_HALF, - ArrayFormat::F64 => CUarray_format_enum::CU_AD_FORMAT_FLOAT, + ArrayFormat::F16 => CUarray_format_enum::CU_AD_FORMAT_HALF, + ArrayFormat::F32 => CUarray_format_enum::CU_AD_FORMAT_FLOAT, } } } @@ -921,11 +913,11 @@ mod test { fn descriptor_round_trip() { let _context = crate::quick_init().unwrap(); - let obj = ArrayObject::new([1, 2, 3], ArrayFormat::F64, 2).unwrap(); + let obj = ArrayObject::new([1, 2, 3], ArrayFormat::F32, 2).unwrap(); let descriptor = obj.descriptor().unwrap(); assert_eq!([1, 2, 3], descriptor.dims()); - assert_eq!(ArrayFormat::F64, descriptor.format()); + assert_eq!(ArrayFormat::F32, descriptor.format()); assert_eq!(2, descriptor.num_channels()); assert_eq!(ArrayObjectFlags::default(), descriptor.flags()); } @@ -934,7 +926,7 @@ mod test { fn allow_1d_arrays() { let _context = crate::quick_init().unwrap(); - let obj = ArrayObject::new([10, 0, 0], ArrayFormat::F64, 1).unwrap(); + let obj = ArrayObject::new([10, 0, 0], ArrayFormat::F32, 1).unwrap(); let descriptor = obj.descriptor().unwrap(); assert_eq!([10, 0, 0], descriptor.dims()); @@ -944,7 +936,7 @@ mod test { fn allow_2d_arrays() { let _context = crate::quick_init().unwrap(); - let obj = ArrayObject::new([10, 20, 0], ArrayFormat::F64, 1).unwrap(); + let obj = ArrayObject::new([10, 20, 0], ArrayFormat::F32, 1).unwrap(); let descriptor = obj.descriptor().unwrap(); assert_eq!([10, 20, 0], descriptor.dims()); @@ -954,7 +946,7 @@ mod test { fn allow_1d_layered_arrays() { let _context = crate::quick_init().unwrap(); - let obj = ArrayObject::new_layered([10, 0], 20, ArrayFormat::F64, 1).unwrap(); + let obj = ArrayObject::new_layered([10, 0], 20, ArrayFormat::F32, 1).unwrap(); let descriptor = obj.descriptor().unwrap(); assert_eq!([10, 0, 20], descriptor.dims()); @@ -965,7 +957,7 @@ mod test { fn allow_cubemaps() { let _context = crate::quick_init().unwrap(); - let obj = ArrayObject::new_cubemap(4, ArrayFormat::F64, 1).unwrap(); + let obj = ArrayObject::new_cubemap(4, ArrayFormat::F32, 1).unwrap(); let descriptor = obj.descriptor().unwrap(); assert_eq!([4, 4, 6], descriptor.dims()); @@ -976,7 +968,7 @@ mod test { fn allow_layered_cubemaps() { let _context = crate::quick_init().unwrap(); - let obj = ArrayObject::new_layered_cubemap(4, 4, ArrayFormat::F64, 1).unwrap(); + let obj = ArrayObject::new_layered_cubemap(4, 4, ArrayFormat::F32, 1).unwrap(); let descriptor = obj.descriptor().unwrap(); assert_eq!([4, 4, 24], descriptor.dims()); @@ -991,7 +983,7 @@ mod test { fn fail_on_zero_width_1d_array() { let _context = crate::quick_init().unwrap(); - let _ = ArrayObject::new_1d(0, ArrayFormat::F64, 1).unwrap(); + let _ = ArrayObject::new_1d(0, ArrayFormat::F32, 1).unwrap(); } #[test] @@ -999,7 +991,7 @@ mod test { fn fail_on_zero_size_widths() { let _context = crate::quick_init().unwrap(); - let _ = ArrayObject::new([0, 10, 20], ArrayFormat::F64, 1).unwrap(); + let _ = ArrayObject::new([0, 10, 20], ArrayFormat::F32, 1).unwrap(); } #[test] @@ -1007,7 +999,7 @@ mod test { fn fail_cubemaps_with_unmatching_width_height() { let _context = crate::quick_init().unwrap(); - let mut descriptor = ArrayDescriptor::from_dims_format([2, 3, 6], ArrayFormat::F64); + let mut descriptor = ArrayDescriptor::from_dims_format([2, 3, 6], ArrayFormat::F32); descriptor.set_flags(ArrayObjectFlags::CUBEMAP); let _ = ArrayObject::from_descriptor(&descriptor).unwrap(); @@ -1018,7 +1010,7 @@ mod test { fn fail_cubemaps_with_non_six_depth() { let _context = crate::quick_init().unwrap(); - let mut descriptor = ArrayDescriptor::from_dims_format([4, 4, 5], ArrayFormat::F64); + let mut descriptor = ArrayDescriptor::from_dims_format([4, 4, 5], ArrayFormat::F32); descriptor.set_flags(ArrayObjectFlags::CUBEMAP); let _ = ArrayObject::from_descriptor(&descriptor).unwrap(); @@ -1029,7 +1021,7 @@ mod test { fn fail_cubemaps_with_non_six_multiple_depth() { let _context = crate::quick_init().unwrap(); - let mut descriptor = ArrayDescriptor::from_dims_format([4, 4, 10], ArrayFormat::F64); + let mut descriptor = ArrayDescriptor::from_dims_format([4, 4, 10], ArrayFormat::F32); descriptor.set_flags(ArrayObjectFlags::LAYERED | ArrayObjectFlags::CUBEMAP); let _ = ArrayObject::from_descriptor(&descriptor).unwrap(); @@ -1040,7 +1032,7 @@ mod test { fn fail_with_depth_without_height() { let _context = crate::quick_init().unwrap(); - let _ = ArrayObject::new([10, 0, 20], ArrayFormat::F64, 1).unwrap(); + let _ = ArrayObject::new([10, 0, 20], ArrayFormat::F32, 1).unwrap(); } #[test] @@ -1048,6 +1040,6 @@ mod test { fn fails_on_invalid_num_channels() { let _context = crate::quick_init().unwrap(); - let _ = ArrayObject::new([1, 2, 3], ArrayFormat::F64, 3).unwrap(); + let _ = ArrayObject::new([1, 2, 3], ArrayFormat::F32, 3).unwrap(); } } diff --git a/crates/cust/src/texture.rs b/crates/cust/src/texture.rs index 60a918dd..47295726 100644 --- a/crates/cust/src/texture.rs +++ b/crates/cust/src/texture.rs @@ -238,11 +238,6 @@ impl ResourceViewFormat { format_impl!(num_channels, I16, I16x1, I16x2, I16x4); format_impl!(num_channels, I32, I32x1, I32x2, I32x4); format_impl!(num_channels, F32, F32x1, F32x2, F32x4); - assert_ne!( - format, - ArrayFormat::F64, - "CUDA Does not have 64 bit float textures, you can instead use int textures with 2 channels then cast the ints to a double in the kernel" - ); unreachable!() } }