Skip to content
132 changes: 92 additions & 40 deletions crates/revrt/src/cost.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Cost fuction

use derive_builder::Builder;
use ndarray::{Axis, stack};
use ndarray::{ArrayD, Axis, IxDyn, stack};
use std::convert::TryFrom;
use tracing::{debug, trace};

use crate::dataset::LazySubset;
Expand All @@ -25,13 +26,18 @@ pub(crate) struct CostFunction {
/// operating on input features. Following the original `revX` structure,
/// the possible compositions are limited to combinations of the relation
/// `weight * layer_name * multiplier_layer`, where the `weight` and the
/// `multiplier_layer` are optional.
/// `multiplier_layer` are optional. Each layer can also be marked as invariant,
/// meaning that it's value does not get scaled by the distance traveled
/// through the cell. Instead, the value of the layer is added once, right
/// when the path enters the cell.
struct CostLayer {
layer_name: String,
#[builder(setter(strip_option), default)]
multiplier_scalar: Option<f32>,
#[builder(setter(strip_option, into), default)]
multiplier_layer: Option<String>,
#[builder(setter(strip_option), default)]
is_invariant: Option<bool>,
}

impl CostFunction {
Expand Down Expand Up @@ -66,52 +72,35 @@ impl CostFunction {
///
/// # Arguments
/// `features`: A lazy collection of input features.
/// `is_invariant`: If true, only invariant layers contribute.
///
/// # Returns
/// A 2D array containing the cost for the subset covered by the input
/// features.
pub(crate) fn compute(
&self,
mut features: LazySubset<f32>,
features: &mut LazySubset<f32>,
is_invariant: bool,
) -> ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>> {
debug!("Calculating cost for ({})", features.subset());
debug!(
"Calculating (is_invariant={}) cost for ({})",
is_invariant,
features.subset()
);

let cost = self
let layers: Vec<&CostLayer> = self
.cost_layers
.iter()
.map(|layer| {
let layer_name = &layer.layer_name;
trace!("Layer name: {}", layer_name);

let mut cost = features
.get(layer_name)
.expect("Layer not found in features");

if let Some(multiplier_scalar) = layer.multiplier_scalar {
trace!(
"Layer {} has multiplier scalar {}",
layer_name, multiplier_scalar
);
// Apply the multiplier scalar to the value
cost *= multiplier_scalar;
// trace!( "Cost for chunk ({}, {}) in layer {}: {}", ci, cj, layer_name, cost);
}

if let Some(multiplier_layer) = &layer.multiplier_layer {
trace!(
"Layer {} has multiplier layer {}",
layer_name, multiplier_layer
);
let multiplier_value = features
.get(multiplier_layer)
.expect("Multiplier layer not found in features");

// Apply the multiplier layer to the value
cost = cost * multiplier_value;
// trace!( "Cost for chunk ({}, {}) in layer {}: {}", ci, cj, layer_name, cost);
}
cost
})
.filter(|layer| layer.is_invariant.unwrap_or(false) == is_invariant)
.collect();

if layers.is_empty() {
return empty_cost_array(features);
}

let cost = layers
.into_iter()
.map(|layer| build_single_layer(layer, features))
.collect::<Vec<_>>();

let views: Vec<_> = cost.iter().map(|a| a.view()).collect();
Expand All @@ -125,6 +114,56 @@ impl CostFunction {
}
}

fn empty_cost_array(
features: &LazySubset<f32>,
) -> ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>> {
let shape: Vec<usize> = features
.subset()
.shape()
.iter()
.map(|&dim| usize::try_from(dim).expect("subset dimension exceeds usize range"))
.collect();

ArrayD::<f32>::zeros(IxDyn(&shape))
}

fn build_single_layer(
layer: &CostLayer,
features: &mut LazySubset<f32>,
) -> ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>> {
let layer_name = &layer.layer_name;
trace!("Layer name: {}", layer_name);

let mut cost = features
.get(layer_name)
.expect("Layer not found in features");

if let Some(multiplier_scalar) = layer.multiplier_scalar {
trace!(
"Layer {} has multiplier scalar {}",
layer_name, multiplier_scalar
);
// Apply the multiplier scalar to the value
cost *= multiplier_scalar;
// trace!( "Cost for chunk ({}, {}) in layer {}: {}", ci, cj, layer_name, cost);
}

if let Some(multiplier_layer) = &layer.multiplier_layer {
trace!(
"Layer {} has multiplier layer {}",
layer_name, multiplier_layer
);
let multiplier_value = features
.get(multiplier_layer)
.expect("Multiplier layer not found in features");

// Apply the multiplier layer to the value
cost = cost * multiplier_value;
// trace!( "Cost for chunk ({}, {}) in layer {}: {}", ci, cj, layer_name, cost);
}
cost
}

#[cfg(test)]
pub(crate) mod sample {
use super::*;
Expand All @@ -139,7 +178,9 @@ pub(crate) mod sample {
{"layer_name": "A",
"multiplier_layer": "B"},
{"layer_name": "C", "multiplier_scalar": 2,
"multiplier_layer": "A"}
"multiplier_layer": "A"},
{"layer_name": "C", "multiplier_scalar": 100,
"is_invariant": true}
]
}
"#
Expand All @@ -162,12 +203,14 @@ mod test_builder {
.layer_name("A".to_string())
.multiplier_scalar(2.0)
.multiplier_layer("B")
.is_invariant(false)
.build()
.unwrap();

assert_eq!(layer.layer_name, "A");
assert_eq!(layer.multiplier_scalar, Some(2.0));
assert_eq!(layer.multiplier_layer, Some("B".to_string()));
assert_eq!(layer.is_invariant, Some(false));
}

#[test]
Expand All @@ -180,6 +223,7 @@ mod test_builder {
assert_eq!(layer.layer_name, "A");
assert_eq!(layer.multiplier_scalar, None);
assert_eq!(layer.multiplier_layer, None);
assert_eq!(layer.is_invariant, None);
}
}

Expand All @@ -192,14 +236,22 @@ mod test {
let json = sample::as_text_v1();
let cost = CostFunction::from_json(&json).unwrap();

assert_eq!(cost.cost_layers.len(), 4);
assert_eq!(cost.cost_layers.len(), 5);
assert_eq!(cost.cost_layers[0].layer_name, "A");
assert_eq!(cost.cost_layers[0].is_invariant, None);
assert_eq!(cost.cost_layers[1].layer_name, "B");
assert_eq!(cost.cost_layers[1].multiplier_scalar, Some(100.0));
assert_eq!(cost.cost_layers[1].is_invariant, None);
assert_eq!(cost.cost_layers[2].layer_name, "A");
assert_eq!(cost.cost_layers[2].multiplier_layer, Some("B".to_string()));
assert_eq!(cost.cost_layers[2].is_invariant, None);
assert_eq!(cost.cost_layers[3].layer_name, "C");
assert_eq!(cost.cost_layers[3].multiplier_layer, Some("A".to_string()));
assert_eq!(cost.cost_layers[3].multiplier_scalar, Some(2.0));
assert_eq!(cost.cost_layers[3].is_invariant, None);
assert_eq!(cost.cost_layers[4].layer_name, "C");
assert_eq!(cost.cost_layers[4].multiplier_layer, None);
assert_eq!(cost.cost_layers[4].multiplier_scalar, Some(100.0));
assert_eq!(cost.cost_layers[4].is_invariant, Some(true));
}
}
Loading
Loading