Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
86f6643
Do not build libraries when using runtime linking
therealfrauholle Nov 20, 2025
569739b
Emit correct flags when using dynamic linking
therealfrauholle Nov 23, 2025
53419b8
Idiomatic imports
therealfrauholle Nov 23, 2025
0a996bc
Add doctest example
therealfrauholle Nov 20, 2025
6a37d3b
Refactor example
therealfrauholle Nov 21, 2025
05774b6
Generalize imports
therealfrauholle Nov 22, 2025
d55c390
Fix lint "renamed_and_removed_lints"
therealfrauholle Nov 21, 2025
d9f75d1
Fix lint "missing_abi"
therealfrauholle Nov 21, 2025
9cfaf29
Fix lint clippy too man args
therealfrauholle Nov 24, 2025
af15bf3
Fix lint "clippy::expect_fn_call"
therealfrauholle Nov 21, 2025
8039c3d
Fix lint "clippy::double_ended_iterator_last"
therealfrauholle Nov 21, 2025
ee66f64
Fix lint "clippy::match_like_matches_macro"
therealfrauholle Nov 21, 2025
e89a1bd
Fix lint "clippy::redundant_pattern"
therealfrauholle Nov 21, 2025
6fce2ce
Fix lint "clippy::approx_constant"
therealfrauholle Nov 21, 2025
47284d5
Fix lint "clippy::single_component_path_imports"
therealfrauholle Nov 21, 2025
0840e2c
Fix lint "clippy::single_match"
therealfrauholle Nov 21, 2025
e877e06
Fix lint "clippy::needless_question_mark"
therealfrauholle Nov 21, 2025
fd31797
Fix lint "clippy::needless_borrow"
therealfrauholle Nov 21, 2025
0469736
Fix lint "clippy::too_many_arguments"
therealfrauholle Nov 24, 2025
22a96cf
Fix lint "clippy::needless_range_loop"
therealfrauholle Nov 21, 2025
6015bd1
Fix lint "clippy::needless_range_loop"
therealfrauholle Nov 22, 2025
d5539dc
Fix lint "clippy::map_clone"
therealfrauholle Nov 24, 2025
cf5026c
Fix lint "clippy::bool_assert_comparison"
therealfrauholle Nov 22, 2025
70b931a
Fix lint "clippy::len_zero"
therealfrauholle Nov 22, 2025
03d391b
Fix lint "clippy::manual_rotate"
therealfrauholle Nov 22, 2025
362003e
Fix lint "clippy::suspicious_open_options"
therealfrauholle Nov 22, 2025
3bf7b49
Fix lint "dead_code"
therealfrauholle Nov 22, 2025
9c041fb
Fix lint "elided_named_lifetimes"
therealfrauholle Nov 24, 2025
4c88de4
Fix lint "clippy::needless_borrows_for_generic_args"
therealfrauholle Nov 24, 2025
4946d5c
Fix lint "unnecessary_casts"
therealfrauholle Nov 24, 2025
f7ca75c
Fix lint "clippy::derivable_impls"
therealfrauholle Nov 24, 2025
11e686a
Fix lint "useless_vec"
therealfrauholle Nov 24, 2025
cc77341
Fix lint "clippy::needless_range_loop"
therealfrauholle Nov 24, 2025
965f9c9
Fix lint "clippy::redundant_closure"
therealfrauholle Nov 24, 2025
035b8b7
Fix lint "clippy::missing_safety_doc"
therealfrauholle Nov 24, 2025
4813260
Fix lint "clippy::redundant_static_lifetime"
therealfrauholle Nov 24, 2025
c1db4b7
Exclude invalid tests
therealfrauholle Nov 24, 2025
f8de427
Fix lints in op-codegen
therealfrauholle Nov 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ default = ["tensorflow-sys"]
experimental = ["tensorflow-sys/experimental"]
tensorflow_gpu = ["tensorflow-sys/tensorflow_gpu"]
tensorflow_unstable = []
tensorflow_runtime_linking = ["tensorflow-sys-runtime"]
tensorflow_runtime_linking = ["tensorflow-sys-runtime", "tensorflow-sys/tensorflow_runtime_linking"]
eager = ["tensorflow-sys/eager"]
# This is for testing purposes; users should not use this.
examples_system_alloc = ["tensorflow-sys/examples_system_alloc"]
Expand Down
3 changes: 1 addition & 2 deletions examples/mobilenetv3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ fn main() -> Result<(), Box<dyn Error>> {

// Load the model.
let mut graph = Graph::new();
let bundle =
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
let bundle = SavedModelBundle::load(&SessionOptions::new(), ["serve"], &mut graph, export_dir)?;
let session = &bundle.session;

// get in/out operations
Expand Down
1 change: 0 additions & 1 deletion examples/regression.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use rand;
use std::error::Error;
use std::fs::File;
use std::io::Read;
Expand Down
1 change: 0 additions & 1 deletion examples/regression_checkpoint.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use rand;
use std::error::Error;
use std::fs::File;
use std::io::Read;
Expand Down
4 changes: 1 addition & 3 deletions examples/regression_savedmodel.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use rand;
use std::error::Error;
use std::path::Path;
use std::result::Result;
Expand Down Expand Up @@ -44,8 +43,7 @@ fn main() -> Result<(), Box<dyn Error>> {

// Load the saved model exported by regression_savedmodel.py.
let mut graph = Graph::new();
let bundle =
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
let bundle = SavedModelBundle::load(&SessionOptions::new(), ["serve"], &mut graph, export_dir)?;
let session = &bundle.session;

// train
Expand Down
131 changes: 66 additions & 65 deletions examples/xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tensorflow::train::Optimizer;
use tensorflow::Code;
use tensorflow::DataType;
use tensorflow::Graph;
use tensorflow::Operation;
use tensorflow::Output;
use tensorflow::OutputName;
use tensorflow::SavedModelBundle;
Expand All @@ -34,7 +35,7 @@ use tensorflow::REGRESS_OUTPUTS;
// function such as tanh.
//
// Returns variables created and the layer output.
fn layer<O1: Into<Output>>(
fn build_layer<O1: Into<Output>>(
input: O1,
input_size: u64,
output_size: u64,
Expand Down Expand Up @@ -70,7 +71,33 @@ fn layer<O1: Into<Output>>(
))
}

fn train<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
/// Helper that generates a training sample from an integer, trains on that
/// example, and returns the error.
fn train(
session: &Session,
sample_seed: usize,
error: &Operation,
optimize: &Operation,
input: &Operation,
output: &Operation,
) -> Result<f32, Box<dyn Error>> {
let mut input_tensor = Tensor::<f32>::new(&[1, 2]);
let mut label_tensor = Tensor::<f32>::new(&[1]);
input_tensor[0] = (sample_seed & 1) as f32;
input_tensor[1] = ((sample_seed >> 1) & 1) as f32;
label_tensor[0] = ((sample_seed & 1) ^ ((sample_seed >> 1) & 1)) as f32;
let mut run_args = SessionRunArgs::new();
run_args.add_target(optimize);
let error_squared_fetch = run_args.request_fetch(error, 0);
run_args.add_feed(input, 0, &input_tensor);
run_args.add_feed(output, 0, &label_tensor);
session.run(&mut run_args)?;
Ok(run_args.fetch::<f32>(error_squared_fetch)?[0])
}
/// Train a model designed to solve the xor problem, saving it.
///
/// It uses one hidden layer.
fn build_and_train_and_save<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
// ================
// Build the model.
// ================
Expand All @@ -87,31 +114,30 @@ fn train<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
.dtype(DataType::Float)
.shape([1u64])
.build(&mut scope.with_op_name("label"))?;
// Hidden layer.
let (vars1, layer1) = layer(
let mut custom_variables = Vec::new();
let (variables_hidden, hidden_layer) = build_layer(
input.clone(),
2,
hidden_size,
&|x, scope| Ok(ops::tanh(x, scope)?.into()),
scope,
)?;
// Output layer.
let (vars2, layer2) = layer(layer1.clone(), hidden_size, 1, &|x, _| Ok(x), scope)?;
let error = ops::sub(layer2.clone(), label.clone(), scope)?;
custom_variables.extend(variables_hidden);
let (variables_output, output_layer) =
build_layer(hidden_layer.clone(), hidden_size, 1, &|x, _| Ok(x), scope)?;
let error = ops::sub(output_layer.clone(), label.clone(), scope)?;
let error_squared = ops::mul(error.clone(), error, scope)?;
let mut optimizer = AdadeltaOptimizer::new();
optimizer.set_learning_rate(ops::constant(1.0f32, scope)?);
let mut variables = Vec::new();
variables.extend(vars1);
variables.extend(vars2);
let (minimizer_vars, minimize) = optimizer.minimize(
custom_variables.extend(variables_output);
let (minimizer_variables, minimize) = optimizer.minimize(
scope,
error_squared.clone().into(),
MinimizeOptions::default().with_variables(&variables),
MinimizeOptions::default().with_variables(&custom_variables),
)?;

let mut all_vars = variables.clone();
all_vars.extend_from_slice(&minimizer_vars);
let mut all_vars = custom_variables.clone();
all_vars.extend_from_slice(&minimizer_variables);
let mut builder = tensorflow::SavedModelBuilder::new();
builder
.add_collection("train", &all_vars)
Expand All @@ -132,7 +158,7 @@ fn train<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
);
def.add_output_info(
REGRESS_OUTPUTS.to_string(),
TensorInfo::new(DataType::Float, Shape::from(None), layer2.name()?),
TensorInfo::new(DataType::Float, Shape::from(None), output_layer.name()?),
);
def
});
Expand All @@ -142,53 +168,31 @@ fn train<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
// Initialize the variables.
// =========================
let options = SessionOptions::new();
let g = scope.graph_mut();
let session = Session::new(&options, &g)?;
let graph = scope.graph_mut();
let session = Session::new(&options, &graph)?;
let mut run_args = SessionRunArgs::new();
// Initialize variables we defined.
for var in &variables {
run_args.add_target(&var.initializer());
for var in &custom_variables {
run_args.add_target(var.initializer());
}
// Initialize variables the optimizer defined.
for var in &minimizer_vars {
run_args.add_target(&var.initializer());
for var in &minimizer_variables {
run_args.add_target(var.initializer());
}
session.run(&mut run_args)?;

// ================
// Train the model.
// ================
let mut input_tensor = Tensor::<f32>::new(&[1, 2]);
let mut label_tensor = Tensor::<f32>::new(&[1]);
// Helper that generates a training example from an integer, trains on that
// example, and returns the error.
let mut train = |i| -> Result<f32, Box<dyn Error>> {
input_tensor[0] = (i & 1) as f32;
input_tensor[1] = ((i >> 1) & 1) as f32;
label_tensor[0] = ((i & 1) ^ ((i >> 1) & 1)) as f32;
let mut run_args = SessionRunArgs::new();
run_args.add_target(&minimize);
let error_squared_fetch = run_args.request_fetch(&error_squared, 0);
run_args.add_feed(&input, 0, &input_tensor);
run_args.add_feed(&label, 0, &label_tensor);
session.run(&mut run_args)?;
Ok(run_args.fetch::<f32>(error_squared_fetch)?[0])
};
for i in 0..10000 {
train(i)?;
train(&session, i, &error_squared, &minimize, &input, &label)?;
}

// ================
// Save the model.
// ================
saved_model_saver.save(&session, &g, &save_dir)?;
saved_model_saver.save(&session, &graph, &save_dir)?;

// ===================
// Evaluate the model.
// ===================
for i in 0..4 {
let error = train(i)?;
println!("Error: {}", error);
let error = train(&session, i, &error_squared, &minimize, &input, &label)?;
println!("Error after training: {}", error);
if error > 0.1 {
return Err(Box::new(Status::new_set(
Code::Internal,
Expand All @@ -203,7 +207,7 @@ fn eval<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
let mut graph = Graph::new();
let bundle = SavedModelBundle::load(
&SessionOptions::new(),
&["serve", "train"],
["serve", "train"],
&mut graph,
save_dir,
)?;
Expand All @@ -215,10 +219,10 @@ fn eval<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
let output_op = graph.operation_by_name_required(&output_info.name().name)?;

let mut input_tensor = Tensor::<f32>::new(&[1, 2]);
for i in 0..4 {
input_tensor[0] = (i & 1) as f32;
input_tensor[1] = ((i >> 1) & 1) as f32;
let expected = ((i & 1) ^ ((i >> 1) & 1)) as f32;
for sample_seed in 0..4 {
input_tensor[0] = (sample_seed & 1) as f32;
input_tensor[1] = ((sample_seed >> 1) & 1) as f32;
let expected = ((sample_seed & 1) ^ ((sample_seed >> 1) & 1)) as f32;
let mut run_args = SessionRunArgs::new();
run_args.add_feed(&input_op, input_info.name().index, &input_tensor);
let output_fetch = run_args.request_fetch(&output_op, output_info.name().index);
Expand All @@ -237,28 +241,25 @@ fn eval<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
Ok(())
}

/// Train a model on the xor[^xor] problem and evaluate it after.
///
/// [^xor]: https://en.wikipedia.org/wiki/Perceptron#History
fn main() -> Result<(), Box<dyn Error>> {
let mut dir = env::temp_dir();
dir.push("tf-rust-example-xor-saved-model");
let mut dir2 = env::temp_dir();
dir2.push("tf-rust-example-xor-saved-model2");
match fs::remove_dir_all(&dir) {
Err(e) => {
if e.kind() != ErrorKind::NotFound {
return Err(Box::new(e));
}
if let Err(e) = fs::remove_dir_all(&dir) {
if e.kind() != ErrorKind::NotFound {
return Err(Box::new(e));
}
Ok(_) => (),
}
match fs::remove_dir_all(&dir2) {
Err(e) => {
if e.kind() != ErrorKind::NotFound {
return Err(Box::new(e));
}
if let Err(e) = fs::remove_dir_all(&dir2) {
if e.kind() != ErrorKind::NotFound {
return Err(Box::new(e));
}
Ok(_) => (),
}
train(&dir)?;
build_and_train_and_save(&dir)?;
// Ensure that the saved model works even when moved.
// Users do not need to do this; this is purely for testing purposes.
fs::rename(&dir, &dir2)?;
Expand Down
6 changes: 1 addition & 5 deletions src/buffer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::TensorType;
use crate::tf;
use libc::size_t;
use std::alloc;
use std::borrow::Borrow;
Expand All @@ -16,11 +17,6 @@ use std::ops::RangeTo;
use std::os::raw::c_void as std_c_void;
use std::process;
use std::slice;
#[cfg(feature = "default")]
use tensorflow_sys as tf;
#[cfg(feature = "tensorflow_runtime_linking")]
use tensorflow_sys_runtime as tf;

/// Fixed-length heap-allocated vector.
/// This is basically a `Box<[T]>`, except that that type can't actually be constructed.
/// Furthermore, `[T; N]` can't be constructed if N is not a compile-time constant.
Expand Down
Loading