Skip to content

Matrix allocation optimization #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: development
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 136 additions & 20 deletions src/linalg/basic/matrix.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::any::{Any, TypeId};
use std::fmt;
use std::fmt::{Debug, Display};
use std::ops::Range;
Expand All @@ -18,19 +19,119 @@
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap};

use crate::error::Failed;

thread_local! {
static MATRIX_POOL: RefCell<MatrixMemoryPool> = RefCell::new(MatrixMemoryPool::new());
}

struct MatrixMemoryPool {
vectors: HashMap<TypeId, BTreeMap<usize, Vec<Box<dyn Any>>>>
}

impl MatrixMemoryPool {
fn new() -> Self {
MatrixMemoryPool {
vectors: HashMap::new(),
}
}

fn release_vec<T: 'static>(&mut self, mut vec: Vec<T>) {
let capacity = vec.capacity();
vec.clear();
let type_id = TypeId::of::<T>();
self.vectors
.entry(type_id)
.or_insert_with(BTreeMap::new)
.entry(capacity)
.or_insert_with(Vec::new)
.push(Box::new(vec));
}

fn acquire_vec<T: 'static>(&mut self, min_size: usize) -> Option<Vec<T>> {
let type_id = TypeId::of::<T>();
if let Some(btree) = self.vectors.get_mut(&type_id) {
let best_fit = btree.range_mut(min_size..).next();
let mut key_to_remove = None;
if let Some((key, vectors)) = best_fit {
if !vectors.is_empty() {
let vec: Vec<T> = *vectors.pop().unwrap().downcast().unwrap();
return Some(vec);
}
if vectors.is_empty() {
key_to_remove = Some(*key);
}
}
if let Some(key) = key_to_remove {
btree.remove(&key);
}
}
None
}
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct PooledVec<T: 'static + Clone> {
inner: Option<Vec<T>>,
is_pooled: bool,
}

impl<T: 'static + Clone> PooledVec<T> {
fn new(vec: Vec<T>, is_pooled: bool) -> Self {
PooledVec {
inner: Some(vec),
is_pooled,
}
}
}

impl<T: 'static + Clone> Drop for PooledVec<T> {
fn drop(&mut self) {
if self.is_pooled {
if let Some(vec) = self.inner.take() {
MATRIX_POOL.with(|pool| {
pool.borrow_mut().release_vec(vec);
});
}
}
}
}

impl<T: 'static + Clone> Clone for PooledVec<T> {
fn clone(&self) -> Self {
PooledVec::new(self.inner.as_ref().unwrap().clone(), false)
}
}

impl<T: 'static + Clone> std::ops::Deref for PooledVec<T> {
type Target = Vec<T>;

fn deref(&self) -> &Vec<T> {
self.inner.as_ref().unwrap()
}
}

impl<T: 'static + Clone> std::ops::DerefMut for PooledVec<T> {
fn deref_mut(&mut self) -> &mut Vec<T> {
self.inner.as_mut().unwrap()
}
}

/// Dense matrix
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DenseMatrix<T> {
pub struct DenseMatrix<T: 'static + Clone> {
ncols: usize,
nrows: usize,
values: Vec<T>,
values: PooledVec<T>,
column_major: bool,
}


/// View on dense matrix
#[derive(Debug, Clone)]
pub struct DenseMatrixView<'a, T: Debug + Display + Copy + Sized> {
Expand Down Expand Up @@ -196,6 +297,7 @@
"The specified shape: (cols: {ncols}, rows: {nrows}) does not align with data len: {data_len}"
)))
} else {
let values = PooledVec::new(values, true);
Ok(DenseMatrix {
ncols,
nrows,
Expand Down Expand Up @@ -280,6 +382,20 @@
};
(start, end, stride)
}

pub fn zeros_(nrows: usize, ncols: usize) -> Vec<f64> {

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / check_features (ubuntu)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / check_features (ubuntu, --features datasets)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, i686-unknown-linux-gnu)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, i686-unknown-linux-gnu)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, x86_64-unknown-linux-gnu)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, x86_64-unknown-linux-gnu)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (macos, aarch64-apple-darwin)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (macos, aarch64-apple-darwin)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (macos, aarch64-apple-darwin)

missing documentation for an associated function

Check warning on line 386 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / check_features (ubuntu, --features serde)

missing documentation for an associated function
if let Some(vector) = MATRIX_POOL.with(|pool| {
pool.borrow_mut().acquire_vec::<f64>(nrows * ncols)
}) {
let mut vector = vector;
for _ in (0..nrows * ncols) {

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / check_features (ubuntu)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / check_features (ubuntu, --features datasets)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, i686-unknown-linux-gnu)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, i686-unknown-linux-gnu)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, x86_64-unknown-linux-gnu)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, x86_64-unknown-linux-gnu)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, x86_64-unknown-linux-gnu)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (macos, aarch64-apple-darwin)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (macos, aarch64-apple-darwin)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (macos, aarch64-apple-darwin)

