diff --git a/export/orbax/export/data_processors/tf_data_processor_test.py b/export/orbax/export/data_processors/tf_data_processor_test.py index 3d274ad04..dd35d9b55 100644 --- a/export/orbax/export/data_processors/tf_data_processor_test.py +++ b/export/orbax/export/data_processors/tf_data_processor_test.py @@ -150,11 +150,14 @@ def test_suppress_x64_output(self): ) def test_convert_to_bfloat16(self): - processor = tf_data_processor.TfDataProcessor( - lambda x: 0.5 + x, name='preprocessor' - ) + v = tf.Variable(0.5, dtype=tf.float32) + + def func(x): + return v + x + + processor = tf_data_processor.TfDataProcessor(func, name='preprocessor') processor.prepare( - (tf.TensorSpec((), tf.float32)), + input_signature=(tf.TensorSpec(shape=(2, 3), dtype=tf.float32)), bfloat16_options=converter_options_v2_pb2.ConverterOptionsV2( bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions( scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL, @@ -163,8 +166,12 @@ def test_convert_to_bfloat16(self): ), ) self.assertEqual( - processor.output_signature[0], - obm.ShloTensorSpec(shape=(), dtype=obm.ShloDType.bf16), + processor.output_signature, + obm.ShloTensorSpec(shape=(2, 3), dtype=obm.ShloDType.bf16), + ) + self.assertLen(processor.concrete_function.variables, 1) + self.assertEqual( + processor.concrete_function.variables[0].dtype, tf.bfloat16 ) def test_bfloat16_convert_error(self):