diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index f28d4b8..c22ec09 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -16,4 +16,4 @@ safetensors = { workspace = true } [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = "0.2.92" getrandom = { version = "0.2", features = ["js"] } -js-sys = "0.3.69" +js-sys = "0.3.69" \ No newline at end of file diff --git a/crates/core/src/cpu/backend.rs b/crates/core/src/cpu/backend.rs index 91051fe..b74dc4b 100644 --- a/crates/core/src/cpu/backend.rs +++ b/crates/core/src/cpu/backend.rs @@ -1,15 +1,14 @@ use std::collections::HashMap; -use std::time::Instant; use ndarray::{ArrayD, ArrayViewD, IxDyn}; use safetensors::{serialize, SafeTensors}; use crate::{ to_arr, ActivationCPULayer, BackendConfig, BatchNorm1DCPULayer, BatchNorm2DCPULayer, - BatchNormTensors, CPUCost, CPULayer, CPUOptimizer, CPUScheduler, Conv2DCPULayer, ConvTensors, - ConvTranspose2DCPULayer, Dataset, DenseCPULayer, DenseTensors, Dropout1DCPULayer, - Dropout2DCPULayer, FlattenCPULayer, GetTensor, Layer, Logger, Pool2DCPULayer, SoftmaxCPULayer, - Tensor, Tensors, + BatchNormTensors, CPUCost, CPULayer, CPUOptimizer, CPUPostProcessor, CPUScheduler, + Conv2DCPULayer, ConvTensors, ConvTranspose2DCPULayer, Dataset, DenseCPULayer, DenseTensors, + Dropout1DCPULayer, Dropout2DCPULayer, FlattenCPULayer, GetTensor, Layer, Logger, + Pool2DCPULayer, PostProcessor, SoftmaxCPULayer, Tensor, Tensors, Timer, }; pub struct Backend { @@ -23,10 +22,16 @@ pub struct Backend { pub optimizer: CPUOptimizer, pub scheduler: CPUScheduler, pub logger: Logger, + pub timer: Timer, } impl Backend { - pub fn new(config: BackendConfig, logger: Logger, mut tensors: Option>) -> Self { + pub fn new( + config: BackendConfig, + logger: Logger, + timer: Timer, + mut tensors: Option>, + ) -> Self { let mut layers = Vec::new(); let mut size = config.size.clone(); for layer in config.layers.iter() { @@ -99,6 +104,7 @@ impl Backend { optimizer, scheduler, size, + timer, } } @@ -147,7 +153,7 @@ impl Backend { let mut cost = 0f32; let mut time: u128; let mut total_time = 0u128; - let start = Instant::now(); + let start = (self.timer.now)(); let total_iter = epochs * datasets.len(); while epoch < epochs { let mut total = 0.0; @@ -160,11 +166,11 @@ impl Backend { let minibatch = outputs.dim()[0]; if !self.silent && ((i + 1) * minibatch) % batches == 0 { cost = total / (batches) as f32; - time = start.elapsed().as_millis() - total_time; + time = ((self.timer.now)() - start) - total_time; total_time += time; let current_iter = epoch * datasets.len() + i; let msg = format!( - "Epoch={}, Dataset={}, Cost={}, Time={}s, ETA={}s", + "Epoch={}, Dataset={}, Cost={}, Time={:.3}s, ETA={:.3}s", epoch, i * minibatch, cost, @@ -188,25 +194,20 @@ impl Backend { } else { disappointments += 1; if !self.silent { - println!( + (self.logger.log)(format!( "Patience counter: {} disappointing epochs out of {}.", disappointments, self.patience - ); + )); } } if disappointments >= self.patience { if !self.silent { - println!( + (self.logger.log)(format!( "No improvement for {} epochs. Stopping early at cost={}", disappointments, best_cost - ); + )); } - let net = Self::load( - &best_net, - Logger { - log: |x| println!("{}", x), - }, - ); + let net = Self::load(&best_net, self.logger.clone(), self.timer.clone()); self.layers = net.layers; break; } @@ -215,11 +216,18 @@ impl Backend { } } - pub fn predict(&mut self, data: ArrayD, layers: Option>) -> ArrayD { + pub fn predict( + &mut self, + data: ArrayD, + postprocess: PostProcessor, + layers: Option>, + ) -> ArrayD { + let processor = CPUPostProcessor::from(&postprocess); for layer in &mut self.layers { layer.reset(1); } - self.forward_propagate(data, false, layers) + let res = self.forward_propagate(data, false, layers); + processor.process(res) } pub fn save(&self) -> Vec { @@ -272,7 +280,7 @@ impl Backend { serialize(tensors, &Some(metadata)).unwrap() } - pub fn load(buffer: &[u8], logger: Logger) -> Self { + pub fn load(buffer: &[u8], logger: Logger, timer: Timer) -> Self { let tensors = SafeTensors::deserialize(buffer).unwrap(); let (_, metadata) = SafeTensors::read_metadata(buffer).unwrap(); let data = metadata.metadata().as_ref().unwrap(); @@ -304,6 +312,6 @@ impl Backend { }; } - Backend::new(config, logger, Some(layers)) + Backend::new(config, logger, timer, Some(layers)) } } diff --git a/crates/core/src/cpu/mod.rs b/crates/core/src/cpu/mod.rs index 462bcb8..c815ede 100644 --- a/crates/core/src/cpu/mod.rs +++ b/crates/core/src/cpu/mod.rs @@ -6,6 +6,7 @@ mod layers; mod optimizers; mod schedulers; mod regularizer; +mod postprocessing; pub use activation::*; pub use backend::*; @@ -14,4 +15,5 @@ pub use init::*; pub use layers::*; pub use optimizers::*; pub use schedulers::*; -pub use regularizer::*; \ No newline at end of file +pub use regularizer::*; +pub use postprocessing::*; \ No newline at end of file diff --git a/crates/core/src/cpu/postprocessing/mod.rs b/crates/core/src/cpu/postprocessing/mod.rs new file mode 100644 index 0000000..1504212 --- /dev/null +++ b/crates/core/src/cpu/postprocessing/mod.rs @@ -0,0 +1,28 @@ +use ndarray::ArrayD; +use crate::PostProcessor; + +mod step; +use step::CPUStepFunction; + +pub enum CPUPostProcessor { + None, + Sign, + Step(CPUStepFunction), +} + +impl CPUPostProcessor { + pub fn from(processor: &PostProcessor) -> Self { + match processor { + PostProcessor::None => CPUPostProcessor::None, + PostProcessor::Sign => CPUPostProcessor::Sign, + PostProcessor::Step(config) => CPUPostProcessor::Step(CPUStepFunction::new(config)), + } + } + pub fn process(&self, x: ArrayD) -> ArrayD { + match self { + CPUPostProcessor::None => x, + CPUPostProcessor::Sign => x.map(|y| y.signum()), + CPUPostProcessor::Step(processor) => x.map(|y| processor.step(*y)), + } + } +} \ No newline at end of file diff --git a/crates/core/src/cpu/postprocessing/step.rs b/crates/core/src/cpu/postprocessing/step.rs new file mode 100644 index 0000000..a624e30 --- /dev/null +++ b/crates/core/src/cpu/postprocessing/step.rs @@ -0,0 +1,22 @@ +use crate::StepFunctionConfig; + +pub struct CPUStepFunction { + thresholds: Vec, + values: Vec +} +impl CPUStepFunction { + pub fn new(config: &StepFunctionConfig) -> Self { + return Self { + thresholds: config.thresholds.clone(), + values: config.values.clone() + } + } + pub fn step(&self, x: f32) -> f32 { + for (i, &threshold) in self.thresholds.iter().enumerate() { + if x < threshold { + return self.values[i]; + } + } + return self.values.last().unwrap().clone() + } +} \ No newline at end of file diff --git a/crates/core/src/ffi.rs b/crates/core/src/ffi.rs index 60dbf81..26b02ee 100644 --- a/crates/core/src/ffi.rs +++ b/crates/core/src/ffi.rs @@ -1,8 +1,9 @@ use std::slice::{from_raw_parts, from_raw_parts_mut}; +use std::time::{SystemTime, UNIX_EPOCH}; use crate::{ - decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, TrainOptions, - RESOURCES, + decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, Timer, + TrainOptions, RESOURCES, }; type AllocBufferFn = extern "C" fn(usize) -> *mut u8; @@ -11,10 +12,17 @@ fn log(string: String) { println!("{}", string) } +fn now() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Your system is behind the Unix Epoch") + .as_millis() +} + #[no_mangle] pub extern "C" fn ffi_backend_create(ptr: *const u8, len: usize, alloc: AllocBufferFn) -> usize { let config = decode_json(ptr, len); - let net_backend = Backend::new(config, Logger { log }, None); + let net_backend = Backend::new(config, Logger { log }, Timer { now }, None); let buf: Vec = net_backend .size .iter() @@ -75,7 +83,7 @@ pub extern "C" fn ffi_backend_predict( RESOURCES.with(|cell| { let mut backend = cell.backend.borrow_mut(); - let res = backend[id].predict(inputs, options.layers); + let res = backend[id].predict(inputs, options.post_process, options.layers); outputs.copy_from_slice(res.as_slice().unwrap()); }); } @@ -98,7 +106,7 @@ pub extern "C" fn ffi_backend_load( alloc: AllocBufferFn, ) -> usize { let buffer = unsafe { from_raw_parts(file_ptr, file_len) }; - let net_backend = Backend::load(buffer, Logger { log }); + let net_backend = Backend::load(buffer, Logger { log }, Timer { now }); let buf: Vec = net_backend.size.iter().map(|x| *x as u8).collect(); let size_ptr = alloc(buf.len()); let output_shape = unsafe { from_raw_parts_mut(size_ptr, buf.len()) }; diff --git a/crates/core/src/types.rs b/crates/core/src/types.rs index d54f4e8..014cd15 100644 --- a/crates/core/src/types.rs +++ b/crates/core/src/types.rs @@ -195,6 +195,21 @@ pub enum Scheduler { OneCycle(OneCycleScheduler), } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct StepFunctionConfig { + pub thresholds: Vec, + pub values: Vec, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type", content = "config")] +#[serde(rename_all = "lowercase")] +pub enum PostProcessor { + None, + Sign, + Step(StepFunctionConfig), +} + #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct TrainOptions { @@ -212,6 +227,7 @@ pub struct PredictOptions { pub input_shape: Vec, pub output_shape: Vec, pub layers: Option>, + pub post_process: PostProcessor, } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/crates/core/src/util.rs b/crates/core/src/util.rs index edee1a0..31e6432 100644 --- a/crates/core/src/util.rs +++ b/crates/core/src/util.rs @@ -4,10 +4,16 @@ use ndarray::ArrayD; use safetensors::tensor::TensorView; use serde::Deserialize; +#[derive(Clone)] pub struct Logger { pub log: fn(string: String) -> (), } +#[derive(Clone)] +pub struct Timer { + pub now: fn() -> u128, +} + pub fn length(shape: Vec) -> usize { return shape.iter().fold(1, |i, x| i * x); } diff --git a/crates/core/src/wasm.rs b/crates/core/src/wasm.rs index d3709b4..4f2e335 100644 --- a/crates/core/src/wasm.rs +++ b/crates/core/src/wasm.rs @@ -1,26 +1,39 @@ use js_sys::{Array, Float32Array, Uint8Array}; use ndarray::ArrayD; - use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; -use crate::{Backend, Dataset, Logger, PredictOptions, TrainOptions, RESOURCES}; +use crate::{Backend, Dataset, Logger, PredictOptions, Timer, TrainOptions, RESOURCES}; #[wasm_bindgen] extern "C" { #[wasm_bindgen(js_namespace = console)] fn log(s: &str); + #[wasm_bindgen(js_namespace = Date)] + fn now() -> f64; + } fn console_log(string: String) { log(string.as_str()) } +fn performance_now() -> u128 { + now() as u128 +} + #[wasm_bindgen] pub fn wasm_backend_create(config: String, shape: Array) -> usize { let config = serde_json::from_str(&config).unwrap(); let mut len = 0; let logger = Logger { log: console_log }; - let net_backend = Backend::new(config, logger, None); + let net_backend = Backend::new( + config, + logger, + Timer { + now: performance_now, + }, + None, + ); shape.set_length(net_backend.size.len() as u32); for (i, s) in net_backend.size.iter().enumerate() { shape.set(i as u32, JsValue::from(*s)) @@ -37,7 +50,6 @@ pub fn wasm_backend_create(config: String, shape: Array) -> usize { #[wasm_bindgen] pub fn wasm_backend_train(id: usize, buffers: Vec, options: String) { let options: TrainOptions = serde_json::from_str(&options).unwrap(); - let mut datasets = Vec::new(); for i in 0..options.datasets { let input = buffers[i * 2].to_vec(); @@ -47,7 +59,6 @@ pub fn wasm_backend_train(id: usize, buffers: Vec, options: String outputs: ArrayD::from_shape_vec(options.output_shape.clone(), output).unwrap(), }); } - RESOURCES.with(|cell| { let mut backend = cell.backend.borrow_mut(); backend[id].train(datasets, options.epochs, options.batches, options.rate) @@ -59,11 +70,12 @@ pub fn wasm_backend_predict(id: usize, buffer: Float32Array, options: String) -> let options: PredictOptions = serde_json::from_str(&options).unwrap(); let inputs = ArrayD::from_shape_vec(options.input_shape, buffer.to_vec()).unwrap(); - let res = ArrayD::zeros(options.output_shape); + let mut res = ArrayD::zeros(options.output_shape.clone()); RESOURCES.with(|cell| { let mut backend = cell.backend.borrow_mut(); - let _res = backend[id].predict(inputs, options.layers); + let _res = backend[id].predict(inputs, options.post_process, options.layers); + res.assign(&ArrayD::from_shape_vec(options.output_shape, _res.as_slice().unwrap().to_vec()).unwrap()); }); Float32Array::from(res.as_slice().unwrap()) } @@ -82,7 +94,10 @@ pub fn wasm_backend_save(id: usize) -> Uint8Array { pub fn wasm_backend_load(buffer: Uint8Array, shape: Array) -> usize { let mut len = 0; let logger = Logger { log: console_log }; - let net_backend = Backend::load(buffer.to_vec().as_slice(), logger); + let timer = Timer { + now: performance_now, + }; + let net_backend = Backend::load(buffer.to_vec().as_slice(), logger, timer); shape.set_length(net_backend.size.len() as u32); for (i, s) in net_backend.size.iter().enumerate() { shape.set(i as u32, JsValue::from(*s)) diff --git a/examples/autoencoders/decoded.html b/examples/autoencoders/decoded.html index a25a3d1..f1d1a26 100644 --- a/examples/autoencoders/decoded.html +++ b/examples/autoencoders/decoded.html @@ -1,2 +1,2 @@ -
idx012345678910
07.558476924896240.60374689102172850.117951095104217531.96667981147766110.106578290462493910.06482601165771534.210468292236330.99668729305267333.35533070564270.68901741504669199.519981384277344
17.8810958862304690.57679510116577150.187343478202819822.68102669715881350.1204662919044494620.41938209533691468.40732574462890.99712616205215453.30222821235656740.74343585968017589.61740493774414
27.7428851127624510.58834218978881840.157615661621093752.37499642372131350.114517211914062515.98343753814697353.757221221923830.99693781137466433.3249773979187010.72012269496917729.575666427612305
37.7996149063110350.58360201120376590.169818013906478882.50061178207397460.1169580817222595217.8042335510253959.7705650329589840.99701493978500373.31563949584960940.72969222068786629.592796325683594
47.558476924896240.60374689102172850.117951095104217531.96667981147766110.106578290462493910.06482601165771534.210468292236330.99668729305267333.35533070564270.68901741504669199.519981384277344
57.6156573295593260.59897023439407350.130249917507171632.0932872295379640.1090410351753234911.90002346038818440.271377563476560.99676507711410523.3459186553955080.69866240024566659.537246704101562
67.78566694259643550.58476763963699340.166817426681518552.4697246551513670.116358935832977317.35653114318847758.291984558105470.99699592590332033.31793546676635740.72733914852142339.588583946228027
77.45732116699218750.61219817399978640.096193492412567141.74269998073577880.102224290370941166.81821060180664123.4882297515869140.99654948711395263.3719806671142580.6719548702239999.489435195922852
87.4161453247070310.61563813686370850.087336897850036621.6515282392501830.100451290607452395.49666118621826219.123691558837890.99649333953857423.3787584304809570.66500949859619149.477001190185547
98.1629238128662110.55324959754943850.24796092510223393.30504274368286130.132598161697387729.464593887329198.279960632324220.99750858545303343.25584197044372560.7909728884696969.702510833740234
107.8374028205871580.58044439554214480.177945256233215332.58427858352661130.1185846924781799319.01701927185058663.775897979736330.99706673622131353.3094198703765870.7360656857490549.604207992553711
118.1629238128662110.55324959754943850.24796092510223393.30504274368286130.132598161697387729.464593887329198.279960632324220.99750858545303343.25584197044372560.7909728884696969.702510833740234
127.7878942489624020.58458095788955690.167296111583709722.4746561050415040.1164526343345642117.42801666259765658.5280723571777340.99699860811233523.31756949424743650.72771453857421889.589258193969727
137.5106606483459470.60774213075637820.107665598392486571.860800027847290.104520440101623548.530093193054229.1418781280517580.99662232398986823.363201618194580.68095135688781749.505542755126953
148.62542438507080.51460629701614380.347441047430038454.3291277885437010.1525041460990905844.3088264465332147.30436706542970.99813938140869143.17971253395080570.86898672580718999.842171669006348
158.6488246917724610.51265501976013180.35247492790222174.3809347152709960.1535125374794006345.05979919433594149.784530639648440.99817049503326423.1758611202239990.87293350696563729.849235534667969
168.218014717102050.54864370822906490.25981095433235173.42703628540039060.1349675059318542531.232906341552734104.119964599609380.99758362770080573.2467720508575440.80026543140411389.719144821166992
177.76234912872314450.58671480417251590.16180223226547242.4180936813354490.1153549551963806216.60813522338867255.8203430175781250.9969641566276553.32177352905273440.72340559959411629.581544876098633
187.5032954216003420.60835766792297360.106082081794738771.8444962501525880.104203343391418468.29376316070556628.361375808715820.99661195278167723.36441349983215330.67970961332321179.503316879272461
197.76466369628906250.58652126789093020.162300288677215582.42321872711181640.1154540777206420916.68242263793945356.065681457519530.99696743488311773.3213942050933840.72379672527313239.582244873046875
\ No newline at end of file +idx01234567891007.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.64928531646728518.513451576232910.52541649341583250.27618211507797242.1018886566162110.159335896372795123.79069328308105568.288558959960941.00238001346588133.21208381652832030.8989087939262399.70976352691650427.3257198333740230.62166380882263180.083864927291870122.86311388015747070.065306633710861213.20198535919189554.199741363525391.00073778629302983.4179754257202150.62651437520980839.6567621231079138.4413099288940430.52269363403320310.278882861137390141.8793088197708130.1600578576326370219.86166381835937556.7057800292968751.0022829771041873.21761989593505860.90526127815246589.74307727813720747.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.64928531646728557.53944635391235350.69273591041564940.0081554651260375981.83751261234283450.04153010249137878412.68921661376953140.7506141662597661.01184713840484623.51825976371765140.82678222656259.44038963317871167.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.64928531646728578.2445058822631840.51526510715484620.28625023365020751.2721085548400880.162028402090072639.14320087432861325.1077384948730471.00201857089996343.23272585868835450.92258942127227789.83396720886230587.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.64928531646728597.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285107.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285117.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285127.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285137.6858797073364260.57098019123077390.229856371879577641.3650441169738770.18276323378086099.85977745056152329.229558944702150.96460103988647463.19617557525634770.81461906433105478.944458961486816149.0359354019165040.54513794183731080.25661963224411013.7138931751251220.1541044414043426552.24626922607422152.175598144531251.00308275222778323.17198204994201660.85290682315826429.468478202819824158.9632139205932620.54239344596862790.25934273004531863.4895493984222410.1548317670822143648.28606033325195140.500915527343751.00298535823822023.17756247520446780.85930877923965459.502055168151855168.7347488403320310.53376978635787960.26789784431457522.784646749496460.1571203768253326435.842952728271484103.81862640380861.00267744064331053.1950986385345460.87942385673522959.607568740844727178.3979330062866210.52105611562728880.28050637245178221.74548399448394780.1604927480220794717.49934387207031249.741657257080081.00222456455230713.22094893455505370.90907967090606699.763108253479004187.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285198.457516670227050.52330523729324340.27827548980712891.92930841445922850.1598957628011703520.74426460266113359.307685852050781.00230491161346443.2163770198822020.90383404493331919.735596656799316 \ No newline at end of file diff --git a/examples/autoencoders/encoded.html b/examples/autoencoders/encoded.html index ca7b752..2951915 100644 --- a/examples/autoencoders/encoded.html +++ b/examples/autoencoders/encoded.html @@ -1,2 +1,2 @@ -
idx01
02.5389764308929443-8.392889976501465
14.668254852294922-17.085620880126953
23.75606107711792-13.361610412597656
34.130483627319336-14.890182495117188
42.5389764308929443-8.392889976501465
52.91636061668396-9.933552742004395
64.038419723510742-14.514333724975586
71.8713524341583252-5.667331218719482
81.599593162536621-4.55787992477417
96.528283596038818-24.67914581298828
104.379877090454102-15.908326148986816
116.528283596038818-24.67914581298828
124.053119659423828-14.574345588684082
132.2233786582946777-7.104470252990723
149.580804824829102-37.1409912109375
159.735233306884766-37.77143859863281
166.891912937164307-26.163654327392578
173.8845224380493164-13.8860502243042
182.1747803688049316-6.906069278717041
193.8997983932495117-13.948413848876953
\ No newline at end of file +idx010-0.78356546163558960.6262750625610352118.180845260620117-15.3507776260375982-0.74291682243347170.5920295715332031313.401049613952637-11.32391643524174-0.78356546163558960.62627506256103525-0.38311046361923220.288901507854461676-0.78356546163558960.626275062561035270.36166173219680786-0.33855092525482188-0.78356546163558960.62627506256103529-0.78356546163558960.626275062561035210-0.78356546163558960.626275062561035211-0.78356546163558960.626275062561035212-0.78356546163558960.6262750625610352130.11124086380004883-0.127577424049377441452.79803466796875-44.5149154663085941547.98031234741211-40.456100463867191632.84283447265625-27.7031440734863281710.527204513549805-8.90277194976806618-0.78356546163558960.62627506256103521914.474761009216309-12.22849178314209 \ No newline at end of file diff --git a/examples/autoencoders/example.ts b/examples/autoencoders/example.ts index 352551f..7cb24f2 100644 --- a/examples/autoencoders/example.ts +++ b/examples/autoencoders/example.ts @@ -14,7 +14,7 @@ import { import { parse } from "jsr:@std/csv@1.0.3/parse"; const data = parse( - Deno.readTextFileSync("examples/autoencoders/winequality-red.csv"), + Deno.readTextFileSync("examples/autoencoders/winequality-red.csv") ); data.shift(); @@ -49,7 +49,7 @@ const net = new Sequential({ const input = tensor(X); const timeStart = performance.now(); -net.train([{ inputs: input, outputs: input }], 10000, 1, 0.01); +net.train([{ inputs: input, outputs: tensor(Float32Array.from(input.data), input.shape) }], 10000, 1, 0.001); console.log(`Trained in ${performance.now() - timeStart}ms`); function saveTable(name: string, data: Matrix<"f32">) { @@ -66,17 +66,14 @@ const output_mat = new Matrix<"f32">(output.data, output.shape as Shape2D); saveTable("output", output_mat); console.log("Running Encoder"); -const encoded = await net.predict(input, [0, 5]); +const encoded = await net.predict(input, { layers: [0, 5] }); const encoded_mat = new Matrix<"f32">(encoded.data, encoded.shape as Shape2D); saveTable("encoded", encoded_mat); console.log("Running Decoder"); -const decoded = await net.predict(tensor(encoded_mat), [ - 5, - 10, -]); +const decoded = await net.predict(tensor(encoded_mat), { layers: [5, 10] }); const decoded_mat = new Matrix<"f32">(decoded.data, decoded.shape as Shape2D); diff --git a/examples/autoencoders/input.html b/examples/autoencoders/input.html index 391ffb1..11e0793 100644 --- a/examples/autoencoders/input.html +++ b/examples/autoencoders/input.html @@ -1,2 +1,2 @@ -
idx012345678910
07.4000000953674320.69999998807907101.8999999761581420.0759999975562095611340.99779999256134033.5099999904632570.56000000238418589.399999618530273
17.8000001907348630.879999995231628402.59999990463256840.0979999974370002725670.99680000543594363.2000000476837160.68000000715255749.800000190734863
27.8000001907348630.75999999046325680.039999999105930332.2999999523162840.0920000001788139315540.9969999790191653.2599999904632570.64999997615814219.800000190734863
311.1999998092651370.28000000119209290.56000000238418581.8999999761581420.0750000029802322417600.99800002574920653.16000008583068850.57999998331069959.800000190734863
47.4000000953674320.69999998807907101.8999999761581420.0759999975562095611340.99779999256134033.5099999904632570.56000000238418589.399999618530273
57.4000000953674320.660000026226043701.79999995231628420.0750000029802322413400.99779999256134033.5099999904632570.56000000238418589.399999618530273
67.9000000953674320.60000002384185790.059999998658895491.6000000238418580.068999998271465315590.9963999986648563.2999999523162840.460000008344650279.399999618530273
77.3000001907348630.649999976158142101.20000004768371580.0649999976158142115210.99459999799728393.3900001049041750.469999998807907110
87.8000001907348630.57999998331069950.01999999955296516420.07299999892711649180.99680000543594363.3599998950958250.56999999284744269.5
97.50.50.360000014305114756.0999999046325680.07100000232458115171020.99779999256134033.34999990463256840.80000001192092910.5
106.6999998092651370.57999998331069950.079999998211860661.79999995231628420.0970000028610229515650.99589997529983523.27999997138977050.54000002145767219.199999809265137
117.50.50.360000014305114756.0999999046325680.07100000232458115171020.99779999256134033.34999990463256840.80000001192092910.5
125.5999999046325680.615000009536743201.6000000238418580.0890000015497207616590.99430000782012943.57999992370605470.51999998092651379.899999618530273
137.8000001907348630.61000001430511470.289999991655349731.6000000238418580.114000000059604649290.99739998579025273.2599999904632571.5599999427795419.100000381469727
148.8999996185302730.62000000476837160.180000007152557373.7999999523162840.17599999904632568521450.99860000610351563.16000008583068850.87999999523162849.199999809265137
158.8999996185302730.62000000476837160.18999999761581423.90000009536743160.17000000178813934511480.99860000610351563.17000007629394530.93000000715255749.199999809265137
168.50.28000000119209290.56000000238418581.79999995231628420.09200000017881393351030.99690002202987673.2999999523162840.7510.5
178.1000003814697270.56000000238418580.28000000119209291.70000004768371580.3680000007152557416560.99680000543594363.1099998950958251.27999997138977059.300000190734863
187.4000000953674320.58999997377395630.079999998211860664.4000000953674320.08600000292062766290.99739998579025273.3800001144409180.59
197.9000000953674320.31999999284744260.50999999046325681.79999995231628420.340999990701675417560.99690002202987673.03999996185302731.08000004291534429.199999809265137
\ No newline at end of file +idx01234567891007.4000000953674320.69999998807907101.8999999761581420.0759999975562095611340.99779999256134033.5099999904632570.56000000238418589.39999961853027317.8000001907348630.879999995231628402.59999990463256840.0979999974370002725670.99680000543594363.2000000476837160.68000000715255749.80000019073486327.8000001907348630.75999999046325680.039999999105930332.2999999523162840.0920000001788139315540.9969999790191653.2599999904632570.64999997615814219.800000190734863311.1999998092651370.28000000119209290.56000000238418581.8999999761581420.0750000029802322417600.99800002574920653.16000008583068850.57999998331069959.80000019073486347.4000000953674320.69999998807907101.8999999761581420.0759999975562095611340.99779999256134033.5099999904632570.56000000238418589.39999961853027357.4000000953674320.660000026226043701.79999995231628420.0750000029802322413400.99779999256134033.5099999904632570.56000000238418589.39999961853027367.9000000953674320.60000002384185790.059999998658895491.6000000238418580.068999998271465315590.9963999986648563.2999999523162840.460000008344650279.39999961853027377.3000001907348630.649999976158142101.20000004768371580.0649999976158142115210.99459999799728393.3900001049041750.46999999880790711087.8000001907348630.57999998331069950.01999999955296516420.07299999892711649180.99680000543594363.3599998950958250.56999999284744269.597.50.50.360000014305114756.0999999046325680.07100000232458115171020.99779999256134033.34999990463256840.80000001192092910.5106.6999998092651370.57999998331069950.079999998211860661.79999995231628420.0970000028610229515650.99589997529983523.27999997138977050.54000002145767219.199999809265137117.50.50.360000014305114756.0999999046325680.07100000232458115171020.99779999256134033.34999990463256840.80000001192092910.5125.5999999046325680.615000009536743201.6000000238418580.0890000015497207616590.99430000782012943.57999992370605470.51999998092651379.899999618530273137.8000001907348630.61000001430511470.289999991655349731.6000000238418580.114000000059604649290.99739998579025273.2599999904632571.5599999427795419.100000381469727148.8999996185302730.62000000476837160.180000007152557373.7999999523162840.17599999904632568521450.99860000610351563.16000008583068850.87999999523162849.199999809265137158.8999996185302730.62000000476837160.18999999761581423.90000009536743160.17000000178813934511480.99860000610351563.17000007629394530.93000000715255749.199999809265137168.50.28000000119209290.56000000238418581.79999995231628420.09200000017881393351030.99690002202987673.2999999523162840.7510.5178.1000003814697270.56000000238418580.28000000119209291.70000004768371580.3680000007152557416560.99680000543594363.1099998950958251.27999997138977059.300000190734863187.4000000953674320.58999997377395630.079999998211860664.4000000953674320.08600000292062766290.99739998579025273.3800001144409180.59197.9000000953674320.31999999284744260.50999999046325681.79999995231628420.340999990701675417560.99690002202987673.03999996185302731.08000004291534429.199999809265137 \ No newline at end of file diff --git a/examples/autoencoders/output.html b/examples/autoencoders/output.html index a25a3d1..f1d1a26 100644 --- a/examples/autoencoders/output.html +++ b/examples/autoencoders/output.html @@ -1,2 +1,2 @@ -
idx012345678910
07.558476924896240.60374689102172850.117951095104217531.96667981147766110.106578290462493910.06482601165771534.210468292236330.99668729305267333.35533070564270.68901741504669199.519981384277344
17.8810958862304690.57679510116577150.187343478202819822.68102669715881350.1204662919044494620.41938209533691468.40732574462890.99712616205215453.30222821235656740.74343585968017589.61740493774414
27.7428851127624510.58834218978881840.157615661621093752.37499642372131350.114517211914062515.98343753814697353.757221221923830.99693781137466433.3249773979187010.72012269496917729.575666427612305
37.7996149063110350.58360201120376590.169818013906478882.50061178207397460.1169580817222595217.8042335510253959.7705650329589840.99701493978500373.31563949584960940.72969222068786629.592796325683594
47.558476924896240.60374689102172850.117951095104217531.96667981147766110.106578290462493910.06482601165771534.210468292236330.99668729305267333.35533070564270.68901741504669199.519981384277344
57.6156573295593260.59897023439407350.130249917507171632.0932872295379640.1090410351753234911.90002346038818440.271377563476560.99676507711410523.3459186553955080.69866240024566659.537246704101562
67.78566694259643550.58476763963699340.166817426681518552.4697246551513670.116358935832977317.35653114318847758.291984558105470.99699592590332033.31793546676635740.72733914852142339.588583946228027
77.45732116699218750.61219817399978640.096193492412567141.74269998073577880.102224290370941166.81821060180664123.4882297515869140.99654948711395263.3719806671142580.6719548702239999.489435195922852
87.4161453247070310.61563813686370850.087336897850036621.6515282392501830.100451290607452395.49666118621826219.123691558837890.99649333953857423.3787584304809570.66500949859619149.477001190185547
98.1629238128662110.55324959754943850.24796092510223393.30504274368286130.132598161697387729.464593887329198.279960632324220.99750858545303343.25584197044372560.7909728884696969.702510833740234
107.8374028205871580.58044439554214480.177945256233215332.58427858352661130.1185846924781799319.01701927185058663.775897979736330.99706673622131353.3094198703765870.7360656857490549.604207992553711
118.1629238128662110.55324959754943850.24796092510223393.30504274368286130.132598161697387729.464593887329198.279960632324220.99750858545303343.25584197044372560.7909728884696969.702510833740234
127.7878942489624020.58458095788955690.167296111583709722.4746561050415040.1164526343345642117.42801666259765658.5280723571777340.99699860811233523.31756949424743650.72771453857421889.589258193969727
137.5106606483459470.60774213075637820.107665598392486571.860800027847290.104520440101623548.530093193054229.1418781280517580.99662232398986823.363201618194580.68095135688781749.505542755126953
148.62542438507080.51460629701614380.347441047430038454.3291277885437010.1525041460990905844.3088264465332147.30436706542970.99813938140869143.17971253395080570.86898672580718999.842171669006348
158.6488246917724610.51265501976013180.35247492790222174.3809347152709960.1535125374794006345.05979919433594149.784530639648440.99817049503326423.1758611202239990.87293350696563729.849235534667969
168.218014717102050.54864370822906490.25981095433235173.42703628540039060.1349675059318542531.232906341552734104.119964599609380.99758362770080573.2467720508575440.80026543140411389.719144821166992
177.76234912872314450.58671480417251590.16180223226547242.4180936813354490.1153549551963806216.60813522338867255.8203430175781250.9969641566276553.32177352905273440.72340559959411629.581544876098633
187.5032954216003420.60835766792297360.106082081794738771.8444962501525880.104203343391418468.29376316070556628.361375808715820.99661195278167723.36441349983215330.67970961332321179.503316879272461
197.76466369628906250.58652126789093020.162300288677215582.42321872711181640.1154540777206420916.68242263793945356.065681457519530.99696743488311773.3213942050933840.72379672527313239.582244873046875
\ No newline at end of file +idx01234567891007.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.64928531646728518.513451576232910.52541649341583250.27618211507797242.1018886566162110.159335896372795123.79069328308105568.288558959960941.00238001346588133.21208381652832030.8989087939262399.70976352691650427.3257198333740230.62166380882263180.083864927291870122.86311388015747070.065306633710861213.20198535919189554.199741363525391.00073778629302983.4179754257202150.62651437520980839.6567621231079138.4413099288940430.52269363403320310.278882861137390141.8793088197708130.1600578576326370219.86166381835937556.7057800292968751.0022829771041873.21761989593505860.90526127815246589.74307727813720747.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.64928531646728557.53944635391235350.69273591041564940.0081554651260375981.83751261234283450.04153010249137878412.68921661376953140.7506141662597661.01184713840484623.51825976371765140.82678222656259.44038963317871167.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.64928531646728578.2445058822631840.51526510715484620.28625023365020751.2721085548400880.162028402090072639.14320087432861325.1077384948730471.00201857089996343.23272585868835450.92258942127227789.83396720886230587.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.64928531646728597.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285107.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285117.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285127.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285137.6858797073364260.57098019123077390.229856371879577641.3650441169738770.18276323378086099.85977745056152329.229558944702150.96460103988647463.19617557525634770.81461906433105478.944458961486816149.0359354019165040.54513794183731080.25661963224411013.7138931751251220.1541044414043426552.24626922607422152.175598144531251.00308275222778323.17198204994201660.85290682315826429.468478202819824158.9632139205932620.54239344596862790.25934273004531863.4895493984222410.1548317670822143648.28606033325195140.500915527343751.00298535823822023.17756247520446780.85930877923965459.502055168151855168.7347488403320310.53376978635787960.26789784431457522.784646749496460.1571203768253326435.842952728271484103.81862640380861.00267744064331053.1950986385345460.87942385673522959.607568740844727178.3979330062866210.52105611562728880.28050637245178221.74548399448394780.1604927480220794717.49934387207031249.741657257080081.00222456455230713.22094893455505370.90907967090606699.763108253479004187.2839708328247070.60809707641601560.098782420158386232.9995722770690920.0732277631759643613.20593738555908255.855880737304690.9967240095138553.39345788955688480.59344971179962169.649285316467285198.457516670227050.52330523729324340.27827548980712891.92930841445922850.1598957628011703520.74426460266113359.307685852050781.00230491161346443.2163770198822020.90383404493331919.735596656799316 \ No newline at end of file diff --git a/examples/classification/binary_iris.ts b/examples/classification/binary_iris.ts index 7378241..1a51dc8 100644 --- a/examples/classification/binary_iris.ts +++ b/examples/classification/binary_iris.ts @@ -16,6 +16,8 @@ import { // Split the dataset useSplit, } from "../../packages/utilities/mod.ts"; +import { PostProcess } from "../../packages/core/src/core/api/postprocess.ts"; +import { AdamOptimizer, WASM } from "../../mod.ts"; // Define classes const classes = ["Setosa", "Versicolor"]; @@ -29,10 +31,10 @@ const x = data.map((fl) => fl.slice(0, 4).map(Number)); const y = data.map((fl) => classes.indexOf(fl[4])); // Split the dataset for training and testing -const [train, test] = useSplit({ ratio: [7, 3], shuffle: true }, x, y) +const [train, test] = useSplit({ ratio: [7, 3], shuffle: true }, x, y); // Setup the CPU backend for Netsaur -await setupBackend(CPU); +await setupBackend(WASM); // Create a sequential neural network const net = new Sequential({ @@ -56,6 +58,7 @@ const net = new Sequential({ ], // We are using Log Loss for finding cost cost: Cost.BinCrossEntropy, + optimizer: AdamOptimizer() }); const time = performance.now(); @@ -69,16 +72,17 @@ net.train( }, ], // Train for 150 epochs - 150, + 100, 1, // Use a smaller learning rate - 0.02, + 0.02 ); console.log(`training time: ${performance.now() - time}ms`); -const res = await net.predict(tensor2D(test[0])); +const res = await net.predict(tensor2D(test[0]), { + postProcess: PostProcess("step", { thresholds: [0.5], values: [0, 1] }), +}); -const y1 = res.data.map((x) => x < 0.5 ? 0 : 1); -const cMatrix = new ClassificationReport(test[1], y1); +const cMatrix = new ClassificationReport(test[1], res.data); console.log("Confusion Matrix: ", cMatrix); diff --git a/packages/core/src/backends/cpu/backend.ts b/packages/core/src/backends/cpu/backend.ts index 040fa32..b2e7717 100644 --- a/packages/core/src/backends/cpu/backend.ts +++ b/packages/core/src/backends/cpu/backend.ts @@ -10,6 +10,7 @@ import { type PredictOptions, type TrainOptions, } from "./util.ts"; +import type { PostProcessor } from "../../core/api/postprocess.ts"; /** * CPU Backend. @@ -68,25 +69,26 @@ export class CPUBackend implements Backend { ); } - async predict(input: Tensor): Promise>; + async predict(input: Tensor, config: {postProcess: PostProcessor, outputShape?: Shape}): Promise>; async predict( input: Tensor, + config: {postProcess: PostProcessor, outputShape?: Shape}, layers: number[], - outputShape: Shape, ): Promise>; //deno-lint-ignore require-await async predict( input: Tensor, - layers?: number[], - outputShape?: Shape, + config: {postProcess: PostProcessor, outputShape?: Shape}, + layers?: number[], ): Promise> { const options = encodeJSON({ inputShape: input.shape, - outputShape: [input.shape[0], ...(outputShape ?? this.outputShape)], + outputShape: [input.shape[0], ...(config.outputShape ?? this.outputShape)], + postProcess: config.postProcess, layers, } as PredictOptions); const output = new Float32Array( - input.shape[0] * length(outputShape ?? this.outputShape), + input.shape[0] * length(config.outputShape ?? this.outputShape), ); this.library.symbols.ffi_backend_predict( this.#id, @@ -99,7 +101,7 @@ export class CPUBackend implements Backend { output, [ input.shape[0], - ...(outputShape ?? this.outputShape), + ...(config.outputShape ?? this.outputShape), ] as Shape, ); } diff --git a/packages/core/src/backends/wasm/backend.ts b/packages/core/src/backends/wasm/backend.ts index 921861d..2f5fa3d 100644 --- a/packages/core/src/backends/wasm/backend.ts +++ b/packages/core/src/backends/wasm/backend.ts @@ -9,6 +9,7 @@ import { wasm_backend_save, wasm_backend_train, } from "./lib/netsaur.generated.js"; +import type { PostProcessor } from "../../core/api/postprocess.ts"; /** * Web Assembly Backend. @@ -32,7 +33,7 @@ export class WASMBackend implements Backend { datasets: DataSet[], epochs: number, batches: number, - rate: number, + rate: number ): void { this.outputShape = datasets[0].outputs.shape.slice(1) as Shape; const buffer = []; @@ -52,18 +53,39 @@ export class WASMBackend implements Backend { wasm_backend_train(this.#id, buffer, options); } + async predict( + input: Tensor, + config: { postProcess: PostProcessor; outputShape?: Shape } + ): Promise>; + async predict( + input: Tensor, + config: { postProcess: PostProcessor; outputShape?: Shape }, + layers: number[] + ): Promise>; //deno-lint-ignore require-await - async predict(input: Tensor): Promise> { + async predict( + input: Tensor, + config: { postProcess: PostProcessor; outputShape?: Shape }, + layers?: number[] + ): Promise> { const options = JSON.stringify({ - inputShape: [1, ...input.shape], - outputShape: this.outputShape, + inputShape: input.shape, + outputShape: [input.shape[0], ...(config.outputShape ?? this.outputShape)], + postProcess: config.postProcess, + layers, } as PredictOptions); const output = wasm_backend_predict( this.#id, input.data as Float32Array, - options, + options + ); + return new Tensor( + output, + [ + input.shape[0], + ...(config.outputShape ?? this.outputShape), + ] as Shape, ); - return new Tensor(output, this.outputShape!); } save(): Uint8Array { diff --git a/packages/core/src/backends/wasm/lib/netsaur.generated.js b/packages/core/src/backends/wasm/lib/netsaur.generated.js index f3523ed..061600c 100644 --- a/packages/core/src/backends/wasm/lib/netsaur.generated.js +++ b/packages/core/src/backends/wasm/lib/netsaur.generated.js @@ -4,7 +4,7 @@ // deno-fmt-ignore-file /// -// source-hash: c1eff57085f8488444a8499d3d2fcad1650a7099 +// source-hash: f1db375a60100b82256510390e36e2ea79669a26 let wasm; let cachedInt32Memory0; @@ -63,6 +63,12 @@ function getStringFromWasm0(ptr, len) { return cachedTextDecoder.decode(getUint8Memory0().subarray(ptr, ptr + len)); } +function notDefined(what) { + return () => { + throw new Error(`${what} is not defined`); + }; +} + let WASM_VECTOR_LEN = 0; const cachedTextEncoder = typeof TextEncoder !== "undefined" @@ -220,6 +226,9 @@ const imports = { __wbg_log_6f7dfa87fad40a57: function (arg0, arg1) { console.log(getStringFromWasm0(arg0, arg1)); }, + __wbg_now_de5fe0de473bcd7d: typeof Date.now == "function" + ? Date.now + : notDefined("Date.now"), __wbindgen_number_new: function (arg0) { const ret = arg0; return addHeapObject(ret); @@ -249,6 +258,10 @@ const imports = { const ret = typeof (getObject(arg0)) === "string"; return ret; }, + __wbg_msCrypto_bcb970640f50a1e8: function (arg0) { + const ret = getObject(arg0).msCrypto; + return addHeapObject(ret); + }, __wbg_require_8f08ceecec0f4fee: function () { return handleError(function () { const ret = module.require; @@ -263,10 +276,6 @@ const imports = { const ret = getStringFromWasm0(arg0, arg1); return addHeapObject(ret); }, - __wbg_msCrypto_bcb970640f50a1e8: function (arg0) { - const ret = getObject(arg0).msCrypto; - return addHeapObject(ret); - }, __wbg_randomFillSync_dc1e9a60c158336d: function () { return handleError(function (arg0, arg1) { getObject(arg0).randomFillSync(takeObject(arg1)); diff --git a/packages/core/src/backends/wasm/lib/netsaur_bg.wasm b/packages/core/src/backends/wasm/lib/netsaur_bg.wasm index 511a32b..bb49444 100644 Binary files a/packages/core/src/backends/wasm/lib/netsaur_bg.wasm and b/packages/core/src/backends/wasm/lib/netsaur_bg.wasm differ diff --git a/packages/core/src/core/api/postprocess.ts b/packages/core/src/core/api/postprocess.ts new file mode 100644 index 0000000..6279510 --- /dev/null +++ b/packages/core/src/core/api/postprocess.ts @@ -0,0 +1,22 @@ +/** Post-processing step only occuring during prediction routine */ +export type PostProcessor = + | { type: "none" } + | { type: "sign" } + | { type: "step"; config: StepFunctionConfig }; + +type StepFunctionConfig = { thresholds: number[]; values: number[] }; + +export function PostProcess(pType: "none" | "sign"): PostProcessor; +export function PostProcess( + pType: "step", + config: StepFunctionConfig +): PostProcessor; +export function PostProcess( + pType: "none" | "sign" | "step", + config?: StepFunctionConfig +) { + if (pType === "none" || pType === "sign") { + return { type: pType }; + } + return { type: pType, config }; +} diff --git a/packages/core/src/core/mod.ts b/packages/core/src/core/mod.ts index 454fec1..b34a55e 100644 --- a/packages/core/src/core/mod.ts +++ b/packages/core/src/core/mod.ts @@ -11,6 +11,7 @@ import type { Rank } from "./api/shape.ts"; import type { Tensor } from "./tensor/tensor.ts"; import type { NeuralNetwork } from "./api/network.ts"; import { SGDOptimizer } from "./api/optimizer.ts"; +import { PostProcess, type PostProcessor } from "./api/postprocess.ts"; /** * Sequential Neural Network @@ -46,21 +47,27 @@ export class Sequential implements NeuralNetwork { */ async predict( data: Tensor, - layers?: [number, number], + config?: { postProcess?: PostProcessor; layers?: [number, number] } ): Promise> { - if (layers) { - if (layers[0] < 0 || layers[1] > this.config.layers.length) { + if (!config) + config = { + postProcess: PostProcess("none"), + }; + if (config.layers) { + if ( + config.layers[0] < 0 || + config.layers[1] > this.config.layers.length + ) { throw new RangeError( - `Execution range should be within (0, ${this.config.layers.length}). Received (${(layers[ - 0 - ], - layers[1])})`, + `Execution range should be within (0, ${ + this.config.layers.length + }). Received (${(config.layers[0], config.layers[1])})` ); } - const lastLayer = this.config.layers[layers[1] - 1]; - const layerList = new Array(layers[1] - layers[0]); + const lastLayer = this.config.layers[config.layers[1] - 1]; + const layerList = new Array(config.layers[1] - config.layers[0]); for (let i = 0; i < layerList.length; i += 1) { - layerList[i] = layers[0] + i; + layerList[i] = config.layers[0] + i; } if ( lastLayer.type === LayerType.Dense || @@ -68,32 +75,43 @@ export class Sequential implements NeuralNetwork { ) { return await this.backend.predict( data, - layerList, - lastLayer.config.size, + { + postProcess: config.postProcess || PostProcess("none"), + outputShape: lastLayer.config.size, + }, + layerList ); } else if (lastLayer.type === LayerType.Activation) { - const penultimate = this.config.layers[layers[1] - 2]; + const penultimate = this.config.layers[config.layers[1] - 2]; if ( penultimate.type === LayerType.Dense || penultimate.type === LayerType.Flatten ) { return await this.backend.predict( data, - layerList, - penultimate.config.size, + { + postProcess: config.postProcess || PostProcess("none"), + outputShape: penultimate.config.size, + }, + layerList ); } else { throw new Error( - `The penultimate layer must be a dense layer, or a flatten layer if the last layer is an activation layer. Received ${penultimate.type}.`, + `The penultimate layer must be a dense layer, or a flatten layer if the last layer is an activation layer. Received ${penultimate.type}.` ); } } else { throw new Error( - `The output layer must be a dense layer, activation layer, or a flatten layer. Received ${lastLayer.type}.`, + `The output layer must be a dense layer, activation layer, or a flatten layer. Received ${lastLayer.type}.` ); } } - return await this.backend.predict(data); + return await this.backend.predict( + data, + config.postProcess + ? (config as { postProcess: PostProcessor; layers?: [number, number] }) + : { ...config, postProcess: PostProcess("none") } + ); } /** diff --git a/packages/core/src/core/types.ts b/packages/core/src/core/types.ts index 149c3a3..7583c9f 100644 --- a/packages/core/src/core/types.ts +++ b/packages/core/src/core/types.ts @@ -3,6 +3,7 @@ import type { Rank, Shape } from "./api/shape.ts"; import type { Layer } from "./api/layer.ts"; import type { Optimizer } from "./api/optimizer.ts"; import type { Scheduler } from "./api/scheduler.ts"; +import type { PostProcessor } from "./api/postprocess.ts"; /** * The Backend is responsible for eveything related to the neural network. @@ -36,8 +37,8 @@ export interface Backend { */ predict( input: Tensor, + config: {postProcess: PostProcessor, outputShape?: Shape}, layers?: number[], - outputShape?: Shape, ): Promise>; /**