unnecessary parentheses around `for` iterator expression

Check warning on line 391 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / check_features (ubuntu, --features serde)

unnecessary parentheses around `for` iterator expression
vector.push(0.0);
}
return vector;
} else {
vec![0.0; nrows * ncols]
}
}
}

impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> {
Expand Down Expand Up @@ -704,7 +820,7 @@

assert_eq!(
vec![4, 5, 6],
DenseMatrix::from_slice(&(*x.slice(1..2, 0..3))).values
*DenseMatrix::from_slice(&(*x.slice(1..2, 0..3))).values
);
let second_row: Vec<i32> = x.slice(1..2, 0..3).iterator(0).copied().collect();
assert_eq!(vec![4, 5, 6], second_row);
Expand All @@ -716,35 +832,35 @@
fn test_iter_mut() {
let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();

assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], x.values);
assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], *x.values);
// add +2 to some elements
x.slice_mut(1..2, 0..3)
.iterator_mut(0)
.for_each(|v| *v += 2);
assert_eq!(vec![1, 6, 7, 2, 7, 8, 3, 8, 9], x.values);
assert_eq!(vec![1, 6, 7, 2, 7, 8, 3, 8, 9], *x.values);
// add +1 to some others
x.slice_mut(0..3, 1..2)
.iterator_mut(0)
.for_each(|v| *v += 1);
assert_eq!(vec![1, 6, 7, 3, 8, 9, 3, 8, 9], x.values);
assert_eq!(vec![1, 6, 7, 3, 8, 9, 3, 8, 9], *x.values);

// rewrite matrix as indices of values per axis 1 (row-wise)
x.iterator_mut(1).enumerate().for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 1, 2, 3, 4, 5, 6, 7, 8], x.values);
assert_eq!(vec![0, 1, 2, 3, 4, 5, 6, 7, 8], *x.values);
// rewrite matrix as indices of values per axis 0 (column-wise)
x.iterator_mut(0).enumerate().for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 3, 6, 1, 4, 7, 2, 5, 8], x.values);
assert_eq!(vec![0, 3, 6, 1, 4, 7, 2, 5, 8], *x.values);
// rewrite some by slice
x.slice_mut(0..3, 0..2)
.iterator_mut(0)
.enumerate()
.for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 2, 4, 1, 3, 5, 2, 5, 8], x.values);
assert_eq!(vec![0, 2, 4, 1, 3, 5, 2, 5, 8], *x.values);
x.slice_mut(0..2, 0..3)
.iterator_mut(1)
.enumerate()
.for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 1, 4, 2, 3, 5, 4, 5, 8], x.values);
assert_eq!(vec![0, 1, 4, 2, 3, 5, 4, 5, 8], *x.values);
}

#[test]
Expand All @@ -753,24 +869,24 @@
DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]])
.unwrap();

assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values);
assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], *x.values);
x.iterator_mut(0).for_each(|v| *v = "str");
assert_eq!(
vec!["str", "str", "str", "str", "str", "str", "str", "str", "str"],
x.values
*x.values
);
}

#[test]
fn test_transpose() {
let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]).unwrap();

assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], *x.values);
assert!(x.column_major);

// transpose
let x = x.transpose();
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], *x.values);
assert!(!x.column_major); // should change column_major
}

Expand All @@ -778,7 +894,7 @@
fn test_from_iterator() {
let data = [1, 2, 3, 4, 5, 6];

let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0);

Check failure on line 897 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu, wasm32-unknown-unknown)

`data` does not live long enough

Check failure on line 897 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / tests (macos, aarch64-apple-darwin)

`data` does not live long enough

Check failure on line 897 in src/linalg/basic/matrix.rs

View workflow job for this annotation

GitHub Actions / coverage

`data` does not live long enough

// make a vector into a 2x3 matrix.
assert_eq!(
Expand All @@ -795,21 +911,21 @@

println!("{a}");
// take column 0 and 2
assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values);
assert_eq!(vec![1, 3, 4, 6], *a.take(&[0, 2], 1).values);
println!("{b}");
// take rows 0 and 2
assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values);
assert_eq!(vec![1, 2, 5, 6], *b.take(&[0, 2], 0).values);
}

#[test]
fn test_mut() {
let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]).unwrap();

let a = a.abs();
assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], a.values);
assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], *a.values);

let a = a.neg();
assert_eq!(vec![-1.3, -4.0, -2.1, -5.3, -3.4, -6.1], a.values);
assert_eq!(vec![-1.3, -4.0, -2.1, -5.3, -3.4, -6.1], *a.values);
}

#[test]
Expand All @@ -818,11 +934,11 @@
.unwrap();

let a = a.reshape(2, 6, 0);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], *a.values);
assert!(a.ncols == 6 && a.nrows == 2 && !a.column_major);

let a = a.reshape(3, 4, 1);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], *a.values);
assert!(a.ncols == 4 && a.nrows == 3 && a.column_major);
}

Expand Down
Loading