-
-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Return Result from wasm_fanova_calculate
- Loading branch information
Showing
2 changed files
with
14 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,31 @@ | ||
use fanova::{FanovaOptions, RandomForestOptions}; | ||
use js_sys::Array; | ||
use serde_wasm_bindgen::from_value; | ||
use fanova::{FanovaOptions, RandomForestOptions}; | ||
use wasm_bindgen::prelude::*; | ||
|
||
#[wasm_bindgen] | ||
pub fn wasm_fanova_calculate(features: Array, targets: Array) -> Vec<f64> { | ||
// TODO(c-bata): Fix error handling | ||
pub fn wasm_fanova_calculate(features: Array, targets: Array) -> Result<Vec<f64>, JsError> { | ||
let features_vec: Vec<Vec<f64>> = features | ||
.iter() | ||
.map(|x| from_value::<Vec<f64>>(x).unwrap()) | ||
.collect(); | ||
let targets_vec: Vec<f64> = targets.iter().map(|x| x.as_f64().unwrap()).collect(); | ||
.map(|x| from_value::<Vec<f64>>(x)) | ||
.collect::<Result<_, _>>() | ||
.map_err(|_| JsError::new("features must be of type number[][]"))?; | ||
let targets_vec: Vec<f64> = targets | ||
.iter() | ||
.map(|x| x.as_f64()) | ||
.collect::<Option<_>>() | ||
.ok_or(JsError::new("targets must be of type number[]"))?; | ||
|
||
let mut fanova = FanovaOptions::new() | ||
.random_forest(RandomForestOptions::new().seed(0)) | ||
.fit( | ||
features_vec.iter().map(|x| x.as_slice()).collect(), | ||
&targets_vec, | ||
) | ||
.unwrap(); | ||
.map_err(|e| JsError::new(&format!("failed to build fANOVA model: {}", e)))?; | ||
let importances = (0..features_vec.len()) | ||
.map(|i| fanova.quantify_importance(&[i]).mean) | ||
.collect::<Vec<_>>(); | ||
return importances; | ||
|
||
Ok(importances) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters