Skip to content

Commit c7cdf52

Browse files
committed
Impl a lifetime-relaxed broadcast for ArrayView
ArrayView::broadcast has a lifetime that depends on &self instead of its internal buffer. This prevents writing some types of functions in an allocation-free way. For instance, take the numpy `meshgrid` function: It could be implemented like so: ```rust fn meshgrid_2d<'a, 'b>(coords_x: ArrayView1<'a, X>, coords_y: ArrayView1<'b, X>) -> (ArrayView2<'a, X>, ArrayView2<'b, X>) { let x_len = coords_x.shape()[0]; let y_len = coords_y.shape()[0]; let coords_x_s = coords_x.into_shape((1, y_len)).unwrap(); let coords_x_b = coords_x_s.broadcast((x_len, y_len)).unwrap(); let coords_y_s = coords_y.into_shape((x_len, 1)).unwrap(); let coords_y_b = coords_y_s.broadcast((x_len, y_len)).unwrap(); (coords_x_b, coords_y_b) } ``` Unfortunately, this doesn't work, because `coords_x_b` is bound to the lifetime of `coord_x_s`, instead of being bound to 'a. This commit introduces a new function, broadcast_ref, that does just that.
1 parent e080d62 commit c7cdf52

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

src/impl_views/methods.rs

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright 2014-2016 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use crate::imp_prelude::*;
10+
use crate::dimension::IntoDimension;
11+
use crate::dimension::size_of_shape_checked;
12+
13+
impl<'a, A, D> ArrayView<'a, A, D>
14+
where
15+
D: Dimension,
16+
{
17+
/// Broadcasts an arrayview.
18+
pub fn broadcast_ref<E>(&self, dim: E) -> Option<ArrayView<'a, A, E::Dim>>
19+
where
20+
E: IntoDimension,
21+
{
22+
/// Return new stride when trying to grow `from` into shape `to`
23+
///
24+
/// Broadcasting works by returning a "fake stride" where elements
25+
/// to repeat are in axes with 0 stride, so that several indexes point
26+
/// to the same element.
27+
///
28+
/// **Note:** Cannot be used for mutable iterators, since repeating
29+
/// elements would create aliasing pointers.
30+
fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
31+
// Make sure the product of non-zero axis lengths does not exceed
32+
// `isize::MAX`. This is the only safety check we need to perform
33+
// because all the other constraints of `ArrayBase` are guaranteed
34+
// to be met since we're starting from a valid `ArrayBase`.
35+
let _ = size_of_shape_checked(to).ok()?;
36+
37+
let mut new_stride = to.clone();
38+
// begin at the back (the least significant dimension)
39+
// size of the axis has to either agree or `from` has to be 1
40+
if to.ndim() < from.ndim() {
41+
return None;
42+
}
43+
44+
{
45+
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
46+
for ((er, es), dr) in from
47+
.slice()
48+
.iter()
49+
.rev()
50+
.zip(stride.slice().iter().rev())
51+
.zip(new_stride_iter.by_ref())
52+
{
53+
/* update strides */
54+
if *dr == *er {
55+
/* keep stride */
56+
*dr = *es;
57+
} else if *er == 1 {
58+
/* dead dimension, zero stride */
59+
*dr = 0
60+
} else {
61+
return None;
62+
}
63+
}
64+
65+
/* set remaining strides to zero */
66+
for dr in new_stride_iter {
67+
*dr = 0;
68+
}
69+
}
70+
Some(new_stride)
71+
}
72+
let dim = dim.into_dimension();
73+
74+
// Note: zero strides are safe precisely because we return an read-only view
75+
let broadcast_strides = match upcast(&dim, &self.dim, &self.strides) {
76+
Some(st) => st,
77+
None => return None,
78+
};
79+
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
80+
}
81+
}

src/impl_views/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod constructors;
22
mod conversions;
33
mod indexing;
4+
mod methods;
45
mod splitting;
56

67
pub use constructors::*;

0 commit comments

Comments
 (0)