Skip to content

Commit 3ba4a05

Browse files
committed
Add support for arbitrary arrays
1 parent 4713b46 commit 3ba4a05

File tree

9 files changed

+306
-108
lines changed

9 files changed

+306
-108
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1212
- Extend `hashbrown` optional dependency supported versions to include 0.11. [#1496](https://github.com/PyO3/pyo3/pull/1496)
1313

1414
### Added
15+
- Add conversion for `[T; N]` for all `N` on Rust 1.51 and up. [#1128](https://github.com/PyO3/pyo3/pull/1128)
1516
- Add conversions between `OsStr`/`OsString`/`Path`/`PathBuf` and Python strings. [#1379](https://github.com/PyO3/pyo3/pull/1379)
1617
- Add `#[pyo3(from_py_with = "...")]` attribute for function arguments and struct fields to override the default from-Python conversion. [#1411](https://github.com/PyO3/pyo3/pull/1411)
1718
- Add FFI definition `PyCFunction_CheckExact` for Python 3.9 and later. [#1425](https://github.com/PyO3/pyo3/pull/1425)

build.rs

+19
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,23 @@ fn abi3_without_interpreter() -> Result<()> {
850850
Ok(())
851851
}
852852

853+
fn rustc_minor_version() -> Option<u32> {
854+
let rustc = env::var_os("RUSTC")?;
855+
let output = Command::new(rustc).arg("--version").output().ok()?;
856+
let version = core::str::from_utf8(&output.stdout).ok()?;
857+
let mut pieces = version.split('.');
858+
if pieces.next() != Some("rustc 1") {
859+
return None;
860+
}
861+
pieces.next()?.parse().ok()
862+
}
863+
864+
fn manage_min_const_generics() {
865+
if rustc_minor_version().unwrap_or(0) >= 51 {
866+
println!("cargo:rustc-cfg=min_const_generics");
867+
}
868+
}
869+
853870
fn main_impl() -> Result<()> {
854871
// If PYO3_NO_PYTHON is set with abi3, we can build PyO3 without calling Python.
855872
// We only check for the abi3-py3{ABI3_MAX_MINOR} because lower versions depend on it.
@@ -916,6 +933,8 @@ fn main_impl() -> Result<()> {
916933
println!("cargo:rustc-cfg=__pyo3_ci");
917934
}
918935

936+
manage_min_const_generics();
937+
919938
Ok(())
920939
}
921940

src/buffer.rs

+3-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
//! `PyBuffer` implementation
2020
use crate::err::{self, PyResult};
21+
use crate::utils::invalid_sequence_length;
2122
use crate::{exceptions, ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, Python};
2223
use std::ffi::CStr;
2324
use std::marker::PhantomData;
@@ -441,9 +442,7 @@ impl<T: Element> PyBuffer<T> {
441442

442443
fn copy_to_slice_impl(&self, py: Python, target: &mut [T], fort: u8) -> PyResult<()> {
443444
if mem::size_of_val(target) != self.len_bytes() {
444-
return Err(exceptions::PyBufferError::new_err(
445-
"Slice length does not match buffer length.",
446-
));
445+
return Err(invalid_sequence_length(self.item_count(), target.len()));
447446
}
448447
unsafe {
449448
err::error_on_minusone(
@@ -528,9 +527,7 @@ impl<T: Element> PyBuffer<T> {
528527
return buffer_readonly_error();
529528
}
530529
if mem::size_of_val(source) != self.len_bytes() {
531-
return Err(exceptions::PyBufferError::new_err(
532-
"Slice length does not match buffer length.",
533-
));
530+
return Err(invalid_sequence_length(source.len(), self.item_count()));
534531
}
535532
unsafe {
536533
err::error_on_minusone(

src/conversions/array.rs

+273
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
use crate::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject};
2+
3+
#[cfg(not(min_const_generics))]
4+
macro_rules! array_impls {
5+
($($N:expr),+) => {
6+
$(
7+
impl<'a, T> FromPyObject<'a> for [T; $N]
8+
where
9+
T: Copy + Default + FromPyObject<'a>,
10+
{
11+
#[cfg(not(feature = "nightly"))]
12+
fn extract(obj: &'a PyAny) -> PyResult<Self> {
13+
let mut array = [T::default(); $N];
14+
_extract_sequence_into_slice(obj, &mut array)?;
15+
Ok(array)
16+
}
17+
18+
#[cfg(feature = "nightly")]
19+
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
20+
let mut array = [T::default(); $N];
21+
_extract_sequence_into_slice(obj, &mut array)?;
22+
Ok(array)
23+
}
24+
}
25+
26+
#[cfg(feature = "nightly")]
27+
impl<'source, T> FromPyObject<'source> for [T; $N]
28+
where
29+
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
30+
{
31+
fn extract(obj: &'source PyAny) -> PyResult<Self> {
32+
let mut array = [T::default(); $N];
33+
// first try buffer protocol
34+
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
35+
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
36+
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
37+
buf.release(obj.py());
38+
return Ok(array);
39+
}
40+
buf.release(obj.py());
41+
}
42+
}
43+
// fall back to sequence protocol
44+
_extract_sequence_into_slice(obj, &mut array)?;
45+
Ok(array)
46+
}
47+
}
48+
)+
49+
}
50+
}
51+
52+
#[cfg(not(min_const_generics))]
53+
array_impls!(
54+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
55+
26, 27, 28, 29, 30, 31, 32
56+
);
57+
58+
#[cfg(min_const_generics)]
59+
impl<'a, T, const N: usize> FromPyObject<'a> for [T; N]
60+
where
61+
T: FromPyObject<'a>,
62+
{
63+
#[cfg(not(feature = "nightly"))]
64+
fn extract(obj: &'a PyAny) -> PyResult<Self> {
65+
create_array_from_obj(obj)
66+
}
67+
68+
#[cfg(feature = "nightly")]
69+
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
70+
create_array_from_obj(obj)
71+
}
72+
}
73+
74+
#[cfg(not(min_const_generics))]
75+
macro_rules! array_impls {
76+
($($N:expr),+) => {
77+
$(
78+
impl<T> IntoPy<PyObject> for [T; $N]
79+
where
80+
T: ToPyObject
81+
{
82+
fn into_py(self, py: Python) -> PyObject {
83+
self.as_ref().to_object(py)
84+
}
85+
}
86+
)+
87+
}
88+
}
89+
90+
#[cfg(not(min_const_generics))]
91+
array_impls!(
92+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
93+
26, 27, 28, 29, 30, 31, 32
94+
);
95+
96+
#[cfg(min_const_generics)]
97+
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
98+
where
99+
T: ToPyObject,
100+
{
101+
fn into_py(self, py: Python) -> PyObject {
102+
self.as_ref().to_object(py)
103+
}
104+
}
105+
106+
#[cfg(all(min_const_generics, feature = "nightly"))]
107+
impl<'source, T, const N: usize> FromPyObject<'source> for [T; N]
108+
where
109+
for<'a> T: FromPyObject<'a> + crate::buffer::Element,
110+
{
111+
fn extract(obj: &'source PyAny) -> PyResult<Self> {
112+
let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
113+
// first try buffer protocol
114+
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
115+
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
116+
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
117+
buf.release(obj.py());
118+
// SAFETY: The array should be fully filled by `copy_to_slice`
119+
return Ok(unsafe { array.assume_init() });
120+
}
121+
buf.release(obj.py());
122+
}
123+
}
124+
// fall back to sequence protocol
125+
_extract_sequence_into_slice(obj, &mut array)?;
126+
// SAFETY: The array should be fully filled by `_extract_sequence_into_slice`
127+
Ok(unsafe { array.assume_init() })
128+
}
129+
}
130+
131+
#[cfg(min_const_generics)]
132+
fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]>
133+
where
134+
T: FromPyObject<'s>,
135+
{
136+
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
137+
let expected_len = seq.len()? as usize;
138+
let mut counter = 0;
139+
try_create_array(&mut counter, |idx| {
140+
seq.get_item(idx as isize)
141+
.map_err(|_| crate::utils::invalid_sequence_length(expected_len, idx + 1))?
142+
.extract::<T>()
143+
})
144+
}
145+
146+
fn _extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
147+
where
148+
T: FromPyObject<'s>,
149+
{
150+
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
151+
let expected_len = seq.len()? as usize;
152+
if expected_len != slice.len() {
153+
return Err(crate::utils::invalid_sequence_length(
154+
expected_len,
155+
slice.len(),
156+
));
157+
}
158+
for (value, item) in slice.iter_mut().zip(seq.iter()?) {
159+
*value = item?.extract::<T>()?;
160+
}
161+
Ok(())
162+
}
163+
164+
#[cfg(min_const_generics)]
165+
fn try_create_array<E, F, T, const N: usize>(counter: &mut usize, mut cb: F) -> Result<[T; N], E>
166+
where
167+
F: FnMut(usize) -> Result<T, E>,
168+
{
169+
// Helper to safely create arrays since the standard library doesn't
170+
// provide one yet. Shouldn't be necessary in the future.
171+
struct ArrayGuard<'a, T, const N: usize> {
172+
dst: *mut T,
173+
initialized: &'a mut usize,
174+
}
175+
176+
impl<T, const N: usize> Drop for ArrayGuard<'_, T, N> {
177+
fn drop(&mut self) {
178+
debug_assert!(*self.initialized <= N);
179+
let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, *self.initialized);
180+
unsafe {
181+
core::ptr::drop_in_place(initialized_part);
182+
}
183+
}
184+
}
185+
186+
let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
187+
let guard: ArrayGuard<T, N> = ArrayGuard {
188+
dst: array.as_mut_ptr() as _,
189+
initialized: counter,
190+
};
191+
unsafe {
192+
for (idx, value_ptr) in (&mut *array.as_mut_ptr()).iter_mut().enumerate() {
193+
core::ptr::write(value_ptr, cb(idx)?);
194+
*guard.initialized += 1;
195+
}
196+
core::mem::forget(guard);
197+
Ok(array.assume_init())
198+
}
199+
}
200+
201+
#[cfg(test)]
202+
mod test {
203+
use crate::Python;
204+
#[cfg(min_const_generics)]
205+
use std::{
206+
panic,
207+
sync::{Arc, Mutex},
208+
thread::sleep,
209+
time,
210+
};
211+
212+
#[cfg(min_const_generics)]
213+
#[test]
214+
fn try_create_array() {
215+
#[allow(clippy::mutex_atomic)]
216+
let counter = Arc::new(Mutex::new(0));
217+
let counter_unwind = Arc::clone(&counter);
218+
let _ = catch_unwind_silent(move || {
219+
let mut locked = counter_unwind.lock().unwrap();
220+
let _: Result<[i32; 4], _> = super::try_create_array(&mut *locked, |idx| {
221+
if idx == 2 {
222+
panic!("peek a boo");
223+
}
224+
Ok::<_, ()>(1)
225+
});
226+
});
227+
sleep(time::Duration::from_secs(2));
228+
assert_eq!(*counter.lock().unwrap_err().into_inner(), 2);
229+
}
230+
231+
#[cfg(not(min_const_generics))]
232+
#[test]
233+
fn test_extract_bytearray_to_array() {
234+
let gil = Python::acquire_gil();
235+
let py = gil.python();
236+
let v: [u8; 3] = py
237+
.eval("bytearray(b'abc')", None, None)
238+
.unwrap()
239+
.extract()
240+
.unwrap();
241+
assert!(&v == b"abc");
242+
}
243+
244+
#[cfg(min_const_generics)]
245+
#[test]
246+
fn test_extract_bytearray_to_array() {
247+
let gil = Python::acquire_gil();
248+
let py = gil.python();
249+
let v: [u8; 33] = py
250+
.eval(
251+
"bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')",
252+
None,
253+
None,
254+
)
255+
.unwrap()
256+
.extract()
257+
.unwrap();
258+
assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
259+
}
260+
261+
// https://stackoverflow.com/a/59211505
262+
#[cfg(min_const_generics)]
263+
fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
264+
where
265+
F: FnOnce() -> R + panic::UnwindSafe,
266+
{
267+
let prev_hook = panic::take_hook();
268+
panic::set_hook(Box::new(|_| {}));
269+
let result = panic::catch_unwind(f);
270+
panic::set_hook(prev_hook);
271+
result
272+
}
273+
}

src/conversions/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! This module contains conversions between non-String Rust object and their string representation
22
//! in Python
33
4+
mod array;
45
mod osstr;
56
mod path;

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ pub mod pyclass_slots;
209209
mod python;
210210
pub mod type_object;
211211
pub mod types;
212+
mod utils;
212213

213214
#[cfg(feature = "serde")]
214215
pub mod serde;

src/types/list.rs

-20
Original file line numberDiff line numberDiff line change
@@ -178,26 +178,6 @@ where
178178
}
179179
}
180180

181-
macro_rules! array_impls {
182-
($($N:expr),+) => {
183-
$(
184-
impl<T> IntoPy<PyObject> for [T; $N]
185-
where
186-
T: ToPyObject
187-
{
188-
fn into_py(self, py: Python) -> PyObject {
189-
self.as_ref().to_object(py)
190-
}
191-
}
192-
)+
193-
}
194-
}
195-
196-
array_impls!(
197-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
198-
26, 27, 28, 29, 30, 31, 32
199-
);
200-
201181
impl<T> ToPyObject for Vec<T>
202182
where
203183
T: ToPyObject,

0 commit comments

Comments
 (0)