diff --git a/Cargo.toml b/Cargo.toml index 9fd2f997de..c738570d28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ default = ["tensorflow-sys"] experimental = ["tensorflow-sys/experimental"] tensorflow_gpu = ["tensorflow-sys/tensorflow_gpu"] tensorflow_unstable = [] -tensorflow_runtime_linking = ["tensorflow-sys-runtime"] +tensorflow_runtime_linking = ["tensorflow-sys-runtime", "tensorflow-sys/tensorflow_runtime_linking"] eager = ["tensorflow-sys/eager"] # This is for testing purposes; users should not use this. examples_system_alloc = ["tensorflow-sys/examples_system_alloc"] diff --git a/examples/mobilenetv3.rs b/examples/mobilenetv3.rs index 7441216ca1..5b4b4a32b2 100644 --- a/examples/mobilenetv3.rs +++ b/examples/mobilenetv3.rs @@ -47,8 +47,7 @@ fn main() -> Result<(), Box> { // Load the model. let mut graph = Graph::new(); - let bundle = - SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?; + let bundle = SavedModelBundle::load(&SessionOptions::new(), ["serve"], &mut graph, export_dir)?; let session = &bundle.session; // get in/out operations diff --git a/examples/regression.rs b/examples/regression.rs index 5393b7ace0..773303d380 100644 --- a/examples/regression.rs +++ b/examples/regression.rs @@ -1,4 +1,3 @@ -use rand; use std::error::Error; use std::fs::File; use std::io::Read; diff --git a/examples/regression_checkpoint.rs b/examples/regression_checkpoint.rs index 0cbdb91cd4..8316781dd9 100644 --- a/examples/regression_checkpoint.rs +++ b/examples/regression_checkpoint.rs @@ -1,4 +1,3 @@ -use rand; use std::error::Error; use std::fs::File; use std::io::Read; diff --git a/examples/regression_savedmodel.rs b/examples/regression_savedmodel.rs index 22ff245254..7f76ebcc9e 100644 --- a/examples/regression_savedmodel.rs +++ b/examples/regression_savedmodel.rs @@ -1,4 +1,3 @@ -use rand; use std::error::Error; use std::path::Path; use std::result::Result; @@ -44,8 +43,7 @@ fn main() -> Result<(), Box> { // Load the saved model exported by regression_savedmodel.py. let mut graph = Graph::new(); - let bundle = - SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?; + let bundle = SavedModelBundle::load(&SessionOptions::new(), ["serve"], &mut graph, export_dir)?; let session = &bundle.session; // train diff --git a/examples/xor.rs b/examples/xor.rs index 7cf532e9fd..824067c27e 100644 --- a/examples/xor.rs +++ b/examples/xor.rs @@ -11,6 +11,7 @@ use tensorflow::train::Optimizer; use tensorflow::Code; use tensorflow::DataType; use tensorflow::Graph; +use tensorflow::Operation; use tensorflow::Output; use tensorflow::OutputName; use tensorflow::SavedModelBundle; @@ -34,7 +35,7 @@ use tensorflow::REGRESS_OUTPUTS; // function such as tanh. // // Returns variables created and the layer output. -fn layer>( +fn build_layer>( input: O1, input_size: u64, output_size: u64, @@ -70,7 +71,33 @@ fn layer>( )) } -fn train>(save_dir: P) -> Result<(), Box> { +/// Helper that generates a training sample from an integer, trains on that +/// example, and returns the error. +fn train( + session: &Session, + sample_seed: usize, + error: &Operation, + optimize: &Operation, + input: &Operation, + output: &Operation, +) -> Result> { + let mut input_tensor = Tensor::::new(&[1, 2]); + let mut label_tensor = Tensor::::new(&[1]); + input_tensor[0] = (sample_seed & 1) as f32; + input_tensor[1] = ((sample_seed >> 1) & 1) as f32; + label_tensor[0] = ((sample_seed & 1) ^ ((sample_seed >> 1) & 1)) as f32; + let mut run_args = SessionRunArgs::new(); + run_args.add_target(optimize); + let error_squared_fetch = run_args.request_fetch(error, 0); + run_args.add_feed(input, 0, &input_tensor); + run_args.add_feed(output, 0, &label_tensor); + session.run(&mut run_args)?; + Ok(run_args.fetch::(error_squared_fetch)?[0]) +} +/// Train a model designed to solve the xor problem, saving it. +/// +/// It uses one hidden layer. +fn build_and_train_and_save>(save_dir: P) -> Result<(), Box> { // ================ // Build the model. // ================ @@ -87,31 +114,30 @@ fn train>(save_dir: P) -> Result<(), Box> { .dtype(DataType::Float) .shape([1u64]) .build(&mut scope.with_op_name("label"))?; - // Hidden layer. - let (vars1, layer1) = layer( + let mut custom_variables = Vec::new(); + let (variables_hidden, hidden_layer) = build_layer( input.clone(), 2, hidden_size, &|x, scope| Ok(ops::tanh(x, scope)?.into()), scope, )?; - // Output layer. - let (vars2, layer2) = layer(layer1.clone(), hidden_size, 1, &|x, _| Ok(x), scope)?; - let error = ops::sub(layer2.clone(), label.clone(), scope)?; + custom_variables.extend(variables_hidden); + let (variables_output, output_layer) = + build_layer(hidden_layer.clone(), hidden_size, 1, &|x, _| Ok(x), scope)?; + let error = ops::sub(output_layer.clone(), label.clone(), scope)?; let error_squared = ops::mul(error.clone(), error, scope)?; let mut optimizer = AdadeltaOptimizer::new(); optimizer.set_learning_rate(ops::constant(1.0f32, scope)?); - let mut variables = Vec::new(); - variables.extend(vars1); - variables.extend(vars2); - let (minimizer_vars, minimize) = optimizer.minimize( + custom_variables.extend(variables_output); + let (minimizer_variables, minimize) = optimizer.minimize( scope, error_squared.clone().into(), - MinimizeOptions::default().with_variables(&variables), + MinimizeOptions::default().with_variables(&custom_variables), )?; - let mut all_vars = variables.clone(); - all_vars.extend_from_slice(&minimizer_vars); + let mut all_vars = custom_variables.clone(); + all_vars.extend_from_slice(&minimizer_variables); let mut builder = tensorflow::SavedModelBuilder::new(); builder .add_collection("train", &all_vars) @@ -132,7 +158,7 @@ fn train>(save_dir: P) -> Result<(), Box> { ); def.add_output_info( REGRESS_OUTPUTS.to_string(), - TensorInfo::new(DataType::Float, Shape::from(None), layer2.name()?), + TensorInfo::new(DataType::Float, Shape::from(None), output_layer.name()?), ); def }); @@ -142,53 +168,31 @@ fn train>(save_dir: P) -> Result<(), Box> { // Initialize the variables. // ========================= let options = SessionOptions::new(); - let g = scope.graph_mut(); - let session = Session::new(&options, &g)?; + let graph = scope.graph_mut(); + let session = Session::new(&options, &graph)?; let mut run_args = SessionRunArgs::new(); // Initialize variables we defined. - for var in &variables { - run_args.add_target(&var.initializer()); + for var in &custom_variables { + run_args.add_target(var.initializer()); } // Initialize variables the optimizer defined. - for var in &minimizer_vars { - run_args.add_target(&var.initializer()); + for var in &minimizer_variables { + run_args.add_target(var.initializer()); } session.run(&mut run_args)?; - // ================ - // Train the model. - // ================ - let mut input_tensor = Tensor::::new(&[1, 2]); - let mut label_tensor = Tensor::::new(&[1]); - // Helper that generates a training example from an integer, trains on that - // example, and returns the error. - let mut train = |i| -> Result> { - input_tensor[0] = (i & 1) as f32; - input_tensor[1] = ((i >> 1) & 1) as f32; - label_tensor[0] = ((i & 1) ^ ((i >> 1) & 1)) as f32; - let mut run_args = SessionRunArgs::new(); - run_args.add_target(&minimize); - let error_squared_fetch = run_args.request_fetch(&error_squared, 0); - run_args.add_feed(&input, 0, &input_tensor); - run_args.add_feed(&label, 0, &label_tensor); - session.run(&mut run_args)?; - Ok(run_args.fetch::(error_squared_fetch)?[0]) - }; for i in 0..10000 { - train(i)?; + train(&session, i, &error_squared, &minimize, &input, &label)?; } - // ================ - // Save the model. - // ================ - saved_model_saver.save(&session, &g, &save_dir)?; + saved_model_saver.save(&session, &graph, &save_dir)?; // =================== // Evaluate the model. // =================== for i in 0..4 { - let error = train(i)?; - println!("Error: {}", error); + let error = train(&session, i, &error_squared, &minimize, &input, &label)?; + println!("Error after training: {}", error); if error > 0.1 { return Err(Box::new(Status::new_set( Code::Internal, @@ -203,7 +207,7 @@ fn eval>(save_dir: P) -> Result<(), Box> { let mut graph = Graph::new(); let bundle = SavedModelBundle::load( &SessionOptions::new(), - &["serve", "train"], + ["serve", "train"], &mut graph, save_dir, )?; @@ -215,10 +219,10 @@ fn eval>(save_dir: P) -> Result<(), Box> { let output_op = graph.operation_by_name_required(&output_info.name().name)?; let mut input_tensor = Tensor::::new(&[1, 2]); - for i in 0..4 { - input_tensor[0] = (i & 1) as f32; - input_tensor[1] = ((i >> 1) & 1) as f32; - let expected = ((i & 1) ^ ((i >> 1) & 1)) as f32; + for sample_seed in 0..4 { + input_tensor[0] = (sample_seed & 1) as f32; + input_tensor[1] = ((sample_seed >> 1) & 1) as f32; + let expected = ((sample_seed & 1) ^ ((sample_seed >> 1) & 1)) as f32; let mut run_args = SessionRunArgs::new(); run_args.add_feed(&input_op, input_info.name().index, &input_tensor); let output_fetch = run_args.request_fetch(&output_op, output_info.name().index); @@ -237,28 +241,25 @@ fn eval>(save_dir: P) -> Result<(), Box> { Ok(()) } +/// Train a model on the xor[^xor] problem and evaluate it after. +/// +/// [^xor]: https://en.wikipedia.org/wiki/Perceptron#History fn main() -> Result<(), Box> { let mut dir = env::temp_dir(); dir.push("tf-rust-example-xor-saved-model"); let mut dir2 = env::temp_dir(); dir2.push("tf-rust-example-xor-saved-model2"); - match fs::remove_dir_all(&dir) { - Err(e) => { - if e.kind() != ErrorKind::NotFound { - return Err(Box::new(e)); - } + if let Err(e) = fs::remove_dir_all(&dir) { + if e.kind() != ErrorKind::NotFound { + return Err(Box::new(e)); } - Ok(_) => (), } - match fs::remove_dir_all(&dir2) { - Err(e) => { - if e.kind() != ErrorKind::NotFound { - return Err(Box::new(e)); - } + if let Err(e) = fs::remove_dir_all(&dir2) { + if e.kind() != ErrorKind::NotFound { + return Err(Box::new(e)); } - Ok(_) => (), } - train(&dir)?; + build_and_train_and_save(&dir)?; // Ensure that the saved model works even when moved. // Users do not need to do this; this is purely for testing purposes. fs::rename(&dir, &dir2)?; diff --git a/src/buffer.rs b/src/buffer.rs index 0db1db70f5..77944fbfd9 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,4 +1,5 @@ use super::TensorType; +use crate::tf; use libc::size_t; use std::alloc; use std::borrow::Borrow; @@ -16,11 +17,6 @@ use std::ops::RangeTo; use std::os::raw::c_void as std_c_void; use std::process; use std::slice; -#[cfg(feature = "default")] -use tensorflow_sys as tf; -#[cfg(feature = "tensorflow_runtime_linking")] -use tensorflow_sys_runtime as tf; - /// Fixed-length heap-allocated vector. /// This is basically a `Box<[T]>`, except that that type can't actually be constructed. /// Furthermore, `[T; N]` can't be constructed if N is not a compile-time constant. diff --git a/src/checkpoint.rs b/src/checkpoint.rs index 6d46a795a4..8fe7d14b1d 100644 --- a/src/checkpoint.rs +++ b/src/checkpoint.rs @@ -13,7 +13,9 @@ struct SaveRestoreOps { /// The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring. /// When one wants to save/restore from or into a session, one calls the save/restore methods /// # Example -/// ``` +/// ```no_run,ignore +/// # let _: &str = stringify!{ // workaround to not compile the code +/// # // FIXME make this test compile and run /// let mut scope = Scope::new_root_scope(); /// // add operations to define the graph /// // ... @@ -29,8 +31,8 @@ struct SaveRestoreOps { /// // then we restore in a different session to continue there /// let new_session = Session::new(&SessionOptions::new(), &scope.graph())?; /// checkpoint_maker.restore(&new_session, "data/checkpoint")?; +/// # }; /// ``` -/// #[derive(Debug)] pub struct CheckpointMaker { scope: Scope, @@ -226,10 +228,10 @@ mod tests { dims: &[u64], values: &[f32], ) -> Result { - Ok(Variable::builder() + Variable::builder() .const_initial_value(Tensor::new(dims).with_values(values)?) .data_type(DataType::Float) - .build(&mut scope.with_op_name(name))?) + .build(&mut scope.with_op_name(name)) } fn create_assignment( @@ -272,7 +274,7 @@ mod tests { let mut placeholders: Vec = Vec::new(); let mut no_op_bld = ops::NoOp::new(); for var in scope_data.variables.as_ref() { - let (placeholder, assign_op) = create_assignment(&var, &mut placeholder_scope)?; + let (placeholder, assign_op) = create_assignment(var, &mut placeholder_scope)?; placeholders.push(placeholder); no_op_bld = no_op_bld.add_control_input(assign_op); } @@ -289,31 +291,43 @@ mod tests { assign_data: &AssignData, values: &[&[f32]], ) -> Result<(), Status> { - let mut values_fed: Vec> = - Vec::with_capacity(assign_data.placeholder_ops.len()); + assert_eq!( + assign_data.placeholder_ops.len(), + scope_data.variables.len() + ); + assert_eq!(assign_data.placeholder_ops.len(), values.len()); + + let tensor: Vec> = scope_data + .variables + .iter() + .zip(values.iter()) + .map(|(variable, value)| { + Tensor::new( + variable + .shape() + .0 + .as_ref() + .ok_or(Status::new_set(Code::Internal, "Rank of shape not known")?)? + .iter() + .map(|o| { + o.map(|i| i as u64).ok_or(Status::new_set( + Code::Internal, + "Dimensiom in shape not known", + )?) + }) + .collect::, _>>()? + .as_ref(), + ) + .with_values(value) + }) + .collect::>()?; + let mut session_run = SessionRunArgs::new(); - for i_var in 0..assign_data.placeholder_ops.len() { - let value_fed_as_tensor = Tensor::new( - &scope_data.variables[i_var] - .shape() - .0 - .as_ref() - .ok_or(Status::new_set(Code::Internal, "Shape not present")?)? - .iter() - .map(|o| { - o.map(|i| i as u64) - .ok_or(Status::new_set(Code::Internal, "Shape item not present")?) - }) - .collect::, Status>>()? - .as_ref(), - ) - .with_values(&values[i_var])?; - values_fed.push(value_fed_as_tensor); - } - for i_var in 0..assign_data.placeholder_ops.len() { - session_run.add_feed(&assign_data.placeholder_ops[i_var], 0, &values_fed[i_var]); + for (tensor, placeholder) in tensor.iter().zip(assign_data.placeholder_ops.iter()) { + session_run.add_feed(placeholder, 0, tensor); } session_run.add_target(&assign_data.assign_op); + session.run(&mut session_run)?; Ok(()) } @@ -324,17 +338,16 @@ mod tests { values: &[&[f32]], ) -> Result<(), Status> { let mut session_run = SessionRunArgs::new(); - let mut tokens: Vec = Vec::with_capacity(variables.len()); - for i in 0..variables.len() { - tokens.push(session_run.request_fetch( - &variables[i].output().operation, - variables[i].output().index, - )); - } + let tokens: Vec = variables + .iter() + .map(|variable| { + session_run.request_fetch(&variable.output().operation, variable.output().index) + }) + .collect(); session.run(&mut session_run)?; - for i in 0..variables.len() { - let got_tensor: Tensor = session_run.fetch(tokens[i])?; - assert_eq!(values[i], got_tensor.as_ref()); + for (token, value) in tokens.into_iter().zip(values.iter()) { + let got_tensor: Tensor = session_run.fetch(token)?; + assert_eq!(value, &got_tensor.as_ref()); } Ok(()) } diff --git a/src/eager.rs b/src/eager.rs index 05f7894dd0..74a460f7d3 100644 --- a/src/eager.rs +++ b/src/eager.rs @@ -5,7 +5,7 @@ //! //! This API requires the `eager` feature to be enabled as follows: //! -//! ``` +//! ```toml //! [dependencies] //! tensorflow = { version = "0.18", features = ["eager"] } //! ``` diff --git a/src/eager/op.rs b/src/eager/op.rs index fe3b026901..5510e9bc27 100644 --- a/src/eager/op.rs +++ b/src/eager/op.rs @@ -7,7 +7,7 @@ use libc::c_void; use libc::size_t; use std::ffi::{CStr, CString}; use std::marker::PhantomData; -use std::mem::{self, ManuallyDrop}; +use std::mem::ManuallyDrop; use std::os::raw::c_void as std_c_void; use std::ptr; @@ -16,9 +16,6 @@ use crate::{AnyTensor, Code, DataType, Result, Shape, Status}; use tensorflow_sys as tf; -#[cfg(test)] -mod op_test_util; - #[allow( non_snake_case, clippy::too_many_arguments, @@ -160,18 +157,26 @@ impl<'a> Op<'a> { /// Sets the value of an attribute which holds a list of strings. fn set_attr_string_list>(&mut self, attr_name: &str, values: &[S]) -> Result<()> { let c_attr_name = CString::new(attr_name)?; - let bytes: Vec<&[u8]> = values.iter().map(|x| x.as_ref().as_bytes()).collect(); - let ptrs: Vec<*const c_void> = bytes.iter().map(|x| x.as_ptr() as *const c_void).collect(); - let lens: Vec = bytes.iter().map(|x| x.len() as size_t).collect(); + let (ptrs, lens): (Vec<*const c_void>, Vec) = values + .iter() + .map(|entry| { + ( + entry.as_ref().as_ptr() as *const std_c_void, + entry.as_ref().len() as size_t, + ) + }) + .unzip(); + unsafe { tf::TFE_OpSetAttrStringList( self.inner, c_attr_name.as_ptr(), - ptrs.as_ptr() as *const *const std_c_void, + ptrs.as_ptr(), lens.as_ptr(), ptrs.len() as c_int, ); } + Ok(()) } @@ -365,7 +370,7 @@ impl<'a> Op<'a> { /// For sync execution, if any of the inputs to `op` are not ready, this call /// will block till they become ready and then return when the kernel execution /// is done. - fn execute(self, ctx: &'a Context) -> Result<[TensorHandle; N]> { + fn execute(self, ctx: &'a Context) -> Result<[TensorHandle<'a>; N]> { let status = Status::new(); let mut num_retvals = N as i32; @@ -389,11 +394,9 @@ impl<'a> Op<'a> { // If the 'num_retvals' was updated, we treat that as an error. See comment above. if num_retvals != N as i32 { - for i in 0..num_retvals as usize { - unsafe { - tf::TFE_DeleteTensorHandle(retvals[i]); - } - } + retvals + .iter() + .for_each(|retval| unsafe { tf::TFE_DeleteTensorHandle(*retval) }); let status = Status::new_set_lossy( Code::InvalidArgument, &format!("Expected {} outputs, got {}", N, num_retvals), @@ -401,23 +404,8 @@ impl<'a> Op<'a> { return Err(status); } - let mut handles_uninit: [mem::MaybeUninit; N] = - unsafe { mem::MaybeUninit::uninit().assume_init() }; - - for i in 0..N { - let t = unsafe { TensorHandle::from_tensor_handle(ctx, retvals[i]) }; - handles_uninit[i].write(t); - } - - // Transmute uninitialized handles to initialized handles. Ideally, we would use - // `mem::transmute` here, but it is not stable yet for generic sized arrays. - // ref : https://github.com/rust-lang/rust/issues/61956 - // - // Following is a workaround for this issue: - // Using &mut as an assertion of unique "ownership" - let ptr = &mut handles_uninit as *mut _ as *mut [TensorHandle; N]; - let handles: [TensorHandle; N] = unsafe { ptr.read() }; - mem::forget(handles_uninit); + let handles: [TensorHandle; N] = + std::array::from_fn(|i| unsafe { TensorHandle::from_tensor_handle(ctx, retvals[i]) }); Ok(handles) } @@ -426,9 +414,8 @@ impl<'a> Op<'a> { #[cfg(test)] mod tests { use super::*; - use crate::eager::{Context, ContextOptions, TensorHandle}; + use crate::eager::{Context, ContextOptions, TensorHandle, ToTensorHandle}; use crate::Tensor; - use op_test_util::add as add_ut; use raw_ops::{add, concat_v2}; #[cfg(feature = "ndarray")] @@ -491,6 +478,26 @@ mod tests { let h_y = h_x.copy_sharing_tensor().unwrap(); let expected = Tensor::new(&[2, 2]).with_values(&[2i32, 4, 6, 8]).unwrap(); + fn add_ut<'a, T0, T1>( + ctx: &'a crate::eager::Context, + x: &T0, + y: &T1, + ) -> Result> + where + T0: ToTensorHandle<'a>, + T1: ToTensorHandle<'a>, + { + let op_name = "Add"; + let mut op = Op::new(ctx, op_name)?; + + // Required input arguments + op.add_input(&x.to_handle(ctx)?)?; + op.add_input(&y.to_handle(ctx)?)?; + + let [h] = op.execute::<1>(ctx)?; + Ok(h) + } + // tensor and tensor let h_z = add_ut(&ctx, &x, &x).unwrap(); let z = h_z.resolve::().unwrap(); diff --git a/src/eager/op/op_test_util.rs b/src/eager/op/op_test_util.rs index 3f83ab4683..e69de29bb2 100644 --- a/src/eager/op/op_test_util.rs +++ b/src/eager/op/op_test_util.rs @@ -1,86 +0,0 @@ -#![allow(non_snake_case)] -/// Code for Op's ut that mimics raw_opw. -use crate::eager::{TensorHandle, ToTensorHandle}; -use crate::Result; - -use super::Op; - -/// Add -#[derive(::std::fmt::Debug)] -pub struct Add { - T: ::std::option::Option, - device_name: ::std::option::Option, -} - -impl ::std::default::Default for Add { - fn default() -> Self { - Self { - T: None, - device_name: None, - } - } -} - -impl Add { - /// Creates a new `Add`. - pub fn new() -> Self { - Self::default() - } - - /// Sets the `T` attribute. - pub fn T>(mut self, value: ArgType) -> Self { - self.T = ::std::option::Option::Some(value.into()); - self - } - - /// Set the `device_name` where in the Op is executed. - pub fn set_device(&mut self, device_name: &str) { - self.device_name = ::std::option::Option::Some(device_name.to_string()); - } - - /// Execute add. - pub fn call<'a, T0, T1>( - &self, - ctx: &'a crate::eager::Context, - x: &T0, - y: &T1, - ) -> Result> - where - T0: ToTensorHandle<'a>, - T1: ToTensorHandle<'a>, - { - // Define Op - - let op_name = "Add"; - let mut op = Op::new(ctx, op_name)?; - - // Required input arguments - op.add_input(&x.to_handle(ctx)?)?; - op.add_input(&y.to_handle(ctx)?)?; - - // Attributes - if let ::std::option::Option::Some(value) = &self.T { - let attr_name = "T"; - op.set_attr_type(attr_name, *value)?; - } - - // Device - if let ::std::option::Option::Some(device_name) = &self.device_name { - op.set_device(device_name)?; - } - - // Execute Op - let [h] = op.execute::<1>(ctx)?; - Ok(h) - } -} - -/// add with default options. -pub fn add<'a, T0, T1>(ctx: &'a crate::eager::Context, x: &T0, y: &T1) -> Result> -where - T0: ToTensorHandle<'a>, - T1: ToTensorHandle<'a>, -{ - let op = Add::new(); - op.call(ctx, x, y) -} diff --git a/src/eager/readonly_tensor.rs b/src/eager/readonly_tensor.rs index a27cd3893a..215a8985c9 100644 --- a/src/eager/readonly_tensor.rs +++ b/src/eager/readonly_tensor.rs @@ -1,3 +1,4 @@ +use crate::tf; use crate::{ write_tensor_recursive, AnyTensor, DataType, Result, Shape, Tensor, TensorInner, TensorType, }; @@ -5,7 +6,6 @@ use core::fmt; use fmt::{Debug, Formatter}; use libc::c_int; use std::{fmt::Display, ops::Deref}; -use tensorflow_sys as tf; /// A read-only tensor. /// diff --git a/src/eager/tensor_handle.rs b/src/eager/tensor_handle.rs index 3ddce2a142..569ff3e568 100644 --- a/src/eager/tensor_handle.rs +++ b/src/eager/tensor_handle.rs @@ -64,7 +64,7 @@ use crate::{AnyTensor, DataType, Result, Status, TensorType}; #[derive(Debug)] pub struct TensorHandle<'a> { pub(super) inner: *mut tf::TFE_TensorHandle, - // TensorHandle should not live longer than a given context. + // TensorHandle canjot outlive its associated context ctx: PhantomData<&'a Context>, } diff --git a/src/graph.rs b/src/graph.rs index c3943f905a..aa1c1ecfbb 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -7,6 +7,7 @@ use super::Shape; use super::Status; use super::Tensor; use super::TensorType; +use crate::tf; use libc::c_char; use libc::c_float; use libc::c_int; @@ -27,10 +28,6 @@ use std::slice; use std::str::FromStr; use std::str::Utf8Error; use std::sync::Arc; -#[cfg(feature = "default")] -use tensorflow_sys as tf; -#[cfg(feature = "tensorflow_runtime_linking")] -use tensorflow_sys_runtime as tf; #[derive(Debug)] struct GraphImpl { @@ -575,20 +572,6 @@ impl Graph { /// `append_hash_to_fn_name` is false, `fn_name` must be distinct from /// other function and operation names (at least those registered in /// graphs where this function will be used). - /// * `append_hash_to_fn_name` - If true, the actual name of the function - /// will be `fn_name` appended with - /// '_<hash_of_this_function's_definition>'. If false, the - /// function's name will be `fn_name`. - /// * `opers` - Array of operations to become the body of the function or - /// null. - /// * If `None`, all the operations in the graph will become part of the - /// function except operations referenced in `inputs`. These operations - /// must have a single output (these operations are typically - /// placeholders created for the sole purpose of representing an input. - /// We can relax this constraint if there are compelling use cases). - /// * If `Some`, all operations in it will become part of the function. In - /// particular, no automatic skipping of dummy input operations is - /// performed. /// * `inputs` - array of `Output`s that specify the inputs to the function. /// The names used for function inputs are normalized names of the /// operations (usually placeholders) pointed to by `inputs`. These @@ -598,13 +581,6 @@ impl Graph { /// argument names. `inputs` cannot contain the same tensor twice. /// * `outputs` - array of `Output`s that specify the outputs of the /// function. `outputs` can contain the same tensor more than once. - /// * `output_names` - The names of the function's outputs. `output_names` - /// array must either have the same length as `outputs` or be None. In the - /// former case, the names should match the regular expression for ArgDef - /// names - "[a-z][a-z0-9_]*". In the latter case, names for outputs will - /// be generated automatically. - /// * `opts` - various options for the function, e.g. XLA's inlining control. - /// * `description` - optional human-readable description of this function. /// /// Note that when the same `Output` is listed as both an input and an /// output, the corresponding function's output will equal to this input, @@ -631,16 +607,29 @@ impl Graph { /// # Returns /// /// A newly created `Function` instance. - pub fn to_function>( + pub fn to_function<'a>( + &'a self, + fn_name: &'a str, + inputs: &'a [Output], + outputs: &'a [Output], + options: &'a FunctionOptions, + ) -> FunctionBuilder<'a> { + FunctionBuilder::new(self, fn_name, inputs, outputs, options) + } + + fn inner_to_function( &self, - fn_name: &str, - append_hash_to_fn_name: bool, - opers: Option<&[&Operation]>, - inputs: &[Output], - outputs: &[Output], - output_names: Option<&[S]>, - opts: &FunctionOptions, - description: Option<&str>, + FunctionBuilder { + graph: _, + fn_name, + append_hash_to_fn_name, + opers, + inputs, + outputs, + output_names, + opts, + description, + }: FunctionBuilder, ) -> Result { let fn_name_cstr = CString::new(fn_name)?; let num_opers: c_int = if let Some(ops) = &opers { @@ -648,39 +637,41 @@ impl Graph { } else { -1 }; + #[allow(trivial_casts)] let c_opers: Option> = opers.map(|s| s.iter().map(|op| op.inner as *const _).collect()); - let c_opers_ptr: *const *const tf::TF_Operation = if let Some(ref ops) = &c_opers { - ops.as_ptr() - } else { - ptr::null() - }; + let c_opers_ptr: *const *const tf::TF_Operation = c_opers + .as_ref() + .map(|opers| opers.as_ptr()) + .unwrap_or_else(ptr::null); + let c_inputs: Vec<_> = inputs.iter().map(|x| x.to_c()).collect(); let c_outputs: Vec<_> = outputs.iter().map(|x| x.to_c()).collect(); - let output_names_cstrs: Option<::std::result::Result, NulError>> = - output_names - .map(|slice: &[S]| slice.iter().map(|s: &S| CString::new(s.as_ref())).collect()); - let output_names_cstrs: Option> = match output_names_cstrs { - None => None, - Some(r) => Some(r?), - }; + + let output_names_cstrs: Option> = output_names + .map(|output_names: &[&str]| { + output_names + .iter() + .cloned() + .map(CString::new) + .collect::<::std::result::Result, NulError>>() + }) + .transpose()?; let output_names_ptrs: Option> = output_names_cstrs .as_ref() .map(|slice| slice.iter().map(|s| s.as_ptr()).collect()); - let output_names_ptrs_ptr = match &output_names_ptrs { - None => ptr::null(), - Some(ref v) => v.as_ptr(), - }; - let description_cstr = match description { - None => None, - Some(d) => Some(CString::new(d)?), - }; - let description_ptr: *const c_char = if let Some(ref cstr) = &description_cstr { - cstr.as_ptr() - } else { - ptr::null() - }; + let output_names_ptrs_ptr = output_names_ptrs + .as_ref() + .map(|names| names.as_ptr()) + .unwrap_or_else(ptr::null); + + let description_cstr = description.map(CString::new).transpose()?; + let description_ptr: *const c_char = description_cstr + .as_ref() + .map(|cstr| cstr.as_ptr()) + .unwrap_or_else(ptr::null); + let status = Status::new(); let f = unsafe { tf::TF_GraphToFunction( @@ -699,6 +690,7 @@ impl Graph { status.inner, ) }; + status.into_result()?; Ok(Function { inner: f }) } @@ -2385,6 +2377,95 @@ impl Function { } } +/// Builder pattern to build a function from a graph. +/// +/// See [Graph::to_function]. +#[derive(Debug)] +pub struct FunctionBuilder<'a> { + graph: &'a Graph, + fn_name: &'a str, + append_hash_to_fn_name: bool, + opers: Option<&'a [&'a Operation]>, + inputs: &'a [Output], + outputs: &'a [Output], + output_names: Option<&'a [&'a str]>, + opts: &'a FunctionOptions, + description: Option<&'a str>, +} + +impl<'a> FunctionBuilder<'a> { + fn new( + graph: &'a Graph, + fn_name: &'a str, + inputs: &'a [Output], + outputs: &'a [Output], + options: &'a FunctionOptions, + ) -> Self { + FunctionBuilder { + fn_name, + append_hash_to_fn_name: false, + opers: None, + inputs, + outputs, + output_names: None, + opts: options, + description: None, + graph, + } + } + + /// When set, the final name of the function will be `fn_name` appended with + /// '_<hash_of_this_function's_definition>'. + pub fn append_hash(mut self, do_it: bool) -> Self { + self.append_hash_to_fn_name = do_it; + self + } + + /// Names of the function's outputs. `output_names`. + /// + /// The names should match the regular expression for ArgDef names - "[a-z][a-z0-9_]*". + /// + /// By default names for outputs will be generated automatically. + /// + /// # Panics + /// The number of names must match the number of outputs. + pub fn output_names(mut self, names: &'a [&'a str]) -> Self { + assert_eq!(names.len(), self.outputs.len(), "one name for one output"); + self.output_names = Some(names); + self + } + + /// Operations to become part of the body of the function. + /// + /// In particular, no automatic skipping of dummy input operations will be performed. + /// + /// By default all the operations in the graph will become part of the function + /// except operations referenced in its inputs. These operations must have a + /// single output (these operations are typically placeholder created for the + /// sole purpose of representing an input). + pub fn opers(mut self, opers: &'a [&'a Operation]) -> Self { + self.opers = Some(opers); + self + } + + /// Options for the function. + pub fn options(mut self, opts: &'a FunctionOptions) -> Self { + self.opts = opts; + self + } + + /// Human readable description for the function. + pub fn description(mut self, desc: &'a str) -> Self { + self.description = Some(desc); + self + } + + /// Build the function. + pub fn finalize(self) -> Result { + self.graph.inner_to_function(self) + } +} + //////////////////////// #[cfg(test)] @@ -2483,16 +2564,11 @@ mod tests { let description = "Multiplies by 2"; let opts = FunctionOptions::new(); let f = g - .to_function( - "times_two", - false, - Some(&opers), - &inputs, - &outputs, - Some(&output_names), - &opts, - Some(description), - ) + .to_function("times_two", &inputs, &outputs, &opts) + .opers(&opers) + .output_names(&output_names) + .description(description) + .finalize() .unwrap(); assert_eq!("times_two", f.get_name().unwrap()); let mut g2 = Graph::new(); @@ -2553,8 +2629,8 @@ mod tests { nd.set_attr_bool("use_locking", false).unwrap(); nd.finish().unwrap() }; - assert_eq!(true, op.get_attr_bool("validate_shape").unwrap()); - assert_eq!(false, op.get_attr_bool("use_locking").unwrap()); + assert!(op.get_attr_bool("validate_shape").unwrap()); + assert!(!op.get_attr_bool("use_locking").unwrap()); let op = { let variable_op = { @@ -2596,10 +2672,10 @@ mod tests { .unwrap(); nd.add_input(variable_op.clone()); nd.add_input(variable_op.clone()); - nd.set_attr_float("tolerance", 3.14).unwrap(); + nd.set_attr_float("tolerance", 42.42).unwrap(); nd.finish().unwrap() }; - assert_eq!(3.14, op.get_attr_float("tolerance").unwrap()); + assert_eq!(42.42, op.get_attr_float("tolerance").unwrap()); let op = { let mut nd = g.new_operation("Bucketize", "Bucketize").unwrap(); @@ -2773,14 +2849,14 @@ mod tests { fn graph_get_op_def() { let g = Graph::new(); // We don't want to compare the actual proto because it may change across releases. - assert!(g.get_op_def("Const").unwrap().len() > 0); + assert!(!g.get_op_def("Const").unwrap().is_empty()); } #[test] fn graph_versions() { let g = Graph::new(); // We don't want to compare the actual proto because it may change across releases. - assert!(g.versions().unwrap().len() > 0); + assert!(!g.versions().unwrap().is_empty()); } #[test] diff --git a/src/io.rs b/src/io.rs index bd11cf599b..0b160c8398 100644 --- a/src/io.rs +++ b/src/io.rs @@ -12,7 +12,7 @@ use std::{ }; fn mask_crc(crc: u32) -> u32 { - ((crc >> 15) | (crc << 17)).wrapping_add(0xa282_ead8u32) + crc.rotate_right(15).wrapping_add(0xa282_ead8u32) } const CASTAGNOLI: Crc = Crc::::new(&CRC_32_ISCSI); @@ -372,6 +372,7 @@ mod tests { let f = ::std::fs::OpenOptions::new() .write(true) .create(true) + .truncate(true) .open(actual_filename) .unwrap(); @@ -398,7 +399,7 @@ mod tests { use std::{io::Cursor, rc::Rc}; let mut buf = Vec::new(); let mut rc = Rc::new(&mut buf); - let records = vec!["foo", "barr", "baz"]; + let records = ["foo", "barr", "baz"]; { let mut writer = RecordWriter::new(Rc::get_mut(&mut rc).unwrap()); for rec in records.iter() { @@ -426,7 +427,7 @@ mod tests { Some(len) => assert_eq!(&ary[0..len], records[i].as_bytes()), None => break, }, - Err(e @ _) => { + Err(e) => { panic!("Received an unexpected error: {:?}", e); } } @@ -439,7 +440,7 @@ mod tests { use std::{io::Cursor, rc::Rc}; let mut buf = Vec::new(); let mut rc = Rc::new(&mut buf); - let records = vec!["foo", "barr", "baz"]; + let records = ["foo", "barr", "baz"]; { let mut writer = RecordWriter::new(Rc::get_mut(&mut rc).unwrap()); for rec in records.iter() { diff --git a/src/lib.rs b/src/lib.rs index de8ba4e0e5..1fc62389f8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,7 @@ use std::process; use std::ptr; use std::slice; use std::str::Utf8Error; -#[cfg(feature = "default")] +#[cfg(not(feature = "tensorflow_runtime_linking"))] use tensorflow_sys as tf; #[cfg(feature = "tensorflow_runtime_linking")] use tensorflow_sys_runtime as tf; @@ -203,8 +203,6 @@ pub use saved_model::*; mod checkpoint; pub use checkpoint::*; -mod option_insert_result; - #[cfg(feature = "eager")] pub mod eager; @@ -1651,7 +1649,7 @@ where fn from(value: Tensor) -> Self { let dims: Vec = value.dims.iter().map(|x| *x as usize).collect(); let dim = Dim(dims); - let data: Vec = value.iter().map(|x| x.clone()).collect(); + let data: Vec = value.iter().cloned().collect(); // We can safely unwrap this because we know that `data` will have the // correct number of elements to conform to `dim`. Array::from_shape_vec(dim, data).unwrap() @@ -2369,7 +2367,7 @@ mod tests { fn test_set_config() { let mut options = SessionOptions::new(); // An empty array is a valid proto, since all fields are optional. - options.set_config(&vec![]).unwrap(); + options.set_config(&[]).unwrap(); } #[test] @@ -2407,19 +2405,18 @@ mod tests { #[test] fn test_bfloat16() { - let data = [-1.0f32, 0.0, 1.0, 2.5]; - for i in 0..data.len() { - let x = data[i]; - let bfx = BFloat16::from(x); - assert_eq!(>::into(bfx), x); - assert_eq!(bfx.partial_cmp(&bfx), Some(Ordering::Equal)); - assert!(bfx.eq(&bfx)); - for j in 0..i { - let y = data[j]; - let bfy = BFloat16::from(y); - assert_eq!(bfx.partial_cmp(&bfy), Some(Ordering::Greater)); - assert_eq!(bfy.partial_cmp(&bfx), Some(Ordering::Less)); - assert!(!bfx.eq(&bfy)); + let sorted = [-1.0f32, 0.0, 1.0, 2.5]; + for greater_index in 0..sorted.len() { + let greater = sorted[greater_index]; + let bf_greater = BFloat16::from(greater); + assert_eq!(>::into(bf_greater), greater); + assert_eq!(bf_greater.partial_cmp(&bf_greater), Some(Ordering::Equal)); + assert!(bf_greater.eq(&bf_greater)); + for lesser in sorted[..greater_index].iter() { + let bf_lesser = BFloat16::from(*lesser); + assert_eq!(bf_greater.partial_cmp(&bf_lesser), Some(Ordering::Greater)); + assert_eq!(bf_lesser.partial_cmp(&bf_greater), Some(Ordering::Less)); + assert!(!bf_greater.eq(&bf_lesser)); } } assert_eq!(>::into(BFloat16::default()), 0.0f32); @@ -2430,7 +2427,7 @@ mod tests { fn test_f16() { let data: Vec = vec![-1.0f32, 0.0, 1.0, 2.5] .into_iter() - .map(|x| f16::from_f32(x)) + .map(f16::from_f32) .collect(); let tensor = >::new(&[2, 2]).with_values(&data).unwrap(); assert_eq!(&tensor[..], &data[..]); @@ -2694,12 +2691,12 @@ mod tests { #[test] fn test_get_all_registered_kernels() { - assert!(get_all_registered_kernels().unwrap().len() > 0); + assert!(!get_all_registered_kernels().unwrap().is_empty()); } #[test] fn test_get_registered_kernels_for_op() { - assert!(get_registered_kernels_for_op("Add").unwrap().len() > 0); + assert!(!get_registered_kernels_for_op("Add").unwrap().is_empty()); } #[cfg(target_os = "linux")] diff --git a/src/option_insert_result.rs b/src/option_insert_result.rs deleted file mode 100644 index 1d41ba50fe..0000000000 --- a/src/option_insert_result.rs +++ /dev/null @@ -1,18 +0,0 @@ -// Similar to Option.get_or_insert_with, for a function that returns a result. -pub trait OptionInsertWithResult { - fn get_or_insert_with_result(&mut self, f: F) -> Result<&mut T, E> - where - F: FnOnce() -> Result; -} - -impl OptionInsertWithResult for Option { - fn get_or_insert_with_result(&mut self, f: F) -> Result<&mut T, E> - where - F: FnOnce() -> Result, - { - if self.is_none() { - *self = Some(f()?); - } - Ok(self.as_mut().unwrap()) - } -} diff --git a/src/protos/allocation_description.rs b/src/protos/allocation_description.rs index 8af606bce0..3a30c7f37f 100644 --- a/src/protos/allocation_description.rs +++ b/src/protos/allocation_description.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/attr_value.rs b/src/protos/attr_value.rs index 7fd8052a47..a8900fca1d 100644 --- a/src/protos/attr_value.rs +++ b/src/protos/attr_value.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/cluster.rs b/src/protos/cluster.rs index 394499ab43..ae7ff736a6 100644 --- a/src/protos/cluster.rs +++ b/src/protos/cluster.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/config.rs b/src/protos/config.rs index ab0688998f..afe9b0fb57 100644 --- a/src/protos/config.rs +++ b/src/protos/config.rs @@ -8,8 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] -#![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] diff --git a/src/protos/coordination_config.rs b/src/protos/coordination_config.rs index 753dfe5217..05fc0890d2 100644 --- a/src/protos/coordination_config.rs +++ b/src/protos/coordination_config.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/cost_graph.rs b/src/protos/cost_graph.rs index e30fadd530..7e26accbf6 100644 --- a/src/protos/cost_graph.rs +++ b/src/protos/cost_graph.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/debug.rs b/src/protos/debug.rs index 6688679e3f..e6192fae2d 100644 --- a/src/protos/debug.rs +++ b/src/protos/debug.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/full_type.rs b/src/protos/full_type.rs index ece0ce955b..29aee1100f 100644 --- a/src/protos/full_type.rs +++ b/src/protos/full_type.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/function.rs b/src/protos/function.rs index 2162bc2cea..48d54594a4 100644 --- a/src/protos/function.rs +++ b/src/protos/function.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/graph.rs b/src/protos/graph.rs index f3b084b494..1b82c9636b 100644 --- a/src/protos/graph.rs +++ b/src/protos/graph.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/graph_debug_info.rs b/src/protos/graph_debug_info.rs index 71b3124d89..220a60bbac 100644 --- a/src/protos/graph_debug_info.rs +++ b/src/protos/graph_debug_info.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/meta_graph.rs b/src/protos/meta_graph.rs index 4a3ede5a12..93ba5ba39b 100644 --- a/src/protos/meta_graph.rs +++ b/src/protos/meta_graph.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/node_def.rs b/src/protos/node_def.rs index fd61372c37..4886b6fd2d 100644 --- a/src/protos/node_def.rs +++ b/src/protos/node_def.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/op_def.rs b/src/protos/op_def.rs index 54dbbe601b..5bbb150963 100644 --- a/src/protos/op_def.rs +++ b/src/protos/op_def.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/resource_handle.rs b/src/protos/resource_handle.rs index c2446c6d61..3793185851 100644 --- a/src/protos/resource_handle.rs +++ b/src/protos/resource_handle.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/rewriter_config.rs b/src/protos/rewriter_config.rs index 25cb6228af..728136a88e 100644 --- a/src/protos/rewriter_config.rs +++ b/src/protos/rewriter_config.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/rpc_options.rs b/src/protos/rpc_options.rs index 9ad0d63c99..23f3cd5004 100644 --- a/src/protos/rpc_options.rs +++ b/src/protos/rpc_options.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/saved_model.rs b/src/protos/saved_model.rs index c0eb5afbc9..8f14aa41af 100644 --- a/src/protos/saved_model.rs +++ b/src/protos/saved_model.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/saved_object_graph.rs b/src/protos/saved_object_graph.rs index e33eeba3bf..b5835c5ed4 100644 --- a/src/protos/saved_object_graph.rs +++ b/src/protos/saved_object_graph.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/saver.rs b/src/protos/saver.rs index 6223992040..e6745265a0 100644 --- a/src/protos/saver.rs +++ b/src/protos/saver.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/step_stats.rs b/src/protos/step_stats.rs index a472ec65eb..782a8c729b 100644 --- a/src/protos/step_stats.rs +++ b/src/protos/step_stats.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/struct_pb.rs b/src/protos/struct_pb.rs index efe7c65cd0..791566caac 100644 --- a/src/protos/struct_pb.rs +++ b/src/protos/struct_pb.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/tensor.rs b/src/protos/tensor.rs index 4e92188fc4..f5741b8cc8 100644 --- a/src/protos/tensor.rs +++ b/src/protos/tensor.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/tensor_description.rs b/src/protos/tensor_description.rs index 15669779d7..605f2164dd 100644 --- a/src/protos/tensor_description.rs +++ b/src/protos/tensor_description.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/tensor_shape.rs b/src/protos/tensor_shape.rs index a493eddc39..d002560a63 100644 --- a/src/protos/tensor_shape.rs +++ b/src/protos/tensor_shape.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/trackable_object_graph.rs b/src/protos/trackable_object_graph.rs index 06e3a23ecd..4339c20301 100644 --- a/src/protos/trackable_object_graph.rs +++ b/src/protos/trackable_object_graph.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/types.rs b/src/protos/types.rs index 19282f1a37..f93d311cfe 100644 --- a/src/protos/types.rs +++ b/src/protos/types.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/variable.rs b/src/protos/variable.rs index 63050ef55b..3e85616e28 100644 --- a/src/protos/variable.rs +++ b/src/protos/variable.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/verifier_config.rs b/src/protos/verifier_config.rs index 9d151e5b22..d6eddce9c6 100644 --- a/src/protos/verifier_config.rs +++ b/src/protos/verifier_config.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/protos/versions.rs b/src/protos/versions.rs index 616b7043cf..9c491b7005 100644 --- a/src/protos/versions.rs +++ b/src/protos/versions.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/src/scope.rs b/src/scope.rs index 03be69338d..c26da833a4 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -144,7 +144,20 @@ impl Scope { /// Return a new scope. Ops created with this scope will have /// `name/child_scope_name` as the prefix. The actual name will be unique /// in the current scope. All other properties are inherited from the current - /// scope. If `child_scope_name` is empty, the `/` is elided. + /// scope. + /// + /// ``` + /// # use tensorflow::Scope; + /// let scope = Scope::new_root_scope(); + /// assert_eq!(scope.get_unique_name_for_op("MyOp"), "MyOp"); + /// + /// let subscope = scope.new_sub_scope("subscope"); + /// assert_eq!(subscope.get_unique_name_for_op("MyOp"), "subscope/MyOp"); + /// + /// // If `child_scope_name` is empty, the previous scope is elided. + /// let empty_subscope = scope.new_sub_scope(""); + /// assert_eq!(empty_subscope.get_unique_name_for_op("MyOp"), "MyOp_1"); + /// ``` pub fn new_sub_scope(&self, name: &str) -> Scope { let self_name: &str = &self.name; let (new_name, copy_names) = match (self_name, name) { @@ -169,8 +182,15 @@ impl Scope { } } - /// Return a new scope. All ops created within the returned scope will have - /// names of the form `scope_name/name[_suffix]` + /// Return a new scope. All ops created will have a modified name. + /// + /// ``` + /// # use tensorflow::Scope; + /// let scope = Scope::new_root_scope(); + /// let new_scope = scope.with_op_name("the_name"); + /// assert_eq!(new_scope.get_unique_name_for_op("MyOp"), "the_name"); + /// assert_eq!(new_scope.get_unique_name_for_op("MyOp"), "the_name_1"); + /// ``` pub fn with_op_name(&self, name: &str) -> Scope { Scope { graph: self.graph.clone(), diff --git a/src/session.rs b/src/session.rs index ed842517fe..830db6ffe9 100644 --- a/src/session.rs +++ b/src/session.rs @@ -641,7 +641,7 @@ mod tests { let mut graph = Graph::new(); let bundle = SavedModelBundle::load( &SessionOptions::new(), - &["train", "serve"], + ["train", "serve"], &mut graph, "test_resources/regression-model", ) diff --git a/src/while_loop.rs b/src/while_loop.rs index 688f66dda1..feafa5a749 100644 --- a/src/while_loop.rs +++ b/src/while_loop.rs @@ -2,6 +2,7 @@ use super::Graph; use super::Output; use super::Result; use super::Status; +use crate::tf; use std::ffi::CString; use std::ffi::NulError; use std::mem; @@ -9,11 +10,6 @@ use std::os::raw::c_int; use std::ptr; use std::result; use std::slice; -#[cfg(feature = "default")] -use tensorflow_sys as tf; -#[cfg(feature = "tensorflow_runtime_linking")] -use tensorflow_sys_runtime as tf; - // This exists purely to ensure TF_AbortWhile gets called properly, even on panic. #[derive(Debug)] struct CWhileParams { diff --git a/tensorflow-op-codegen/src/bin/eager.rs b/tensorflow-op-codegen/src/bin/eager.rs index 9b639e8f98..d70269a817 100644 --- a/tensorflow-op-codegen/src/bin/eager.rs +++ b/tensorflow-op-codegen/src/bin/eager.rs @@ -10,11 +10,11 @@ use std::io::Write; use std::path::Path; use std::result::Result; use tensorflow_op_codegen::parser; -use tensorflow_op_codegen::protos::OpDef; +use tensorflow_op_codegen::protos::op_def::OpDef; use ::protobuf::ProtobufEnum; -use tensorflow_op_codegen::protos::AttrValue_oneof_value; -use tensorflow_op_codegen::protos::OpDef_ArgDef; +use tensorflow_op_codegen::protos::attr_value::AttrValue_oneof_value; +use tensorflow_op_codegen::protos::op_def::OpDef_ArgDef; #[derive(Clone)] struct Attr { @@ -337,7 +337,7 @@ fn define_op( let mut attr_escaper = Escaper::new(keywords); for attr in op.attr.iter() { // skip if the attr is for type annotation - if skip_attrs.contains(&attr.get_name().to_string()) { + if skip_attrs.contains(attr.get_name()) { continue; } @@ -637,14 +637,13 @@ fn main() -> Result<(), Box> { .to_str() .ok_or("Unable to format path for tensorflow folder")?, )?; - let ops = parser::parse(&ops_bytes).map_err(|e| { + let ops = parser::parse(&ops_bytes).inspect_err(|e| { println!("Parse error at {:?}", e.pos); if let Some(p) = &e.pos { let input = String::from_utf8_lossy(&ops_bytes); println!("Previous: {}", &input[0..*p]); println!("Next: {}", &input[*p..]); } - e })?; let keywords: HashSet = [ "abstract", diff --git a/tensorflow-op-codegen/src/main.rs b/tensorflow-op-codegen/src/main.rs index c2b8f11bc1..bf9b802798 100644 --- a/tensorflow-op-codegen/src/main.rs +++ b/tensorflow-op-codegen/src/main.rs @@ -12,7 +12,7 @@ use std::io::Write; use std::path::Path; use std::result::Result; use tensorflow_op_codegen::parser; -use tensorflow_op_codegen::protos::OpDef; +use tensorflow_op_codegen::protos::op_def::OpDef; #[derive(Clone)] struct Attr { diff --git a/tensorflow-op-codegen/src/parser.rs b/tensorflow-op-codegen/src/parser.rs index 7df0701579..d839fa3b9c 100644 --- a/tensorflow-op-codegen/src/parser.rs +++ b/tensorflow-op-codegen/src/parser.rs @@ -2,18 +2,18 @@ // currently implemented directly with nom to parse the text proto, but ideally this would use a // proto library with support for parsing text protos. -use crate::protos::AttrValue; -use crate::protos::AttrValue_ListValue; -use crate::protos::DataType; -use crate::protos::FullTypeDef; -use crate::protos::FullTypeId; -use crate::protos::OpDef; -use crate::protos::OpDef_ArgDef; -use crate::protos::OpDef_AttrDef; -use crate::protos::OpDeprecation; -use crate::protos::TensorProto; -use crate::protos::TensorShapeProto; -use crate::protos::TensorShapeProto_Dim; +use crate::protos::attr_value::AttrValue; +use crate::protos::attr_value::AttrValue_ListValue; +use crate::protos::full_type::FullTypeDef; +use crate::protos::full_type::FullTypeId; +use crate::protos::op_def::OpDef; +use crate::protos::op_def::OpDef_ArgDef; +use crate::protos::op_def::OpDef_AttrDef; +use crate::protos::op_def::OpDeprecation; +use crate::protos::tensor::TensorProto; +use crate::protos::tensor_shape::TensorShapeProto; +use crate::protos::tensor_shape::TensorShapeProto_Dim; +use crate::protos::types::DataType; use nom::branch::alt; use nom::bytes::complete::tag; use nom::character::complete::anychar; diff --git a/tensorflow-op-codegen/src/protos/attr_value.rs b/tensorflow-op-codegen/src/protos/attr_value.rs index 7fd8052a47..a8900fca1d 100644 --- a/tensorflow-op-codegen/src/protos/attr_value.rs +++ b/tensorflow-op-codegen/src/protos/attr_value.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/tensorflow-op-codegen/src/protos/full_type.rs b/tensorflow-op-codegen/src/protos/full_type.rs index ece0ce955b..29aee1100f 100644 --- a/tensorflow-op-codegen/src/protos/full_type.rs +++ b/tensorflow-op-codegen/src/protos/full_type.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/tensorflow-op-codegen/src/protos/mod.rs b/tensorflow-op-codegen/src/protos/mod.rs index 1f86e7f2c0..4b7a005c90 100644 --- a/tensorflow-op-codegen/src/protos/mod.rs +++ b/tensorflow-op-codegen/src/protos/mod.rs @@ -1,20 +1,13 @@ -mod attr_value; -pub use attr_value::*; +pub mod attr_value; -mod full_type; -pub use full_type::*; +pub mod full_type; -mod op_def; -pub use op_def::*; +pub mod op_def; -mod resource_handle; -pub use resource_handle::*; +pub mod resource_handle; -mod tensor; -pub use tensor::*; +pub mod tensor; -mod tensor_shape; -pub use tensor_shape::*; +pub mod tensor_shape; -mod types; -pub use types::*; +pub mod types; diff --git a/tensorflow-op-codegen/src/protos/op_def.rs b/tensorflow-op-codegen/src/protos/op_def.rs index 54dbbe601b..5bbb150963 100644 --- a/tensorflow-op-codegen/src/protos/op_def.rs +++ b/tensorflow-op-codegen/src/protos/op_def.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/tensorflow-op-codegen/src/protos/resource_handle.rs b/tensorflow-op-codegen/src/protos/resource_handle.rs index c2446c6d61..3793185851 100644 --- a/tensorflow-op-codegen/src/protos/resource_handle.rs +++ b/tensorflow-op-codegen/src/protos/resource_handle.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/tensorflow-op-codegen/src/protos/tensor.rs b/tensorflow-op-codegen/src/protos/tensor.rs index 4e92188fc4..f5741b8cc8 100644 --- a/tensorflow-op-codegen/src/protos/tensor.rs +++ b/tensorflow-op-codegen/src/protos/tensor.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/tensorflow-op-codegen/src/protos/tensor_shape.rs b/tensorflow-op-codegen/src/protos/tensor_shape.rs index a493eddc39..d002560a63 100644 --- a/tensorflow-op-codegen/src/protos/tensor_shape.rs +++ b/tensorflow-op-codegen/src/protos/tensor_shape.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/tensorflow-op-codegen/src/protos/types.rs b/tensorflow-op-codegen/src/protos/types.rs index 19282f1a37..f93d311cfe 100644 --- a/tensorflow-op-codegen/src/protos/types.rs +++ b/tensorflow-op-codegen/src/protos/types.rs @@ -8,7 +8,6 @@ #![allow(unused_attributes)] #![cfg_attr(rustfmt, rustfmt::skip)] -#![allow(box_pointers)] #![allow(dead_code)] #![allow(missing_docs)] #![allow(non_camel_case_types)] diff --git a/tensorflow-sys-runtime/src/finder.rs b/tensorflow-sys-runtime/src/finder.rs index cbf6b0d22b..7455f00d16 100644 --- a/tensorflow-sys-runtime/src/finder.rs +++ b/tensorflow-sys-runtime/src/finder.rs @@ -51,32 +51,31 @@ pub fn find(library_name: &str) -> Option { None } -const ENV_TENSORFLOW_DIR: &'static str = "TENSORFLOW_DIR"; +const ENV_TENSORFLOW_DIR: &str = "TENSORFLOW_DIR"; cfg_if! { if #[cfg(any(target_os = "linux"))] { - const ENV_LIBRARY_PATH: &'static str = "LD_LIBRARY_PATH"; + const ENV_LIBRARY_PATH: &str = "LD_LIBRARY_PATH"; } else if #[cfg(target_os = "macos")] { - const ENV_LIBRARY_PATH: &'static str = "DYLD_LIBRARY_PATH"; + const ENV_LIBRARY_PATH: &str = "DYLD_LIBRARY_PATH"; } else if #[cfg(target_os = "windows")] { - const ENV_LIBRARY_PATH: &'static str = "PATH"; + const ENV_LIBRARY_PATH: &str = "PATH"; } else { // This may not work but seems like a sane default for target OS' not listed above. - const ENV_LIBRARY_PATH: &'static str = "LD_LIBRARY_PATH"; + const ENV_LIBRARY_PATH: &str = "LD_LIBRARY_PATH"; } } cfg_if! { if #[cfg(any(target_os = "linux", target_os = "macos"))] { - const DEFAULT_INSTALLATION_DIRECTORIES: &'static [&'static str] = + const DEFAULT_INSTALLATION_DIRECTORIES: &[&str] = &["/usr/local/lib", "/usr/local/lib/libtensorflow"]; } else if #[cfg(target_os = "windows")] { - const DEFAULT_INSTALLATION_DIRECTORIES: &'static [&'static str] = &[ + const DEFAULT_INSTALLATION_DIRECTORIES: &[&str] = &[ "C:\\Program Files (x86)\\Tensorflow", "C:\\Program Files (x86)\\tensorflow", ]; } else { - const DEFAULT_INSTALLATION_DIRECTORIES: &'static [&'static str] = &[]; + const DEFAULT_INSTALLATION_DIRECTORIES: &[&str] = &[]; } } - diff --git a/tensorflow-sys-runtime/src/lib.rs b/tensorflow-sys-runtime/src/lib.rs index c1fe93d7a1..77f8049d53 100644 --- a/tensorflow-sys-runtime/src/lib.rs +++ b/tensorflow-sys-runtime/src/lib.rs @@ -1,6 +1,9 @@ #![allow(non_camel_case_types)] #![allow(non_snake_case)] #![allow(non_upper_case_globals)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::missing_safety_doc)] +#![allow(deref_nullptr)] // FIXME old bindgen code has undefined behaviour in tests, see https://github.com/rust-lang/rust-bindgen/pull/2055 include!("c_api.rs"); include!("types.rs"); include!("finder.rs"); diff --git a/tensorflow-sys-runtime/src/runtime.rs b/tensorflow-sys-runtime/src/runtime.rs index dd0b47b51e..46706a0fa3 100644 --- a/tensorflow-sys-runtime/src/runtime.rs +++ b/tensorflow-sys-runtime/src/runtime.rs @@ -47,7 +47,7 @@ macro_rules! link { #[derive(Default)] pub(crate) struct Functions { $( - pub $name: Option $ret)*>, + pub $name: Option $ret)*>, )+ } @@ -108,7 +108,7 @@ macro_rules! link { "`libtensorflow` function not loaded: `", stringify!($name) )) - }).expect("an `libtensorflow` shared library is not loaded on this thread"); + }).expect("A `libtensorflow` shared library is not loaded on this thread. Did you call `load()`?"); f($($pname), *) } )+ diff --git a/tensorflow-sys/Cargo.toml b/tensorflow-sys/Cargo.toml index dca3c51894..f61e9ad265 100644 --- a/tensorflow-sys/Cargo.toml +++ b/tensorflow-sys/Cargo.toml @@ -32,6 +32,7 @@ pkg-config = "0.3.25" semver = "1.0.13" tar = "0.4.38" zip = "0.6.4" +anyhow = "1.0.100" [features] tensorflow_gpu = [] @@ -40,3 +41,4 @@ experimental = [] # This is for testing purposes; users should not use this. examples_system_alloc = [] private-docs-rs = [] # DO NOT RELY ON THIS +tensorflow_runtime_linking= [] diff --git a/tensorflow-sys/build.rs b/tensorflow-sys/build.rs index 4f97851f9f..f4d77048d0 100644 --- a/tensorflow-sys/build.rs +++ b/tensorflow-sys/build.rs @@ -12,6 +12,7 @@ use std::io::{BufWriter, Write}; use std::path::{Path, PathBuf}; use std::process::{self, Command}; +use anyhow::Context; use curl::easy::Easy; use flate2::read::GzDecoder; use semver::Version; @@ -39,7 +40,7 @@ macro_rules! log_var(($var:ident) => (log!(concat!(stringify!($var), " = {:?}"), fn main() { // If we are doing runtime linking, just return. - #[cfg(feature = "runtime_linking")] + #[cfg(feature = "tensorflow_runtime_linking")] return; // DO NOT RELY ON THIS @@ -53,10 +54,15 @@ fn main() { return; } - // Note that pkg_config will print cargo:rustc-link-lib and cargo:rustc-link-search as - // appropriate if the library is found. - if pkg_config::probe_library(LIBRARY).is_ok() { - log!("Returning early because {} was already found", LIBRARY); + if let Ok(library) = pkg_config::probe_library(LIBRARY) { + for library_directory in library.link_paths { + println!("cargo:rustc-link-search={}", library_directory.display()); + } + if target_os() != "windows" { + // There is no tensorflow_framework.dll + println!("cargo:rustc-link-lib=dylib={}", FRAMEWORK_LIBRARY); + } + println!("cargo:rustc-link-lib=dylib={}", LIBRARY); return; } @@ -66,12 +72,10 @@ fn main() { }; log_var!(force_src); - let prebuilt_supported = match (&target_arch() as &str, &target_os() as &str) { - ("x86_64", "linux") => true, - ("x86_64", "windows") => true, - ("aarch64", "macos") => true, - _ => false, - }; + let prebuilt_supported = matches!( + (target_arch().as_str(), target_os().as_str()), + ("x86_64", "linux") | ("x86_64", "windows") | ("aarch64", "macos") + ); if !force_src && prebuilt_supported { install_prebuilt(); } else { @@ -211,7 +215,7 @@ fn install_prebuilt() { VERSION, proc_type, os, arch, ext ); log_var!(binary_url); - let short_file_name = binary_url.split('/').last().unwrap(); + let short_file_name = binary_url.split('/').next_back().unwrap(); let mut base_name = short_file_name.to_string(); remove_suffix(&mut base_name, ext); log_var!(base_name); @@ -405,22 +409,32 @@ fn build_from_src() { ); if framework_library_path.exists() { fs::remove_file(&framework_library_path) - .expect(&format!("{:?} should be removable", framework_library_path)); + .with_context(|| format!("{:?} should be removable", framework_library_path)) + .unwrap(); } - fs::copy(&framework_target_bazel_bin, &framework_library_path).expect(&format!( - "{:?} should be copyable to {:?}", - framework_target_bazel_bin, framework_library_path - )); + fs::copy(&framework_target_bazel_bin, &framework_library_path) + .with_context(|| { + format!( + "{:?} should be copyable to {:?}", + framework_target_bazel_bin, framework_library_path + ) + }) + .unwrap(); let target_bazel_bin = source.join("bazel-bin").join(target_path); log!("Copying {:?} to {:?}", target_bazel_bin, library_path); if library_path.exists() { fs::remove_file(&library_path) - .expect(&format!("{:?} should be removable", library_path)); + .with_context(|| format!("{:?} should be removable", library_path)) + .unwrap() } - fs::copy(&target_bazel_bin, &library_path).expect(&format!( - "{:?} should be copyable to {:?}", - target_bazel_bin, library_path - )); + fs::copy(&target_bazel_bin, &library_path) + .with_context(|| { + format!( + "{:?} should be copyable to {:?}", + target_bazel_bin, library_path + ) + }) + .unwrap(); } symlink( framework_library_path.file_name().unwrap(), diff --git a/tensorflow-sys/examples/multiplication.rs b/tensorflow-sys/examples/multiplication.rs index 85aee8b40b..7c1e155b60 100644 --- a/tensorflow-sys/examples/multiplication.rs +++ b/tensorflow-sys/examples/multiplication.rs @@ -52,7 +52,7 @@ fn main() { let name = CString::new("a").unwrap(); let mut data = vec![1f32, 2.0, 3.0]; - let dims = vec![data.len() as i64]; + let dims = [data.len() as i64]; let input_tensor1 = nonnull!(ffi::TF_NewTensor( ffi::TF_FLOAT, dims.as_ptr(), @@ -72,7 +72,7 @@ fn main() { let name = CString::new("b").unwrap(); let mut data = vec![4f32, 5.0, 6.0]; - let dims = vec![data.len() as i64]; + let dims = [data.len() as i64]; let input_tensor2 = nonnull!(ffi::TF_NewTensor( ffi::TF_FLOAT, dims.as_ptr(), diff --git a/tensorflow-sys/src/c_api_experimental.rs b/tensorflow-sys/src/c_api_experimental.rs index 983e285bea..306f1ee3c1 100644 --- a/tensorflow-sys/src/c_api_experimental.rs +++ b/tensorflow-sys/src/c_api_experimental.rs @@ -1,5 +1,7 @@ /* automatically generated by rust-bindgen 0.59.1 */ +use crate::{TF_Library, TF_Status}; + extern "C" { pub fn TF_LoadPluggableDeviceLibrary( library_filename: *const ::std::os::raw::c_char, diff --git a/tensorflow-sys/src/lib.rs b/tensorflow-sys/src/lib.rs index 8b55aa6d11..680bf9b900 100644 --- a/tensorflow-sys/src/lib.rs +++ b/tensorflow-sys/src/lib.rs @@ -2,13 +2,20 @@ #![allow(non_snake_case)] #![allow(non_upper_case_globals)] +#[allow(deref_nullptr)] +// FIXME old bindgen code has undefined behaviour in tests, see https://github.com/rust-lang/rust-bindgen/pull/2055 +mod c_api; +pub use c_api::*; + #[cfg(feature = "eager")] mod eager; #[cfg(feature = "eager")] pub use eager::*; -include!("c_api.rs"); + +#[cfg(feature = "experimental")] +mod c_api_experimental; #[cfg(feature = "experimental")] -include!("c_api_experimental.rs"); +pub use c_api_experimental::*; pub use crate::TF_AttrType::*; pub use crate::TF_Code::*;