diff --git a/src/buffer.rs b/src/buffer.rs index dedea35c582..7e767fac670 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -25,36 +25,55 @@ use std::ffi::{ c_ushort, c_void, }; use std::marker::PhantomData; +use std::ptr::NonNull; use std::{cell, mem, slice}; use std::{ffi::CStr, fmt::Debug}; +/// A typed form of [`PyUntypedBuffer`]. +#[repr(transparent)] +pub struct PyBuffer(PyUntypedBuffer, PhantomData<[T]>); + /// Allows access to the underlying buffer used by a python object such as `bytes`, `bytearray` or `array.array`. #[repr(transparent)] -pub struct PyBuffer(Box, PhantomData); +pub struct PyUntypedBuffer(Box); -// PyBuffer is thread-safe: the shape of the buffer is immutable while a Py_buffer exists. -// Accessing the buffer contents is protected using the GIL. -unsafe impl Send for PyBuffer {} -unsafe impl Sync for PyBuffer {} +// PyBuffer send & sync guarantees are upheld by Python. +unsafe impl Send for PyUntypedBuffer {} +unsafe impl Sync for PyUntypedBuffer {} impl Debug for PyBuffer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PyBuffer") - .field("buf", &self.0.buf) - .field("obj", &self.0.obj) - .field("len", &self.0.len) - .field("itemsize", &self.0.itemsize) - .field("readonly", &self.0.readonly) - .field("ndim", &self.0.ndim) - .field("format", &self.format()) - .field("shape", &self.shape()) - .field("strides", &self.strides()) - .field("suboffsets", &self.suboffsets()) - .field("internal", &self.0.internal) - .finish() + debug_buffer("PyBuffer", &self.0, f) + } +} + +impl Debug for PyUntypedBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + debug_buffer("PyUntypedBuffer", self, f) } } +fn debug_buffer( + name: &str, + b: &PyUntypedBuffer, + f: &mut std::fmt::Formatter<'_>, +) -> std::fmt::Result { + let raw = &b.0; + f.debug_struct(name) + .field("buf", &raw.buf) + .field("obj", &raw.obj) + .field("len", &raw.len) + .field("itemsize", &raw.itemsize) + .field("readonly", &raw.readonly) + .field("ndim", &raw.ndim) + .field("format", &b.format()) + .field("shape", &b.shape()) + .field("strides", &b.strides()) + .field("suboffsets", &b.suboffsets()) + .field("internal", &raw.internal) + .finish() +} + /// Represents the type of a Python buffer element. #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ElementType { @@ -193,42 +212,19 @@ impl FromPyObject<'_, '_> for PyBuffer { impl PyBuffer { /// Gets the underlying buffer from the specified python object. - pub fn get(obj: &Bound<'_, PyAny>) -> PyResult> { - // TODO: use nightly API Box::new_uninit() once our MSRV is 1.82 - let mut buf = Box::new(mem::MaybeUninit::uninit()); - let buf: Box = { - err::error_on_minusone(obj.py(), unsafe { - ffi::PyObject_GetBuffer(obj.as_ptr(), buf.as_mut_ptr(), ffi::PyBUF_FULL_RO) - })?; - // Safety: buf is initialized by PyObject_GetBuffer. - // TODO: use nightly API Box::assume_init() once our MSRV is 1.82 - unsafe { mem::transmute(buf) } - }; - // Create PyBuffer immediately so that if validation checks fail, the PyBuffer::drop code - // will call PyBuffer_Release (thus avoiding any leaks). - let buf = PyBuffer(buf, PhantomData); + pub fn get(obj: &Bound<'_, PyAny>) -> PyResult { + PyUntypedBuffer::get(obj)?.into_typed() + } - if buf.0.shape.is_null() { - Err(PyBufferError::new_err("shape is null")) - } else if buf.0.strides.is_null() { - Err(PyBufferError::new_err("strides is null")) - } else if mem::size_of::() != buf.item_size() || !T::is_compatible_format(buf.format()) { - Err(PyBufferError::new_err(format!( - "buffer contents are not compatible with {}", - std::any::type_name::() - ))) - } else if buf.0.buf.align_offset(mem::align_of::()) != 0 { - Err(PyBufferError::new_err(format!( - "buffer contents are insufficiently aligned for {}", - std::any::type_name::() - ))) - } else { - Ok(buf) - } + /// Releases the buffer object, freeing the reference to the Python object + /// which owns the buffer. + /// + /// This will automatically be called on drop. + #[inline] + pub fn release(self, py: Python<'_>) { + self.0.release(py) } -} -impl PyBuffer { /// Gets the pointer to the start of the buffer memory. /// /// Warning: the buffer memory can be mutated by other code (including @@ -240,54 +236,34 @@ impl PyBuffer { /// [this blog post]: https://alexgaynor.net/2022/oct/23/buffers-on-the-edge/ #[inline] pub fn buf_ptr(&self) -> *mut c_void { - self.0.buf + self.0.buf_ptr() } /// Gets a pointer to the specified item. /// /// If `indices.len() < self.dimensions()`, returns the start address of the sub-array at the specified dimension. + #[inline] pub fn get_ptr(&self, indices: &[usize]) -> *mut c_void { - let shape = &self.shape()[..indices.len()]; - for i in 0..indices.len() { - assert!(indices[i] < shape[i]); - } - unsafe { - ffi::PyBuffer_GetPointer( - #[cfg(Py_3_11)] - &*self.0, - #[cfg(not(Py_3_11))] - { - &*self.0 as *const ffi::Py_buffer as *mut ffi::Py_buffer - }, - #[cfg(Py_3_11)] - { - indices.as_ptr().cast() - }, - #[cfg(not(Py_3_11))] - { - indices.as_ptr() as *mut ffi::Py_ssize_t - }, - ) - } + self.0.get_ptr(indices) } /// Gets whether the underlying buffer is read-only. #[inline] pub fn readonly(&self) -> bool { - self.0.readonly != 0 + self.0.readonly() } /// Gets the size of a single element, in bytes. /// Important exception: when requesting an unformatted buffer, item_size still has the value #[inline] pub fn item_size(&self) -> usize { - self.0.itemsize as usize + self.0.item_size() } /// Gets the total number of items. #[inline] pub fn item_count(&self) -> usize { - (self.0.len as usize) / (self.0.itemsize as usize) + self.0.item_count() } /// `item_size() * item_count()`. @@ -295,7 +271,7 @@ impl PyBuffer { /// For non-contiguous arrays, it is the length that the logical structure would have if it were copied to a contiguous representation. #[inline] pub fn len_bytes(&self) -> usize { - self.0.len as usize + self.0.len_bytes() } /// Gets the number of dimensions. @@ -303,7 +279,7 @@ impl PyBuffer { /// May be 0 to indicate a single scalar value. #[inline] pub fn dimensions(&self) -> usize { - self.0.ndim as usize + self.0.dimensions() } /// Returns an array of length `dimensions`. `shape()[i]` is the length of the array in dimension number `i`. @@ -315,7 +291,7 @@ impl PyBuffer { /// However, dimensions of length 0 are possible and might need special attention. #[inline] pub fn shape(&self) -> &[usize] { - unsafe { slice::from_raw_parts(self.0.shape.cast(), self.0.ndim as usize) } + self.0.shape() } /// Returns an array that holds, for each dimension, the number of bytes to skip to get to the next element in the dimension. @@ -324,7 +300,7 @@ impl PyBuffer { /// but a consumer MUST be able to handle the case `strides[n] <= 0`. #[inline] pub fn strides(&self) -> &[isize] { - unsafe { slice::from_raw_parts(self.0.strides, self.0.ndim as usize) } + self.0.strides() } /// An array of length ndim. @@ -334,42 +310,27 @@ impl PyBuffer { /// If all suboffsets are negative (i.e. no de-referencing is needed), then this field must be NULL (the default value). #[inline] pub fn suboffsets(&self) -> Option<&[isize]> { - unsafe { - if self.0.suboffsets.is_null() { - None - } else { - Some(slice::from_raw_parts( - self.0.suboffsets, - self.0.ndim as usize, - )) - } - } + self.0.suboffsets() } - /// A NUL terminated string in struct module style syntax describing the contents of a single item. + /// A string in struct module style syntax describing the contents of a single item. #[inline] pub fn format(&self) -> &CStr { - if self.0.format.is_null() { - ffi::c_str!("B") - } else { - unsafe { CStr::from_ptr(self.0.format) } - } + self.0.format() } /// Gets whether the buffer is contiguous in C-style order (last index varies fastest when visiting items in order of memory address). #[inline] pub fn is_c_contiguous(&self) -> bool { - unsafe { ffi::PyBuffer_IsContiguous(&*self.0, b'C' as std::ffi::c_char) != 0 } + self.0.is_c_contiguous() } /// Gets whether the buffer is contiguous in Fortran-style order (first index varies fastest when visiting items in order of memory address). #[inline] pub fn is_fortran_contiguous(&self) -> bool { - unsafe { ffi::PyBuffer_IsContiguous(&*self.0, b'F' as std::ffi::c_char) != 0 } + self.0.is_fortran_contiguous() } -} -impl PyBuffer { /// Gets the buffer memory as a slice. /// /// This function succeeds if: @@ -383,7 +344,7 @@ impl PyBuffer { if self.is_c_contiguous() { unsafe { Some(slice::from_raw_parts( - self.0.buf as *mut ReadOnlyCell, + self.0 .0.buf as *mut ReadOnlyCell, self.item_count(), )) } @@ -406,7 +367,7 @@ impl PyBuffer { if !self.readonly() && self.is_c_contiguous() { unsafe { Some(slice::from_raw_parts( - self.0.buf as *mut cell::Cell, + self.0 .0.buf as *mut cell::Cell, self.item_count(), )) } @@ -428,7 +389,7 @@ impl PyBuffer { if mem::size_of::() == self.item_size() && self.is_fortran_contiguous() { unsafe { Some(slice::from_raw_parts( - self.0.buf as *mut ReadOnlyCell, + self.0 .0.buf as *mut ReadOnlyCell, self.item_count(), )) } @@ -451,7 +412,7 @@ impl PyBuffer { if !self.readonly() && self.is_fortran_contiguous() { unsafe { Some(slice::from_raw_parts( - self.0.buf as *mut cell::Cell, + self.0 .0.buf as *mut cell::Cell, self.item_count(), )) } @@ -499,12 +460,12 @@ impl PyBuffer { ffi::PyBuffer_ToContiguous( target.as_mut_ptr().cast(), #[cfg(Py_3_11)] - &*self.0, + &*self.0 .0, #[cfg(not(Py_3_11))] { &*self.0 as *const ffi::Py_buffer as *mut ffi::Py_buffer }, - self.0.len, + self.0 .0.len, fort as std::ffi::c_char, ) }) @@ -536,12 +497,12 @@ impl PyBuffer { ffi::PyBuffer_ToContiguous( vec.as_ptr() as *mut c_void, #[cfg(Py_3_11)] - &*self.0, + &*self.0 .0, #[cfg(not(Py_3_11))] { &*self.0 as *const ffi::Py_buffer as *mut ffi::Py_buffer }, - self.0.len, + self.0 .0.len, fort as std::ffi::c_char, ) })?; @@ -592,7 +553,7 @@ impl PyBuffer { err::error_on_minusone(py, unsafe { ffi::PyBuffer_FromContiguous( #[cfg(Py_3_11)] - &*self.0, + &*self.0 .0, #[cfg(not(Py_3_11))] { &*self.0 as *const ffi::Py_buffer as *mut ffi::Py_buffer @@ -605,25 +566,200 @@ impl PyBuffer { { source.as_ptr() as *mut c_void }, - self.0.len, + self.0 .0.len, fort as std::ffi::c_char, ) }) } +} - /// Releases the buffer object, freeing the reference to the Python object - /// which owns the buffer. - /// - /// This will automatically be called on drop. +impl std::ops::Deref for PyBuffer { + type Target = PyUntypedBuffer; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl PyUntypedBuffer { + /// See [`PyBuffer::get()`]. + pub fn get(obj: &Bound<'_, PyAny>) -> PyResult { + let mut buf = mem::MaybeUninit::uninit(); + let buf: Box = { + err::error_on_minusone(obj.py(), unsafe { + // TODO: add PyBufferRequest type which allows for controlling the buffer request? + // - require writable ? + // - require contiguous ? + // - is there ever a case that we need to handle producers which can't fill strides or shape? + ffi::PyObject_GetBuffer(obj.as_ptr(), buf.as_mut_ptr(), ffi::PyBUF_FULL_RO) + })?; + // SAFETY: PyObject_GetBuffer has initialized the buffer on success + Box::new(unsafe { buf.assume_init() }) + }; + // Create PyBuffer immediately so that if validation checks fail, the PyBuffer::drop code + // will call PyBuffer_Release (thus avoiding any leaks). + let buf = Self(buf); + let raw = &buf.0; + + if raw.shape.is_null() { + Err(PyBufferError::new_err("shape is null")) + } else if raw.strides.is_null() { + Err(PyBufferError::new_err("strides is null")) + } else { + Ok(buf) + } + } + + /// Returns a `[PyBuffer]` instance if the buffer can be interpreted as containing elements of type `T`. + pub fn into_typed(self) -> PyResult> { + self.ensure_compatible_with::()?; + Ok(PyBuffer(self, PhantomData)) + } + + /// Non-owning equivalent of [`into_typed()`][Self::into_typed]. + pub fn as_typed(&self) -> PyResult<&PyBuffer> { + self.ensure_compatible_with::()?; + // SAFETY: PyBuffer is repr(transparent) around PyUntypedBuffer + Ok(unsafe { NonNull::from(self).cast::>().as_ref() }) + } + + fn ensure_compatible_with(&self) -> PyResult<()> { + if mem::size_of::() != self.item_size() || !T::is_compatible_format(self.format()) { + Err(PyBufferError::new_err(format!( + "buffer contents are not compatible with {}", + std::any::type_name::() + ))) + } else if self.0.buf.align_offset(mem::align_of::()) != 0 { + Err(PyBufferError::new_err(format!( + "buffer contents are insufficiently aligned for {}", + std::any::type_name::() + ))) + } else { + Ok(()) + } + } + + /// See [`PyBuffer::release()`]. pub fn release(self, _py: Python<'_>) { // SAFETY: Self is `repr(transparent)` around a Box let mut inner: Box = unsafe { std::mem::transmute(self) }; // SAFETY: the ffi::Py_buffer structure is valid until this release call unsafe { ffi::PyBuffer_Release(&mut *inner) }; } + + /// See [`PyBuffer::buf_ptr()`]. + #[inline] + pub fn buf_ptr(&self) -> *mut c_void { + self.0.buf + } + + /// See [`PyBuffer::get_ptr()`]. + pub fn get_ptr(&self, indices: &[usize]) -> *mut c_void { + let shape = &self.shape()[..indices.len()]; + for i in 0..indices.len() { + assert!(indices[i] < shape[i]); + } + unsafe { + ffi::PyBuffer_GetPointer( + #[cfg(Py_3_11)] + &*self.0, + #[cfg(not(Py_3_11))] + { + &*self.0 as *const ffi::Py_buffer as *mut ffi::Py_buffer + }, + #[cfg(Py_3_11)] + { + indices.as_ptr().cast() + }, + #[cfg(not(Py_3_11))] + { + indices.as_ptr() as *mut ffi::Py_ssize_t + }, + ) + } + } + + /// See [`PyBuffer::readonly()`]. + #[inline] + pub fn readonly(&self) -> bool { + self.0.readonly != 0 + } + + /// See [`PyBuffer::item_size()`]. + #[inline] + pub fn item_size(&self) -> usize { + self.0.itemsize as usize + } + + /// See [`PyBuffer::item_count()`]. + #[inline] + pub fn item_count(&self) -> usize { + (self.0.len as usize) / (self.0.itemsize as usize) + } + + /// See [`PyBuffer::len_bytes()`]. + #[inline] + pub fn len_bytes(&self) -> usize { + self.0.len as usize + } + + /// See [`PyBuffer::dimensions()`]. + #[inline] + pub fn dimensions(&self) -> usize { + self.0.ndim as usize + } + + /// See [`PyBuffer::shape()`]. + #[inline] + pub fn shape(&self) -> &[usize] { + unsafe { slice::from_raw_parts(self.0.shape.cast(), self.0.ndim as usize) } + } + + /// See [`PyBuffer::strides()`]. + #[inline] + pub fn strides(&self) -> &[isize] { + unsafe { slice::from_raw_parts(self.0.strides, self.0.ndim as usize) } + } + + /// See [`PyBuffer::suboffsets()`]. + #[inline] + pub fn suboffsets(&self) -> Option<&[isize]> { + unsafe { + if self.0.suboffsets.is_null() { + None + } else { + Some(slice::from_raw_parts( + self.0.suboffsets, + self.0.ndim as usize, + )) + } + } + } + + /// See [`PyBuffer::format()`]. + #[inline] + pub fn format(&self) -> &CStr { + if self.0.format.is_null() { + ffi::c_str!("B") + } else { + unsafe { CStr::from_ptr(self.0.format) } + } + } + + /// See [`PyBuffer::is_c_contiguous()`]. + #[inline] + pub fn is_c_contiguous(&self) -> bool { + unsafe { ffi::PyBuffer_IsContiguous(&*self.0, b'C' as std::ffi::c_char) != 0 } + } + + /// See [`PyBuffer::is_fortran_contiguous()`]. + #[inline] + pub fn is_fortran_contiguous(&self) -> bool { + unsafe { ffi::PyBuffer_IsContiguous(&*self.0, b'F' as std::ffi::c_char) != 0 } + } } -impl Drop for PyBuffer { +impl Drop for PyUntypedBuffer { fn drop(&mut self) { fn inner(buf: &mut Box) { if Python::try_attach(|_| unsafe { ffi::PyBuffer_Release(buf.as_mut()) }).is_none() @@ -711,7 +847,9 @@ mod tests { "ndim: 1, format: \"B\", shape: [5], ", "strides: [1], suboffsets: None, internal: {:?} }}", ), - buffer.0.buf, buffer.0.obj, buffer.0.internal + buffer.buf_ptr(), + buffer.0 .0.obj, + buffer.0 .0.internal ); let debug_repr = format!("{:?}", buffer); assert_eq!(debug_repr, expected); @@ -941,4 +1079,22 @@ mod tests { assert_eq!(buffer.to_fortran_vec(py).unwrap(), [10.0, 11.0, 12.0, 13.0]); }); } + + #[test] + fn test_untyped_buffer() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"abcde"); + let untyped = PyUntypedBuffer::get(&bytes).unwrap(); + assert_eq!(untyped.dimensions(), 1); + assert_eq!(untyped.item_count(), 5); + assert_eq!(untyped.format().to_str().unwrap(), "B"); + assert_eq!(untyped.shape(), [5]); + + let typed: &PyBuffer = untyped.as_typed().unwrap(); + assert_eq!(typed.dimensions(), 1); + assert_eq!(typed.item_count(), 5); + assert_eq!(typed.format().to_str().unwrap(), "B"); + assert_eq!(typed.shape(), [5]); + }); + } }