Skip to content

Commit 90b14fb

Browse files
authored
Merge pull request #772 from andrewwhitehead/buffer-refs
Buffer protocol updates to support object references, custom release method
2 parents 45d892a + aae57e7 commit 90b14fb

File tree

2 files changed

+123
-28
lines changed

2 files changed

+123
-28
lines changed

src/class/buffer.rs

+45-7
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,27 @@
77
use crate::callback::UnitCallbackConverter;
88
use crate::err::PyResult;
99
use crate::ffi;
10-
use crate::type_object::PyTypeInfo;
10+
use crate::{PyClass, PyClassShell};
1111
use std::os::raw::c_int;
1212

1313
/// Buffer protocol interface
1414
///
1515
/// For more information check [buffer protocol](https://docs.python.org/3/c-api/buffer.html)
1616
/// c-api
1717
#[allow(unused_variables)]
18-
pub trait PyBufferProtocol<'p>: PyTypeInfo {
19-
fn bf_getbuffer(&'p self, view: *mut ffi::Py_buffer, flags: c_int) -> Self::Result
18+
pub trait PyBufferProtocol<'p>: PyClass {
19+
fn bf_getbuffer(
20+
slf: &mut PyClassShell<Self>,
21+
view: *mut ffi::Py_buffer,
22+
flags: c_int,
23+
) -> Self::Result
2024
where
2125
Self: PyBufferGetBufferProtocol<'p>,
2226
{
2327
unimplemented!()
2428
}
2529

26-
fn bf_releasebuffer(&'p self, view: *mut ffi::Py_buffer) -> Self::Result
30+
fn bf_releasebuffer(slf: &mut PyClassShell<Self>, view: *mut ffi::Py_buffer) -> Self::Result
2731
where
2832
Self: PyBufferReleaseBufferProtocol<'p>,
2933
{
@@ -59,7 +63,7 @@ where
5963
fn tp_as_buffer() -> Option<ffi::PyBufferProcs> {
6064
Some(ffi::PyBufferProcs {
6165
bf_getbuffer: Self::cb_bf_getbuffer(),
62-
bf_releasebuffer: None,
66+
bf_releasebuffer: Self::cb_bf_releasebuffer(),
6367
..ffi::PyBufferProcs_INIT
6468
})
6569
}
@@ -94,11 +98,45 @@ where
9498
{
9599
let py = crate::Python::assume_gil_acquired();
96100
let _pool = crate::GILPool::new(py);
97-
let slf = py.mut_from_borrowed_ptr::<T>(slf);
101+
let slf = &mut *(slf as *mut PyClassShell<T>);
98102

99-
let result = slf.bf_getbuffer(arg1, arg2).into();
103+
let result = T::bf_getbuffer(slf, arg1, arg2).into();
100104
crate::callback::cb_convert(UnitCallbackConverter, py, result)
101105
}
102106
Some(wrap::<T>)
103107
}
104108
}
109+
110+
trait PyBufferReleaseBufferProtocolImpl {
111+
fn cb_bf_releasebuffer() -> Option<ffi::releasebufferproc>;
112+
}
113+
114+
impl<'p, T> PyBufferReleaseBufferProtocolImpl for T
115+
where
116+
T: PyBufferProtocol<'p>,
117+
{
118+
default fn cb_bf_releasebuffer() -> Option<ffi::releasebufferproc> {
119+
None
120+
}
121+
}
122+
123+
impl<T> PyBufferReleaseBufferProtocolImpl for T
124+
where
125+
T: for<'p> PyBufferReleaseBufferProtocol<'p>,
126+
{
127+
#[inline]
128+
fn cb_bf_releasebuffer() -> Option<ffi::releasebufferproc> {
129+
unsafe extern "C" fn wrap<T>(slf: *mut ffi::PyObject, arg1: *mut ffi::Py_buffer)
130+
where
131+
T: for<'p> PyBufferReleaseBufferProtocol<'p>,
132+
{
133+
let py = crate::Python::assume_gil_acquired();
134+
let _pool = crate::GILPool::new(py);
135+
let slf = &mut *(slf as *mut PyClassShell<T>);
136+
137+
let result = T::bf_releasebuffer(slf, arg1).into();
138+
crate::callback::cb_convert(UnitCallbackConverter, py, result);
139+
}
140+
Some(wrap::<T>)
141+
}
142+
}

tests/test_buffer_protocol.rs

+78-21
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,43 @@
1+
use pyo3::buffer::PyBuffer;
12
use pyo3::class::PyBufferProtocol;
23
use pyo3::exceptions::BufferError;
34
use pyo3::ffi;
45
use pyo3::prelude::*;
56
use pyo3::types::IntoPyDict;
7+
use pyo3::{AsPyPointer, PyClassShell};
68
use std::ffi::CStr;
79
use std::os::raw::{c_int, c_void};
810
use std::ptr;
11+
use std::sync::atomic::{AtomicBool, Ordering};
12+
use std::sync::Arc;
913

1014
#[pyclass]
11-
struct TestClass {
15+
struct TestBufferClass {
1216
vec: Vec<u8>,
17+
drop_called: Arc<AtomicBool>,
1318
}
1419

1520
#[pyproto]
16-
impl PyBufferProtocol for TestClass {
17-
fn bf_getbuffer(&self, view: *mut ffi::Py_buffer, flags: c_int) -> PyResult<()> {
21+
impl PyBufferProtocol for TestBufferClass {
22+
fn bf_getbuffer(
23+
slf: &mut PyClassShell<Self>,
24+
view: *mut ffi::Py_buffer,
25+
flags: c_int,
26+
) -> PyResult<()> {
1827
if view.is_null() {
1928
return Err(BufferError::py_err("View is null"));
2029
}
2130

22-
unsafe {
23-
(*view).obj = ptr::null_mut();
24-
}
25-
2631
if (flags & ffi::PyBUF_WRITABLE) == ffi::PyBUF_WRITABLE {
2732
return Err(BufferError::py_err("Object is not writable"));
2833
}
2934

30-
let bytes = &self.vec;
35+
unsafe {
36+
(*view).obj = slf.as_ptr();
37+
ffi::Py_INCREF((*view).obj);
38+
}
39+
40+
let bytes = &slf.vec;
3141

3242
unsafe {
3343
(*view).buf = bytes.as_ptr() as *mut c_void;
@@ -58,21 +68,68 @@ impl PyBufferProtocol for TestClass {
5868

5969
Ok(())
6070
}
71+
72+
fn bf_releasebuffer(_slf: &mut PyClassShell<Self>, _view: *mut ffi::Py_buffer) -> PyResult<()> {
73+
Ok(())
74+
}
75+
}
76+
77+
impl Drop for TestBufferClass {
78+
fn drop(&mut self) {
79+
print!("dropped");
80+
self.drop_called.store(true, Ordering::Relaxed);
81+
}
6182
}
6283

6384
#[test]
6485
fn test_buffer() {
65-
let gil = Python::acquire_gil();
66-
let py = gil.python();
67-
68-
let t = Py::new(
69-
py,
70-
TestClass {
71-
vec: vec![b' ', b'2', b'3'],
72-
},
73-
)
74-
.unwrap();
75-
76-
let d = [("ob", t)].into_py_dict(py);
77-
py.run("assert bytes(ob) == b' 23'", None, Some(d)).unwrap();
86+
let drop_called = Arc::new(AtomicBool::new(false));
87+
88+
{
89+
let gil = Python::acquire_gil();
90+
let py = gil.python();
91+
let instance = Py::new(
92+
py,
93+
TestBufferClass {
94+
vec: vec![b' ', b'2', b'3'],
95+
drop_called: drop_called.clone(),
96+
},
97+
)
98+
.unwrap();
99+
let env = [("ob", instance)].into_py_dict(py);
100+
py.run("assert bytes(ob) == b' 23'", None, Some(env))
101+
.unwrap();
102+
}
103+
104+
assert!(drop_called.load(Ordering::Relaxed));
105+
}
106+
107+
#[test]
108+
fn test_buffer_referenced() {
109+
let drop_called = Arc::new(AtomicBool::new(false));
110+
111+
let buf = {
112+
let input = vec![b' ', b'2', b'3'];
113+
let gil = Python::acquire_gil();
114+
let py = gil.python();
115+
let instance: PyObject = TestBufferClass {
116+
vec: input.clone(),
117+
drop_called: drop_called.clone(),
118+
}
119+
.into_py(py);
120+
121+
let buf = PyBuffer::get(py, instance.as_ref(py)).unwrap();
122+
assert_eq!(buf.to_vec::<u8>(py).unwrap(), input);
123+
drop(instance);
124+
buf
125+
};
126+
127+
assert!(!drop_called.load(Ordering::Relaxed));
128+
129+
{
130+
let _py = Python::acquire_gil().python();
131+
drop(buf);
132+
}
133+
134+
assert!(drop_called.load(Ordering::Relaxed));
78135
}

0 commit comments

Comments
 (0)