diff --git a/crates/revrt/src/cost.rs b/crates/revrt/src/cost.rs index 69770bbf..58c2cb94 100644 --- a/crates/revrt/src/cost.rs +++ b/crates/revrt/src/cost.rs @@ -72,6 +72,7 @@ impl CostFunction { /// features. pub(crate) fn compute( &self, + // For now let's restrict to f32, but we are ready to move to generic types. mut features: LazySubset, ) -> ndarray::ArrayBase, ndarray::Dim> { debug!("Calculating cost for ({})", features.subset()); @@ -117,7 +118,6 @@ impl CostFunction { let views: Vec<_> = cost.iter().map(|a| a.view()).collect(); let stack = stack(Axis(0), &views).unwrap(); //let cost = stack![Axis(3), &cost]; - trace!("Stack shape: {:?}", stack.shape()); let cost = stack.sum_axis(Axis(0)); trace!("Stack shape: {:?}", stack.shape()); diff --git a/crates/revrt/src/dataset.rs b/crates/revrt/src/dataset.rs index 114f078d..638ed371 100644 --- a/crates/revrt/src/dataset.rs +++ b/crates/revrt/src/dataset.rs @@ -7,19 +7,24 @@ use std::sync::RwLock; use tracing::{debug, trace, warn}; use zarrs::array::ArrayChunkCacheExt; +use zarrs::array_subset::ArraySubset; use zarrs::storage::{ ListableStorageTraits, ReadableListableStorage, ReadableWritableListableStorage, }; -use crate::ArrayIndex; use crate::cost::CostFunction; use crate::error::Result; +use crate::ArrayIndex; +pub(crate) use lazy_chunk::LazyChunk; pub(crate) use lazy_subset::LazySubset; +const CHUNK_SHAPE: [u64; 2] = [1_000, 1_000]; + /// Manages the features datasets and calculated total cost pub(super) struct Dataset { /// A Zarr storages with the features source: ReadableListableStorage, + dims: Vec, // Silly way to keep the tmp path alive #[allow(dead_code)] cost_path: tempfile::TempDir, @@ -45,18 +50,33 @@ impl Dataset { path: P, cost_function: CostFunction, cache_size: u64, + ) -> Result { + tracing::warn!("Deprecated: use `Dataset::new` instead"); + Self::new(path, cost_function, cache_size) + } + + pub(super) fn new>( + path: P, + cost_function: CostFunction, + cache_size: u64, ) -> Result { debug!("Opening dataset: {:?}", path.as_ref()); let filesystem = zarrs::filesystem::FilesystemStore::new(path).expect("could not open filesystem store"); let source = std::sync::Arc::new(filesystem); + // ==== Temporary solution to specify dimensions ==== + // Assume all variables have the same shape and chunk shape. + // Find the name of the first variable and use it. + let varname = source.list().unwrap()[0].to_string(); + let varname = varname.split("/").collect::>()[0]; + let tmp = zarrs::array::Array::open(source.clone(), &format!("/{varname}")).unwrap(); + let dims = tmp.shape().to_vec(); + debug_assert!(!dims.contains(&0)); + // ==== Create the swap dataset ==== let tmp_path = tempfile::TempDir::new().unwrap(); - debug!( - "Initializing a temporary swap dataset at {:?}", - tmp_path.path() - ); + debug!("Initializing a swap dataset at {:?}", tmp_path.path()); let swap: ReadableWritableListableStorage = std::sync::Arc::new( zarrs::filesystem::FilesystemStore::new(tmp_path.path()) .expect("could not open filesystem store"), @@ -80,24 +100,48 @@ impl Dataset { // ---- trace!("Creating an empty cost array"); - let array = zarrs::array::ArrayBuilder::new( + let cost = zarrs::array::ArrayBuilder::new( cost_shape.into(), zarrs::array::DataType::Float32, chunk_shape, zarrs::array::FillValue::from(zarrs::array::ZARR_NAN_F32), ) .build(swap.clone(), "/cost") - .unwrap(); - trace!("Cost shape: {:?}", array.shape().to_vec()); - trace!("Cost chunk shape: {:?}", array.chunk_grid()); - array.store_metadata().unwrap(); + .expect("Failed to create cost array"); + trace!("Cost shape: {:?}", cost.shape().to_vec()); + trace!("Cost chunk shape: {:?}", cost.chunk_grid()); + cost.store_metadata().unwrap(); + + debug!( + "Cost variable created: {:?}, shape: {:?} [{:?}]", + cost.path(), + cost.shape(), + cost.chunk_grid() + ); + // ==== Create the delta cost array ==== + trace!("Creating an empty delta cost array"); + + let delta = zarrs::array::ArrayBuilder::new( + [cost_shape, &[8]].concat(), + zarrs::array::DataType::Float32, + vec![CHUNK_SHAPE[0], CHUNK_SHAPE[1], 8].try_into().unwrap(), + // CHUNK_SHAPE .iter() .chain(&[8]) .cloned() .collect::>() .try_into() .unwrap(), + zarrs::array::FillValue::from(zarrs::array::ZARR_NAN_F32), + ) + .build(swap.clone(), "/delta") + .expect("Failed to create delta array"); + trace!("Delta shape: {:?}", delta.shape().to_vec()); + trace!("Delta chunk shape: {:?}", delta.chunk_grid()); + delta.store_metadata().unwrap(); + + // ==== trace!("Cost dataset contents: {:?}", swap.list().unwrap()); let cost_chunk_idx = ndarray::Array2::from_elem( ( - array.chunk_grid_shape().unwrap()[0] as usize, - array.chunk_grid_shape().unwrap()[1] as usize, + cost.chunk_grid_shape().unwrap()[0] as usize, + cost.chunk_grid_shape().unwrap()[1] as usize, ), false, ) @@ -112,6 +156,7 @@ impl Dataset { trace!("Dataset opened successfully"); Ok(Self { source, + dims, cost_path: tmp_path, swap, cost_chunk_idx, @@ -128,25 +173,14 @@ impl Dataset { // Get the subset according to cost's chunk let subset = variable.chunk_subset(&[ci, cj]).unwrap(); let data = LazySubset::::new(self.source.clone(), subset); - let output = self.cost_function.compute(data); trace!("Cost function: {:?}", self.cost_function); + let output = self.cost_function.compute(data); - /* - trace!("Getting '/A' variable"); - let array = zarrs::array::Array::open(self.source.clone(), "/A").unwrap(); - let value = array.retrieve_chunk_ndarray::(&[i, j]).unwrap(); - trace!("Value: {:?}", value); - trace!("Calculating cost for chunk ({}, {})", i, j); - let output = value * 10.0; - */ - - let cost = zarrs::array::Array::open(self.swap.clone(), "/cost").unwrap(); cost.store_metadata().unwrap(); let chunk_indices: Vec = vec![ci, cj]; trace!("Storing chunk at {:?}", chunk_indices); - let chunk_subset = - &zarrs::array_subset::ArraySubset::new_with_ranges(&[ci..(ci + 1), cj..(cj + 1)]); + let chunk_subset = &ArraySubset::new_with_ranges(&[ci..(ci + 1), cj..(cj + 1)]); trace!("Target chunk subset: {:?}", chunk_subset); cost.store_chunks_ndarray(chunk_subset, output).unwrap(); } @@ -310,7 +344,7 @@ mod tests { let cost_function = CostFunction::from_json(r#"{"cost_layers": [{"layer_name": "A"}]}"#).unwrap(); let dataset = - Dataset::open(path, cost_function, 250_000_000).expect("Error opening dataset"); + Dataset::new(path, cost_function, 250_000_000).expect("Error opening dataset"); let test_points = [ArrayIndex { i: 3, j: 1 }, ArrayIndex { i: 2, j: 2 }]; let array = zarrs::array::Array::open(dataset.source.clone(), "/A").unwrap(); @@ -318,8 +352,7 @@ mod tests { let results = dataset.get_3x3(&point); for (ArrayIndex { i, j }, val) in results { - let subset = - zarrs::array_subset::ArraySubset::new_with_ranges(&[i..(i + 1), j..(j + 1)]); + let subset = ArraySubset::new_with_ranges(&[i..(i + 1), j..(j + 1)]); let subset_elements: Vec = array .retrieve_array_subset_elements(&subset) .expect("Error reading zarr data"); @@ -334,7 +367,7 @@ mod tests { let path = samples::multi_variable_zarr(); let cost_function = crate::cost::sample::cost_function(); let dataset = - Dataset::open(path, cost_function, 250_000_000).expect("Error opening dataset"); + Dataset::new(path, cost_function, 250_000_000).expect("Error opening dataset"); let test_points = [ArrayIndex { i: 3, j: 1 }, ArrayIndex { i: 2, j: 2 }]; let array_a = zarrs::array::Array::open(dataset.source.clone(), "/A").unwrap(); @@ -344,8 +377,7 @@ mod tests { let results = dataset.get_3x3(&point); for (ArrayIndex { i, j }, val) in results { - let subset = - zarrs::array_subset::ArraySubset::new_with_ranges(&[i..(i + 1), j..(j + 1)]); + let subset = ArraySubset::new_with_ranges(&[i..(i + 1), j..(j + 1)]); let subset_elements_a: Vec = array_a .retrieve_array_subset_elements(&subset) .expect("Error reading zarr data"); @@ -378,7 +410,7 @@ mod tests { let cost_function = CostFunction::from_json(r#"{"cost_layers": [{"layer_name": "cost"}]}"#).unwrap(); let dataset = - Dataset::open(path, cost_function, 250_000_000).expect("Error opening dataset"); + Dataset::new(path, cost_function, 250_000_000).expect("Error opening dataset"); let results = dataset.get_3x3(&ArrayIndex { i: 0, j: 0 }); @@ -394,7 +426,7 @@ mod tests { let cost_function = CostFunction::from_json(r#"{"cost_layers": [{"layer_name": "cost"}]}"#).unwrap(); let dataset = - Dataset::open(path, cost_function, 250_000_000).expect("Error opening dataset"); + Dataset::new(path, cost_function, 250_000_000).expect("Error opening dataset"); let results = dataset.get_3x3(&ArrayIndex { i: si, j: sj }); @@ -424,7 +456,7 @@ mod tests { let cost_function = CostFunction::from_json(r#"{"cost_layers": [{"layer_name": "cost"}]}"#).unwrap(); let dataset = - Dataset::open(path, cost_function, 250_000_000).expect("Error opening dataset"); + Dataset::new(path, cost_function, 250_000_000).expect("Error opening dataset"); let results = dataset.get_3x3(&ArrayIndex { i: si, j: sj }); @@ -457,7 +489,7 @@ mod tests { let cost_function = CostFunction::from_json(r#"{"cost_layers": [{"layer_name": "cost"}]}"#).unwrap(); let dataset = - Dataset::open(path, cost_function, 250_000_000).expect("Error opening dataset"); + Dataset::new(path, cost_function, 250_000_000).expect("Error opening dataset"); let results = dataset.get_3x3(&ArrayIndex { i: si, j: sj }); @@ -470,85 +502,3 @@ mod tests { ); } } - -/// Lazy chunk of a Zarr dataset -pub(crate) struct LazyChunk { - /// Source Zarr storage - source: ReadableListableStorage, - /// Chunk index 1st dimension - ci: u64, - /// Chunk index 2nd dimension - cj: u64, - /// Data - // We know it is a 2D array of f32. We might want to simplify and strict this definition. - // data: std::collections::HashMap>, - data: std::collections::HashMap< - String, - ndarray::ArrayBase, ndarray::Dim>, - >, -} - -#[allow(dead_code)] -impl LazyChunk { - pub(super) fn ci(&self) -> u64 { - self.ci - } - - pub(super) fn cj(&self) -> u64 { - self.cj - } - - //fn get(&self, variable: &str) -> Result<&ndarray::Array2> { - pub(crate) fn get( - &mut self, - variable: &str, - ) -> Result, ndarray::Dim>> { - trace!("Getting chunk data for variable: {}", variable); - - Ok(match self.data.get(variable) { - Some(v) => { - trace!("Chunk data for variable {} already loaded", variable); - v.clone() - } - None => { - trace!("Loading chunk data for variable: {}", variable); - let array = zarrs::array::Array::open(self.source.clone(), &format!("/{variable}")) - .unwrap(); - let chunk_indices = &[self.ci, self.cj]; - let chunk_subset = zarrs::array_subset::ArraySubset::new_with_ranges(&[ - chunk_indices[0]..(chunk_indices[0] + 1), - chunk_indices[1]..(chunk_indices[1] + 1), - ]); - trace!("Storing chunk data for variable: {}", variable); - let values = array.retrieve_chunks_ndarray::(&chunk_subset).unwrap(); - // array.retrieve_chunk_ndarray::(&[ci, cj]).unwrap(); - self.data.insert(variable.to_string(), values.clone()); - values - } - }) - } -} - -#[cfg(test)] -mod chunk_tests { - use super::*; - - #[test] - fn dev() { - let path = samples::multi_variable_zarr(); - let store: zarrs::storage::ReadableListableStorage = - std::sync::Arc::new(zarrs::filesystem::FilesystemStore::new(&path).unwrap()); - - let mut chunk = LazyChunk { - source: store, - ci: 0, - cj: 0, - data: std::collections::HashMap::new(), - }; - - assert_eq!(chunk.ci, 0); - assert_eq!(chunk.cj, 0); - - let _tmp = chunk.get("A").unwrap(); - } -}