From de1d6f0fb3643ce7a4afb952bcf2ccda49f8f5c0 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 23 Feb 2024 20:31:08 -0500 Subject: [PATCH] Spotless fixes. --- pom.xml | 4 +- .../tensorflow/AbstractGradientAdapter.java | 29 ++-- .../src/main/java/org/tensorflow/Graph.java | 7 +- .../org/tensorflow/GraphOperationBuilder.java | 2 +- .../main/java/org/tensorflow/Signature.java | 1 + .../main/java/org/tensorflow/TensorFlow.java | 22 +-- .../internal/types/TBfloat16Mapper.java | 4 +- .../internal/types/TBoolMapper.java | 4 +- .../internal/types/TFloat16Mapper.java | 4 +- .../internal/types/TFloat32Mapper.java | 4 +- .../internal/types/TFloat64Mapper.java | 4 +- .../internal/types/TInt32Mapper.java | 4 +- .../internal/types/TInt64Mapper.java | 4 +- .../internal/types/TStringMapper.java | 4 +- .../internal/types/TUint16Mapper.java | 4 +- .../internal/types/TUint8Mapper.java | 4 +- .../org/tensorflow/op/CustomGradient.java | 8 +- .../java/org/tensorflow/op/NativeScope.java | 5 +- .../org/tensorflow/op/RawCustomGradient.java | 2 +- .../org/tensorflow/op/RawGradientAdapter.java | 8 +- .../tensorflow/op/TypedGradientAdapter.java | 15 ++- .../org/tensorflow/CustomGradientTest.java | 18 +-- .../org/tensorflow/SavedModelBundleTest.java | 2 +- .../java/org/tensorflow/TensorFlowTest.java | 13 +- .../generator/op/ClassGenerator.java | 1 - .../tensorflow/generator/op/OpGenerator.java | 106 +++++++++------ .../tensorflow/generator/op/TypeResolver.java | 5 +- .../src/main/java/module-info.java | 4 +- .../internal/c_api/AbstractTFE_Context.java | 6 +- .../c_api/AbstractTFE_ContextOptions.java | 6 +- .../internal/c_api/AbstractTFE_Op.java | 6 +- .../c_api/AbstractTFE_TensorHandle.java | 4 +- .../internal/c_api/AbstractTF_Buffer.java | 7 +- .../internal/c_api/AbstractTF_Function.java | 4 +- .../internal/c_api/AbstractTF_Graph.java | 6 +- .../AbstractTF_ImportGraphDefOptions.java | 6 +- .../internal/c_api/AbstractTF_Session.java | 4 +- .../c_api/AbstractTF_SessionOptions.java | 6 +- .../internal/c_api/AbstractTF_Status.java | 4 +- .../internal/c_api/AbstractTF_Tensor.java | 4 +- .../internal/c_api/TFJ_RuntimeLibrary.java | 6 +- .../internal/c_api/presets/tensorflow.java | 126 ++++++++++-------- .../internal/c_api/GradientTest.java | 8 +- .../internal/c_api/HelloWorldTest.java | 4 +- .../tensorflow/framework/data/Dataset.java | 6 +- .../framework/metrics/impl/MetricsHelper.java | 2 +- .../framework/optimizers/AdaGrad.java | 6 +- .../framework/optimizers/Nadam.java | 1 + .../framework/optimizers/Optimizer.java | 12 +- .../optimizers/GradientDescentTest.java | 1 - 50 files changed, 299 insertions(+), 228 deletions(-) diff --git a/pom.xml b/pom.xml index 9ba7407c652..f86e92bc69a 100644 --- a/pom.xml +++ b/pom.xml @@ -47,7 +47,7 @@ true true true - 2.38.0 + 2.43.0 @@ -564,7 +564,7 @@ - 1.17.0 + 1.20.0 diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java index 1305359ba38..2119cddaa67 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java @@ -18,15 +18,14 @@ import java.util.ArrayList; import java.util.List; - import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; -import org.tensorflow.internal.c_api.TF_Operation; -import org.tensorflow.internal.c_api.TF_Output; -import org.tensorflow.internal.c_api.TFJ_Scope; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; import org.tensorflow.internal.c_api.TFJ_GraphId; +import org.tensorflow.internal.c_api.TFJ_Scope; +import org.tensorflow.internal.c_api.TF_Operation; +import org.tensorflow.internal.c_api.TF_Output; /** Helper base class for custom gradient adapters INTERNAL USE ONLY */ public abstract class AbstractGradientAdapter extends TFJ_GradFuncAdapter { @@ -35,10 +34,17 @@ protected AbstractGradientAdapter() { super(); } - protected abstract List> apply(Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs); + protected abstract List> apply( + Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs); @Override - public int call(TFJ_GraphId nativeGraphId, TFJ_Scope nativeScope, TF_Operation nativeOperation, TF_Output nativeGradInputs, int nativeGradInputsLength, PointerPointer nativeGradOutputsPtr) { + public int call( + TFJ_GraphId nativeGraphId, + TFJ_Scope nativeScope, + TF_Operation nativeOperation, + TF_Output nativeGradInputs, + int nativeGradInputsLength, + PointerPointer nativeGradOutputsPtr) { try (PointerScope callScope = new PointerScope()) { var graph = Graph.findGraph(nativeGraphId); var operation = new GraphOperation(graph, nativeOperation); @@ -67,7 +73,8 @@ private static List> fromNativeOutputs(Graph g, TF_Output nativeOutput List> outputs = new ArrayList<>(length); for (int i = 0; i < length; ++i) { var nativeOutput = nativeOutputs.position(i); - outputs.add(i, new Output<>(new GraphOperation(g, nativeOutput.oper()), nativeOutput.index())); + outputs.add( + i, new Output<>(new GraphOperation(g, nativeOutput.oper()), nativeOutput.index())); } return outputs; } @@ -79,13 +86,15 @@ private static List> fromNativeOutputs(Graph g, TF_Output nativeOutput * @return pointer to the native array of outputs */ private static TF_Output toNativeOutputs(List> outputs) { - // Use malloc to allocate native outputs, as they will be freed by the native layer and we do not want JavaCPP to deallocate them - var nativeOutputs = new TF_Output(Pointer.malloc((long)outputs.size() * Pointer.sizeof(TF_Output.class))); + // Use malloc to allocate native outputs, as they will be freed by the native layer and we do + // not want JavaCPP to deallocate them + var nativeOutputs = + new TF_Output(Pointer.malloc((long) outputs.size() * Pointer.sizeof(TF_Output.class))); for (int i = 0; i < outputs.size(); ++i) { var output = outputs.get(i).asOutput(); var nativeOutput = nativeOutputs.getPointer(i); - nativeOutput.oper(((GraphOperation)output.op()).getUnsafeNativeHandle()); + nativeOutput.oper(((GraphOperation) output.op()).getUnsafeNativeHandle()); nativeOutput.index(output.index()); } return nativeOutputs; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 1a6d79486e3..488434c56f2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -15,6 +15,7 @@ */ package org.tensorflow; +import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_GetGraphId; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddGradientsWithPrefix; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteGraph; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FinishWhile; @@ -28,7 +29,6 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_ImportGraphDefOptionsSetPrefix; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewWhile; -import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_GetGraphId; import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayDeque; @@ -50,6 +50,7 @@ import org.bytedeco.javacpp.PointerScope; import org.bytedeco.javacpp.SizeTPointer; import org.tensorflow.exceptions.TensorFlowException; +import org.tensorflow.internal.c_api.TFJ_GraphId; import org.tensorflow.internal.c_api.TF_Buffer; import org.tensorflow.internal.c_api.TF_Function; import org.tensorflow.internal.c_api.TF_Graph; @@ -58,7 +59,6 @@ import org.tensorflow.internal.c_api.TF_Output; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_WhileParams; -import org.tensorflow.internal.c_api.TFJ_GraphId; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Op; import org.tensorflow.op.OpScope; @@ -1319,7 +1319,8 @@ private static SaverDef addVariableSaver(Graph graph) { .build(); } - private static final Map ALL_GRAPHS = Collections.synchronizedMap(new WeakHashMap<>()); + private static final Map ALL_GRAPHS = + Collections.synchronizedMap(new WeakHashMap<>()); /** * Find the graph with the matching ID. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 1103bb008e0..d68232b2598 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -15,6 +15,7 @@ */ package org.tensorflow; +import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_UnmapOperationName; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddControlInput; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddInput; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddInputList; @@ -39,7 +40,6 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrTypeList; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrValueProto; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetDevice; -import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_UnmapOperationName; import java.nio.charset.Charset; import java.util.Arrays; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 39b422585a3..9f524ef2544 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -39,6 +39,7 @@ public static class TensorDescription { /** The name of the tensor's operand in the graph */ public final String name; + /** The data type of the tensor */ public final DataType dataType; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 77424316b44..7eba6d7ce30 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -15,6 +15,8 @@ */ package org.tensorflow; +import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_HasGradient; +import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_RegisterCustomGradient; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteBuffer; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteLibraryHandle; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetAllOpList; @@ -22,17 +24,20 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadLibrary; import static org.tensorflow.internal.c_api.global.tensorflow.TF_RegisterFilesystemPlugin; import static org.tensorflow.internal.c_api.global.tensorflow.TF_Version; -import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_HasGradient; -import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_RegisterCustomGradient; import com.google.protobuf.InvalidProtocolBufferException; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; import org.bytedeco.javacpp.PointerScope; import org.tensorflow.exceptions.TensorFlowException; +import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; +import org.tensorflow.internal.c_api.TFJ_RuntimeLibrary; import org.tensorflow.internal.c_api.TF_Buffer; import org.tensorflow.internal.c_api.TF_Library; import org.tensorflow.internal.c_api.TF_Status; -import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; -import org.tensorflow.internal.c_api.TFJ_RuntimeLibrary; import org.tensorflow.op.CustomGradient; import org.tensorflow.op.RawCustomGradient; import org.tensorflow.op.RawOpInputs; @@ -40,12 +45,6 @@ import org.tensorflow.op.annotation.OpMetadata; import org.tensorflow.proto.OpList; -import java.util.Collections; -import java.util.IdentityHashMap; -import java.util.Locale; -import java.util.Set; -import java.util.stream.Collectors; - /** Static utility methods describing the TensorFlow runtime. */ public final class TensorFlow { @@ -199,7 +198,8 @@ static synchronized boolean hasGradient(String opType) { * @return {@code true} if the gradient was registered, {@code false} if there was already a * gradient registered for this op */ - public static synchronized boolean registerCustomGradient(String opType, RawCustomGradient gradient) { + public static synchronized boolean registerCustomGradient( + String opType, RawCustomGradient gradient) { if (isWindowsOs()) { throw new UnsupportedOperationException( "Custom gradient registration is not supported on Windows systems."); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java index 8aad57e54d2..72fe6dfe745 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java @@ -30,8 +30,8 @@ import org.tensorflow.types.TInt64; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_BFLOAT16} tensors to a - * n-dimensional data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_BFLOAT16} tensors to a n-dimensional data + * space. */ public final class TBfloat16Mapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java index ab82369c435..becea1bd410 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java @@ -29,8 +29,8 @@ import org.tensorflow.types.TInt64; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_BOOL} tensors to a n-dimensional - * data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_BOOL} tensors to a n-dimensional data + * space. */ public final class TBoolMapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java index f08c0c08348..e49b7df2574 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java @@ -30,8 +30,8 @@ import org.tensorflow.types.TInt64; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_HALF} tensors to a n-dimensional - * data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_HALF} tensors to a n-dimensional data + * space. */ public final class TFloat16Mapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java index 4d71872bb57..dfd7d05bcea 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java @@ -29,8 +29,8 @@ import org.tensorflow.types.TInt64; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_FLOAT} tensors to a - * n-dimensional data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_FLOAT} tensors to a n-dimensional data + * space. */ public final class TFloat32Mapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java index f866e1a4321..e5524348629 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java @@ -29,8 +29,8 @@ import org.tensorflow.types.TInt64; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_DOUBLE} tensors to a - * n-dimensional data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_DOUBLE} tensors to a n-dimensional data + * space. */ public final class TFloat64Mapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java index c34bade622c..12802b264b3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java @@ -29,8 +29,8 @@ import org.tensorflow.types.TInt64; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_INT32} tensors to a - * n-dimensional data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_INT32} tensors to a n-dimensional data + * space. */ public final class TInt32Mapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java index c2f5cd97f3a..a85cc40faff 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java @@ -28,8 +28,8 @@ import org.tensorflow.types.TInt64; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_INT64} tensors to a - * n-dimensional data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_INT64} tensors to a n-dimensional data + * space. */ public final class TInt64Mapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java index f98d203f58b..3406e2165a3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java @@ -37,8 +37,8 @@ import org.tensorflow.types.TString; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_STRING} tensors to a - * n-dimensional data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_STRING} tensors to a n-dimensional data + * space. */ public final class TStringMapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint16Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint16Mapper.java index 3ff2be5f520..d563302319a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint16Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint16Mapper.java @@ -29,8 +29,8 @@ import org.tensorflow.types.TUint16; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_Uint16} tensors to a - * n-dimensional data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_Uint16} tensors to a n-dimensional data + * space. */ public final class TUint16Mapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java index dc6f8a7593b..71c2652a7a3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java @@ -29,8 +29,8 @@ import org.tensorflow.types.TUint8; /** - * Maps memory of {@link org.tensorflow.proto.DataType#DT_UINT8} tensors to a - * n-dimensional data space. + * Maps memory of {@link org.tensorflow.proto.DataType#DT_UINT8} tensors to a n-dimensional data + * space. */ public final class TUint8Mapper extends TensorMapper { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java index 7731837060f..02acce1cb37 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java @@ -16,13 +16,12 @@ */ package org.tensorflow.op; +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.TensorFlow; import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter; -import java.util.List; - /** * A custom gradient for ops of type {@link T}. Should be registered using {@link * TensorFlow#registerCustomGradient(Class, CustomGradient)}. @@ -56,7 +55,8 @@ public interface CustomGradient { *

You should not be calling this yourself, use {@link TensorFlow#registerCustomGradient(Class, * CustomGradient)}. */ - static > TFJ_GradFuncAdapter adapter(CustomGradient gradient, Class opClass) { + static > TFJ_GradFuncAdapter adapter( + CustomGradient gradient, Class opClass) { return new TypedGradientAdapter(gradient, opClass); } -} \ No newline at end of file +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NativeScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NativeScope.java index e5804d6ea63..f7685bbad6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NativeScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NativeScope.java @@ -30,8 +30,8 @@ import org.tensorflow.GraphOperation; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; -import org.tensorflow.internal.c_api.TF_Operation; import org.tensorflow.internal.c_api.TFJ_Scope; +import org.tensorflow.internal.c_api.TF_Operation; /** A {@link Scope} implementation backed by a native scope. */ public final class NativeScope implements Scope { @@ -87,7 +87,8 @@ public void refreshNames() {} @Override public Scope withControlDependencies(Iterable controls) { - return withControlDependencyOps(StreamSupport.stream(controls.spliterator(), false) + return withControlDependencyOps( + StreamSupport.stream(controls.spliterator(), false) .map(Op::op) .collect(Collectors.toList())); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java index 53b483fbef0..c2d5496de2a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java @@ -56,4 +56,4 @@ public interface RawCustomGradient { static TFJ_GradFuncAdapter adapter(RawCustomGradient gradient) { return new RawGradientAdapter(gradient); } -} \ No newline at end of file +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java index 85fe9b4ad32..2324fb75f32 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java @@ -17,6 +17,7 @@ */ package org.tensorflow.op; +import java.util.List; import org.tensorflow.AbstractGradientAdapter; import org.tensorflow.Graph; import org.tensorflow.GraphOperation; @@ -24,8 +25,6 @@ import org.tensorflow.Output; import org.tensorflow.internal.c_api.TFJ_Scope; -import java.util.List; - /** A native adapter for {@link RawCustomGradient}. */ final class RawGradientAdapter extends AbstractGradientAdapter { @@ -37,8 +36,9 @@ final class RawGradientAdapter extends AbstractGradientAdapter { } @Override - protected List> apply(Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs) { + protected List> apply( + Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs) { Scope nativeScope = new NativeScope(scope, graph, null).withSubScope(operation.name()); return gradient.call(new Ops(nativeScope), operation, gradInputs); } -} \ No newline at end of file +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java index 6515186fef0..33d71679fae 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java @@ -17,6 +17,9 @@ */ package org.tensorflow.op; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.List; import org.tensorflow.AbstractGradientAdapter; import org.tensorflow.Graph; import org.tensorflow.GraphOperation; @@ -24,10 +27,6 @@ import org.tensorflow.Output; import org.tensorflow.internal.c_api.TFJ_Scope; -import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; -import java.util.List; - /** A native adapter for {@link CustomGradient}. */ final class TypedGradientAdapter> extends AbstractGradientAdapter { @@ -44,14 +43,16 @@ final class TypedGradientAdapter> extends AbstractGradi } @Override - protected List> apply(Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs) { + protected List> apply( + Graph graph, TFJ_Scope scope, GraphOperation operation, List> gradInputs) { try { T rawOp = ctor.newInstance(operation); - Scope nativeScope = new NativeScope(scope, graph, null).withSubScope(rawOp.getOutputs().op().name()); + Scope nativeScope = + new NativeScope(scope, graph, null).withSubScope(rawOp.getOutputs().op().name()); return gradient.call(new Ops(nativeScope), rawOp, gradInputs); } catch (InvocationTargetException | InstantiationException | IllegalAccessException e) { throw new RuntimeException("Could not instantiate Op class " + opInputClass, e); } } -} \ No newline at end of file +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java index 003d4ec38b6..81e401dbccc 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java @@ -16,6 +16,11 @@ */ package org.tensorflow; +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledOnOs; import org.junit.jupiter.api.condition.EnabledOnOs; @@ -28,12 +33,6 @@ import org.tensorflow.proto.DataType; import org.tensorflow.types.TFloat32; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import static org.junit.jupiter.api.Assertions.*; - // FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see // https://github.com/tensorflow/java/issues/486 public class CustomGradientTest { @@ -107,10 +106,11 @@ public void testCustomGradient() { @Test public void applyGradientOnMultipleNodesOfSameOpType() { try (Graph g = new Graph()) { - assertTrue(TensorFlow.registerCustomGradient( + assertTrue( + TensorFlow.registerCustomGradient( Merge.Inputs.class, - (tf, op, gradInputs) -> gradInputs.stream().map(i -> tf.constant(-10)).collect(Collectors.toList()) - )); + (tf, op, gradInputs) -> + gradInputs.stream().map(i -> tf.constant(-10)).collect(Collectors.toList()))); var tf = Ops.create(g); var initialValue = tf.constant(10); var merge1 = tf.merge(List.of(initialValue, tf.constant(20))); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 410abb15484..4b452984574 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -41,9 +41,9 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.proto.ConfigProto; import org.tensorflow.proto.RunOptions; +import org.tensorflow.proto.SaverDef; import org.tensorflow.proto.SignatureDef; import org.tensorflow.proto.TensorInfo; -import org.tensorflow.proto.SaverDef; import org.tensorflow.types.TFloat32; /** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java index 77452b3b6f7..edf7bcd7190 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java @@ -40,10 +40,15 @@ public void registeredOpList() { @Test public void loadTFTextLibrary() { - String libname = System.mapLibraryName("_sentence_breaking_ops").substring(3); // strips off the lib on macOS & Linux, don't care about Windows. - File customOpLibrary = Paths.get("", "target","tf-text-download","tensorflow_text","python","ops",libname).toFile(); - - // Disable this test if the tf-text library is not available. This may happen on some platforms (e.g. Windows) + String libname = + System.mapLibraryName("_sentence_breaking_ops") + .substring(3); // strips off the lib on macOS & Linux, don't care about Windows. + File customOpLibrary = + Paths.get("", "target", "tf-text-download", "tensorflow_text", "python", "ops", libname) + .toFile(); + + // Disable this test if the tf-text library is not available. This may happen on some platforms + // (e.g. Windows) assumeTrue(customOpLibrary.exists()); OpList opList = TensorFlow.loadLibrary(customOpLibrary.getAbsolutePath()); diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/ClassGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/ClassGenerator.java index 158ede0cbb8..51151992194 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/ClassGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/ClassGenerator.java @@ -45,7 +45,6 @@ import java.util.Set; import java.util.StringJoiner; import javax.lang.model.element.Modifier; - import org.tensorflow.proto.ApiDef; import org.tensorflow.proto.ApiDef.Endpoint; import org.tensorflow.proto.ApiDef.Visibility; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/OpGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/OpGenerator.java index e4187d94916..17607b2e937 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/OpGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/OpGenerator.java @@ -18,17 +18,6 @@ import com.google.protobuf.TextFormat; import com.squareup.javapoet.JavaFile; import com.squareup.javapoet.TypeSpec; -import org.bytedeco.javacpp.BytePointer; -import org.springframework.core.io.support.PathMatchingResourcePatternResolver; -import org.tensorflow.internal.c_api.TF_ApiDefMap; -import org.tensorflow.internal.c_api.TF_Buffer; -import org.tensorflow.internal.c_api.TF_Status; -import org.tensorflow.internal.c_api.global.tensorflow; -import org.tensorflow.proto.ApiDef; -import org.tensorflow.proto.ApiDefs; -import org.tensorflow.proto.OpDef; -import org.tensorflow.proto.OpList; - import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; @@ -48,6 +37,16 @@ import java.util.Map; import java.util.Scanner; import java.util.stream.Collectors; +import org.bytedeco.javacpp.BytePointer; +import org.springframework.core.io.support.PathMatchingResourcePatternResolver; +import org.tensorflow.internal.c_api.TF_ApiDefMap; +import org.tensorflow.internal.c_api.TF_Buffer; +import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.internal.c_api.global.tensorflow; +import org.tensorflow.proto.ApiDef; +import org.tensorflow.proto.ApiDefs; +import org.tensorflow.proto.OpDef; +import org.tensorflow.proto.OpList; public final class OpGenerator { @@ -68,7 +67,8 @@ public final class OpGenerator { + "=======================================================================*/" + "\n"; - private static final String HELP_TEXT = "Args should be: [--help] [-p ] [-a ] [-o ] [-c] []"; + private static final String HELP_TEXT = + "Args should be: [--help] [-p ] [-a ] [-o ] [-c] []"; private static final String DEFAULT_OP_DEF_FILE = "org/tensorflow/ops.pbtxt"; @@ -148,7 +148,8 @@ public static void main(String[] args) throws IOException, URISyntaxException { var opDefsFile = OpGenerator.class.getClassLoader().getResource(DEFAULT_OP_DEF_FILE); if (opDefsFile == null) { - throw new FileNotFoundException("\"" + DEFAULT_OP_DEF_FILE + "\" cannot be found in native artifact"); + throw new FileNotFoundException( + "\"" + DEFAULT_OP_DEF_FILE + "\" cannot be found in native artifact"); } try (var opDefsInput = opDefsFile.openStream()) { opList = readOpList(opDefsFile.getFile(), opDefsInput); @@ -208,7 +209,8 @@ private static OpList readOpList(String filename, InputStream protoInput) { private final File outputDir; private final boolean createMissingApiDefs; - private OpGenerator(String basePackage, String apiDefsPath, File outputDir, boolean createMissingApiDefs) { + private OpGenerator( + String basePackage, String apiDefsPath, File outputDir, boolean createMissingApiDefs) { this.basePackage = basePackage; this.apiDefsPath = Path.of(apiDefsPath); this.outputDir = outputDir; @@ -221,7 +223,8 @@ private Map buildDefMap(OpList opList) { apiDefMap = tensorflow.TF_NewApiDefMap(TF_Buffer.newBufferFromString(opList), status); status.throwExceptionIfNotOK(); - // Check if there is any missing APIs in the provided path, if so give a chance to the invoker of this generator + // Check if there is any missing APIs in the provided path, if so give a chance to the invoker + // of this generator // to create one before continuing for (OpDef opDef : opList.getOpList()) { var apiDefFile = apiDefsPath.resolve("api_def_" + opDef.getName() + ".pbtxt").toFile(); @@ -241,13 +244,15 @@ private Map buildDefMap(OpList opList) { Map defs = new LinkedHashMap<>(); for (OpDef opDef : opList.getOpList()) { - var apiDef = tensorflow.TF_ApiDefMapGet(apiDefMap, opDef.getName(), opDef.getName().length(), status); + var apiDef = + tensorflow.TF_ApiDefMapGet( + apiDefMap, opDef.getName(), opDef.getName().length(), status); defs.put(opDef, ApiDef.parseFrom(apiDef.copyData())); } return defs; } catch (Exception e) { - throw e instanceof RuntimeException ? (RuntimeException)e : new RuntimeException(e); + throw e instanceof RuntimeException ? (RuntimeException) e : new RuntimeException(e); } finally { if (apiDefMap != null) { @@ -258,34 +263,48 @@ private Map buildDefMap(OpList opList) { private void mergeBaseApiDefs(TF_ApiDefMap apiDefMap, TF_Status status) { try { - var resourceResolver = new PathMatchingResourcePatternResolver(OpGenerator.class.getClassLoader()); + var resourceResolver = + new PathMatchingResourcePatternResolver(OpGenerator.class.getClassLoader()); var apiDefs = resourceResolver.getResources("org/tensorflow/base_api/api_def_*.pbtxt"); for (var apiDef : apiDefs) { try (var apiDefInput = apiDef.getInputStream()) { - tensorflow.TF_ApiDefMapPut(apiDefMap, new BytePointer(apiDefInput.readAllBytes()), apiDef.contentLength(), status); + tensorflow.TF_ApiDefMapPut( + apiDefMap, + new BytePointer(apiDefInput.readAllBytes()), + apiDef.contentLength(), + status); status.throwExceptionIfNotOK(); } catch (IOException e) { - throw new RuntimeException("Failed to parse API definition in resource \"" + apiDef.getURI() + "\"", e); + throw new RuntimeException( + "Failed to parse API definition in resource \"" + apiDef.getURI() + "\"", e); } } } catch (IOException e) { - throw new RuntimeException("Failed to browse API definitions in resource folder \"" + apiDefsPath + "\"", e); + throw new RuntimeException( + "Failed to browse API definitions in resource folder \"" + apiDefsPath + "\"", e); } } private void mergeApiDefs(TF_ApiDefMap apiDefMap, TF_Status status) { try { - Files.walk(apiDefsPath).filter(p -> p.toString().endsWith(".pbtxt")).forEach(p -> { - try { - byte[] content = Files.readAllBytes(p); - tensorflow.TF_ApiDefMapPut(apiDefMap, new BytePointer(content), content.length, status); - status.throwExceptionIfNotOK(); - } catch (IOException e) { - throw new RuntimeException("Failed to parse API definition in resource file \"" + p.toString() + "\"", e); - } - }); + Files.walk(apiDefsPath) + .filter(p -> p.toString().endsWith(".pbtxt")) + .forEach( + p -> { + try { + byte[] content = Files.readAllBytes(p); + tensorflow.TF_ApiDefMapPut( + apiDefMap, new BytePointer(content), content.length, status); + status.throwExceptionIfNotOK(); + } catch (IOException e) { + throw new RuntimeException( + "Failed to parse API definition in resource file \"" + p.toString() + "\"", + e); + } + }); } catch (IOException e) { - throw new RuntimeException("Failed to browse API definitions in resource folder \"" + apiDefsPath + "\"", e); + throw new RuntimeException( + "Failed to browse API definitions in resource folder \"" + apiDefsPath + "\"", e); } } @@ -296,23 +315,28 @@ private void createApiDef(OpDef opDef, File apiDefFile) throws IOException { ApiDef.Visibility visibility = null; do { - System.out.print(" Choose visibility of this op [v]isible/[h]idden/[s]kip/[d]efault (default=d): "); + System.out.print( + " Choose visibility of this op [v]isible/[h]idden/[s]kip/[d]efault (default=d): "); var value = USER_PROMPT.nextLine().trim(); if (!value.isEmpty()) { switch (value) { - case "V": case "v": + case "V": + case "v": visibility = ApiDef.Visibility.VISIBLE; apiDef.setVisibility(visibility); break; - case "H": case "h": + case "H": + case "h": visibility = ApiDef.Visibility.HIDDEN; apiDef.setVisibility(visibility); break; - case "S": case "s": + case "S": + case "s": visibility = ApiDef.Visibility.SKIP; apiDef.setVisibility(visibility); break; - case "D": case "d": + case "D": + case "d": visibility = ApiDef.Visibility.DEFAULT_VISIBILITY; break; default: @@ -350,7 +374,10 @@ private void createApiDef(OpDef opDef, File apiDefFile) throws IOException { } catch (Exception e) { // If something goes wrong, erase the file we've just created if (!apiDefFile.delete()) { - System.err.println("Cannot delete invalid API definition file \"" + apiDefFile.getPath() + "\", please clean up manually"); + System.err.println( + "Cannot delete invalid API definition file \"" + + apiDefFile.getPath() + + "\", please clean up manually"); } throw e; } @@ -407,7 +434,7 @@ private void generate(OpList opList) { return new FullOpDef( entry.getKey(), entry.getValue(), - basePackage, + basePackage, basePackage + "." + pack, pack, name, @@ -425,8 +452,7 @@ private void generate(OpList opList) { statefulPairs.forEach( (pair) -> { - pair.buildOpClasses() - .forEach((spec) -> writeToFile(spec, pair.getPackageName())); + pair.buildOpClasses().forEach((spec) -> writeToFile(spec, pair.getPackageName())); }); } } diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/TypeResolver.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/TypeResolver.java index 54e7ca25000..92dd7c951c7 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/TypeResolver.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/TypeResolver.java @@ -26,7 +26,6 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; - import org.tensorflow.proto.AttrValue; import org.tensorflow.proto.DataType; import org.tensorflow.proto.OpDef; @@ -41,6 +40,7 @@ final class TypeResolver { static TypeName WILDCARD = WildcardTypeName.subtypeOf(TypeName.OBJECT); static TypeName STRING = TypeName.get(java.lang.String.class); + /** Data types that are real numbers. */ private static final Set realNumberTypes = new HashSet<>(); @@ -66,10 +66,13 @@ final class TypeResolver { /** The op def to get types for. */ private final OpDef op; + /** The processed argument types. */ private final Map argTypes = new HashMap<>(); + /** Known types. Not simply a cache. */ private final Map known = new HashMap<>(); + /** * Attributes that were reached while getting the types of inputs. * diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/module-info.java b/tensorflow-core/tensorflow-core-native/src/main/java/module-info.java index 53f219abd81..2f7628187e3 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/module-info.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/module-info.java @@ -23,7 +23,9 @@ exports org.tensorflow.internal; exports org.tensorflow.internal.c_api; exports org.tensorflow.internal.c_api.global; - exports org.tensorflow.internal.c_api.presets to org.bytedeco.javacpp,tensorflow; + exports org.tensorflow.internal.c_api.presets to + org.bytedeco.javacpp, + tensorflow; exports org.tensorflow.proto; exports org.tensorflow.proto.data; exports org.tensorflow.proto.data.model; diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Context.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Context.java index ff84bfe3b37..c96e38cdac5 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Context.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Context.java @@ -17,12 +17,12 @@ package org.tensorflow.internal.c_api; -import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.annotation.Properties; - import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteContext; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_NewContext; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.annotation.Properties; + @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTFE_Context extends Pointer { protected static class DeleteDeallocator extends TFE_Context implements Pointer.Deallocator { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_ContextOptions.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_ContextOptions.java index a0177b27079..d40c379485f 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_ContextOptions.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_ContextOptions.java @@ -17,12 +17,12 @@ package org.tensorflow.internal.c_api; -import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.annotation.Properties; - import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteContextOptions; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_NewContextOptions; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.annotation.Properties; + @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTFE_ContextOptions extends Pointer { protected static class DeleteDeallocator extends TFE_ContextOptions diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Op.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Op.java index fc3e1d36800..107dbb0c207 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Op.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_Op.java @@ -17,12 +17,12 @@ package org.tensorflow.internal.c_api; -import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.annotation.Properties; - import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteOp; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_NewOp; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.annotation.Properties; + @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTFE_Op extends Pointer { protected static class DeleteDeallocator extends TFE_Op implements Pointer.Deallocator { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_TensorHandle.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_TensorHandle.java index 0fa1d086be8..943f67e54e4 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_TensorHandle.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTFE_TensorHandle.java @@ -17,11 +17,11 @@ package org.tensorflow.internal.c_api; +import static org.tensorflow.internal.c_api.global.tensorflow.*; + import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.annotation.Properties; -import static org.tensorflow.internal.c_api.global.tensorflow.*; - @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTFE_TensorHandle extends Pointer { protected static class DeleteDeallocator extends TFE_TensorHandle implements Pointer.Deallocator { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Buffer.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Buffer.java index a3c1402cccc..108549a875b 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Buffer.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Buffer.java @@ -17,15 +17,14 @@ package org.tensorflow.internal.c_api; +import static org.tensorflow.internal.c_api.global.tensorflow.*; + import com.google.protobuf.Message; +import java.nio.ByteBuffer; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.annotation.Properties; -import java.nio.ByteBuffer; - -import static org.tensorflow.internal.c_api.global.tensorflow.*; - @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_Buffer extends Pointer { protected static class DeleteDeallocator extends TF_Buffer implements Pointer.Deallocator { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java index a91614c4a94..622cc5bb356 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java @@ -15,11 +15,11 @@ */ package org.tensorflow.internal.c_api; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteFunction; + import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.annotation.Properties; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteFunction; - @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_Function extends Pointer { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Graph.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Graph.java index ab68230679c..3cc7624ab71 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Graph.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Graph.java @@ -17,12 +17,12 @@ package org.tensorflow.internal.c_api; -import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.annotation.Properties; - import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteGraph; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.annotation.Properties; + @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_Graph extends Pointer { protected static class DeleteDeallocator extends TF_Graph implements Pointer.Deallocator { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_ImportGraphDefOptions.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_ImportGraphDefOptions.java index 0bb0ac8c140..aaefa5d540f 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_ImportGraphDefOptions.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_ImportGraphDefOptions.java @@ -17,12 +17,12 @@ package org.tensorflow.internal.c_api; -import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.annotation.Properties; - import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteImportGraphDefOptions; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewImportGraphDefOptions; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.annotation.Properties; + @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_ImportGraphDefOptions extends Pointer { protected static class DeleteDeallocator extends TF_ImportGraphDefOptions diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java index 027301f3e08..ef20d5a09a8 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java @@ -17,14 +17,14 @@ package org.tensorflow.internal.c_api; +import static org.tensorflow.internal.c_api.global.tensorflow.*; + import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; import org.bytedeco.javacpp.annotation.Properties; -import static org.tensorflow.internal.c_api.global.tensorflow.*; - @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_Session extends Pointer { protected static class DeleteDeallocator extends TF_Session implements Pointer.Deallocator { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_SessionOptions.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_SessionOptions.java index 30e2ea55934..07692027919 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_SessionOptions.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_SessionOptions.java @@ -17,12 +17,12 @@ package org.tensorflow.internal.c_api; -import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.annotation.Properties; - import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSessionOptions; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSessionOptions; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.annotation.Properties; + @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_SessionOptions extends Pointer { protected static class DeleteDeallocator extends TF_SessionOptions diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Status.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Status.java index 8bc2d220229..24d6f6b9f35 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Status.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Status.java @@ -17,6 +17,8 @@ package org.tensorflow.internal.c_api; +import static org.tensorflow.internal.c_api.global.tensorflow.*; + import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.annotation.Properties; import org.tensorflow.exceptions.TFFailedPreconditionException; @@ -28,8 +30,6 @@ import org.tensorflow.exceptions.TFUnimplementedException; import org.tensorflow.exceptions.TensorFlowException; -import static org.tensorflow.internal.c_api.global.tensorflow.*; - @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_Status extends Pointer { protected static class DeleteDeallocator extends TF_Status implements Pointer.Deallocator { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java index dd3e43f6913..d3e2ddee335 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java @@ -17,11 +17,11 @@ package org.tensorflow.internal.c_api; +import static org.tensorflow.internal.c_api.global.tensorflow.*; + import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.annotation.Properties; -import static org.tensorflow.internal.c_api.global.tensorflow.*; - @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_Tensor extends Pointer { protected static class DeleteDeallocator extends TF_Tensor implements Pointer.Deallocator { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/TFJ_RuntimeLibrary.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/TFJ_RuntimeLibrary.java index 75df30123f1..f1c0927db02 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/TFJ_RuntimeLibrary.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/TFJ_RuntimeLibrary.java @@ -15,12 +15,11 @@ package org.tensorflow.internal.c_api; -import org.tensorflow.internal.c_api.global.tensorflow; - import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; +import org.tensorflow.internal.c_api.global.tensorflow; /** * Helper class for loading the TensorFlow Java native library. @@ -130,7 +129,8 @@ private static boolean isLoaded() { } private static boolean resourceExists(String baseName) { - return TFJ_RuntimeLibrary.class.getClassLoader().getResource(makeResourceName(baseName)) != null; + return TFJ_RuntimeLibrary.class.getClassLoader().getResource(makeResourceName(baseName)) + != null; } private static String getVersionedLibraryName(String libFilename) { diff --git a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java index e4c9b198b86..730286c5282 100644 --- a/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java +++ b/tensorflow-core/tensorflow-core-native/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java @@ -16,11 +16,10 @@ */ package org.tensorflow.internal.c_api.presets; +import java.util.List; import org.bytedeco.javacpp.ClassProperties; import org.bytedeco.javacpp.LoadEnabled; import org.bytedeco.javacpp.Loader; -import org.bytedeco.javacpp.annotation.Adapter; -import org.bytedeco.javacpp.annotation.Cast; import org.bytedeco.javacpp.annotation.NoException; import org.bytedeco.javacpp.annotation.Platform; import org.bytedeco.javacpp.annotation.Properties; @@ -28,13 +27,6 @@ import org.bytedeco.javacpp.tools.InfoMap; import org.bytedeco.javacpp.tools.InfoMapper; -import java.lang.annotation.Documented; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; -import java.util.List; - /** * @author Samuel Audet */ @@ -59,7 +51,8 @@ "tensorflow/c/eager/c_api_experimental.h", // Following are C API extensions maintained within TF Java, see src/main/native. - // Binding directly the C++ API with JavaCPP turned out to be too precarious between different releases, + // Binding directly the C++ API with JavaCPP turned out to be too precarious between + // different releases, // so it is simpler to write our own C API only exposing what we need from it. "tfj_graph.h", "tfj_scope.h", @@ -137,58 +130,73 @@ public class tensorflow implements LoadEnabled, InfoMapper { @Override public void map(InfoMap infoMap) { infoMap - .put(new Info("TF_CAPI_EXPORT", "TF_Bool", "TF_GUARDED_BY", "TF_MUST_USE_RESULT") - .cppTypes() - .annotations()) - .put(new Info("TF_Buffer::data") - .javaText("public native @Const Pointer data(); public native TF_Buffer data(Pointer data);")) - .put(new Info("TF_Status") - .pointerTypes("TF_Status") - .base("org.tensorflow.internal.c_api.AbstractTF_Status")) - .put(new Info("TF_Buffer") - .pointerTypes("TF_Buffer") - .base("org.tensorflow.internal.c_api.AbstractTF_Buffer")) - .put(new Info("TF_Tensor") - .pointerTypes("TF_Tensor") - .base("org.tensorflow.internal.c_api.AbstractTF_Tensor")) - .put(new Info("TF_Session") - .pointerTypes("TF_Session") - .base("org.tensorflow.internal.c_api.AbstractTF_Session")) - .put(new Info("TF_SessionOptions") - .pointerTypes("TF_SessionOptions") - .base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions")) - .put(new Info("TF_Graph") - .pointerTypes("TF_Graph") - .base("org.tensorflow.internal.c_api.AbstractTF_Graph") - .purify()) - .put(new Info("TF_Function") - .pointerTypes("TF_Function") - .base("org.tensorflow.internal.c_api.AbstractTF_Function")) - .put(new Info("TF_ImportGraphDefOptions") - .pointerTypes("TF_ImportGraphDefOptions") - .base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions")) - .put(new Info("TFE_Context") - .pointerTypes("TFE_Context") - .base("org.tensorflow.internal.c_api.AbstractTFE_Context")) - .put(new Info("TFE_ContextOptions") - .pointerTypes("TFE_ContextOptions") - .base("org.tensorflow.internal.c_api.AbstractTFE_ContextOptions")) - .put(new Info("TFE_Op") - .pointerTypes("TFE_Op") - .base("org.tensorflow.internal.c_api.AbstractTFE_Op")) - .put(new Info("TFE_TensorHandle") - .pointerTypes("TFE_TensorHandle") - .base("org.tensorflow.internal.c_api.AbstractTFE_TensorHandle")) - .put(new Info("TF_WhileParams") - .purify()) - .put(new Info("TFE_CustomDeviceTensorHandle::deallocator") - .javaNames("cdt_deallocator") - ); + .put( + new Info("TF_CAPI_EXPORT", "TF_Bool", "TF_GUARDED_BY", "TF_MUST_USE_RESULT") + .cppTypes() + .annotations()) + .put( + new Info("TF_Buffer::data") + .javaText( + "public native @Const Pointer data(); public native TF_Buffer data(Pointer data);")) + .put( + new Info("TF_Status") + .pointerTypes("TF_Status") + .base("org.tensorflow.internal.c_api.AbstractTF_Status")) + .put( + new Info("TF_Buffer") + .pointerTypes("TF_Buffer") + .base("org.tensorflow.internal.c_api.AbstractTF_Buffer")) + .put( + new Info("TF_Tensor") + .pointerTypes("TF_Tensor") + .base("org.tensorflow.internal.c_api.AbstractTF_Tensor")) + .put( + new Info("TF_Session") + .pointerTypes("TF_Session") + .base("org.tensorflow.internal.c_api.AbstractTF_Session")) + .put( + new Info("TF_SessionOptions") + .pointerTypes("TF_SessionOptions") + .base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions")) + .put( + new Info("TF_Graph") + .pointerTypes("TF_Graph") + .base("org.tensorflow.internal.c_api.AbstractTF_Graph") + .purify()) + .put( + new Info("TF_Function") + .pointerTypes("TF_Function") + .base("org.tensorflow.internal.c_api.AbstractTF_Function")) + .put( + new Info("TF_ImportGraphDefOptions") + .pointerTypes("TF_ImportGraphDefOptions") + .base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions")) + .put( + new Info("TFE_Context") + .pointerTypes("TFE_Context") + .base("org.tensorflow.internal.c_api.AbstractTFE_Context")) + .put( + new Info("TFE_ContextOptions") + .pointerTypes("TFE_ContextOptions") + .base("org.tensorflow.internal.c_api.AbstractTFE_ContextOptions")) + .put( + new Info("TFE_Op") + .pointerTypes("TFE_Op") + .base("org.tensorflow.internal.c_api.AbstractTFE_Op")) + .put( + new Info("TFE_TensorHandle") + .pointerTypes("TFE_TensorHandle") + .base("org.tensorflow.internal.c_api.AbstractTFE_TensorHandle")) + .put(new Info("TF_WhileParams").purify()) + .put(new Info("TFE_CustomDeviceTensorHandle::deallocator").javaNames("cdt_deallocator")); - // TensorFlow is remapping all TSL symbols into its own namespace, so avoid generate bindings that requires linkage + // TensorFlow is remapping all TSL symbols into its own namespace, so avoid generate bindings + // that requires linkage // to TSL symbols directly (at this time 02/12/2024, this is still not possible in Windows, see // https://github.com/tensorflow/tensorflow/issues/62579) - infoMap.put(new Info("TSL_Status", "TSL_PayloadVisitor", "TF_PayloadVisitor", "TF_ForEachPayload").skip()); + infoMap.put( + new Info("TSL_Status", "TSL_PayloadVisitor", "TF_PayloadVisitor", "TF_ForEachPayload") + .skip()); // This C++-API dependent method appears somehow at the bottom of c/eager/c_api.h, skip it infoMap.put(new Info("TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)").skip()); diff --git a/tensorflow-core/tensorflow-core-native/src/test/java/org/tensorflow/internal/c_api/GradientTest.java b/tensorflow-core/tensorflow-core-native/src/test/java/org/tensorflow/internal/c_api/GradientTest.java index 2b86829db3b..cd0ef6cc1dc 100644 --- a/tensorflow-core/tensorflow-core-native/src/test/java/org/tensorflow/internal/c_api/GradientTest.java +++ b/tensorflow-core/tensorflow-core-native/src/test/java/org/tensorflow/internal/c_api/GradientTest.java @@ -16,15 +16,15 @@ */ package org.tensorflow.internal.c_api; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.DisabledOnOs; -import org.junit.jupiter.api.condition.OS; - import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_HasGradient; import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_RegisterCustomGradient; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; + // WARNING: Gradient registry in native library is stateful across all tests @DisabledOnOs(OS.WINDOWS) public class GradientTest { diff --git a/tensorflow-core/tensorflow-core-native/src/test/java/org/tensorflow/internal/c_api/HelloWorldTest.java b/tensorflow-core/tensorflow-core-native/src/test/java/org/tensorflow/internal/c_api/HelloWorldTest.java index dbc48c7cf51..98f1c08f261 100644 --- a/tensorflow-core/tensorflow-core-native/src/test/java/org/tensorflow/internal/c_api/HelloWorldTest.java +++ b/tensorflow-core/tensorflow-core-native/src/test/java/org/tensorflow/internal/c_api/HelloWorldTest.java @@ -15,11 +15,11 @@ package org.tensorflow.internal.c_api; -import org.junit.jupiter.api.Test; - import static org.junit.jupiter.api.Assertions.assertTrue; import static org.tensorflow.internal.c_api.global.tensorflow.TF_Version; +import org.junit.jupiter.api.Test; + public class HelloWorldTest { @Test diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java index 026fb0a42ef..6267591c298 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java @@ -296,7 +296,11 @@ public static Dataset fromTensorSlices( public static Dataset tfRecordDataset( Ops tf, String filename, String compressionType, long bufferSize) { return new TFRecordDataset( - tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize), tf.constant(0L)); + tf, + tf.constant(filename), + tf.constant(compressionType), + tf.constant(bufferSize), + tf.constant(0L)); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 6fd5424db3f..12f7d0cbb3d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -661,7 +661,7 @@ private static Operand filterTopK(Ops tf, Operand x, i Class type = x.type(); Shape xShape = x.shape(); // top has the same rank as x; the last dimension becomes indices of the topK features. - TopK top = tf.nn.topK(x, tf.constant(topK), new Options[]{TopK.sorted(false)}); + TopK top = tf.nn.topK(x, tf.constant(topK), new Options[] {TopK.sorted(false)}); // oneHot has an additional dimension: the one-hot representation of each topK index. OneHot oneHot = tf.oneHot( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index fc11f60e1f4..5901a28f25f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -144,7 +144,11 @@ private void createAdaGradSlot(Output v) { protected Op applyDense(Ops deps, Output gradient, Output variable) { Variable slot = getSlot(variable, ACCUMULATOR).get(); return deps.train.applyAdagrad( - variable, slot, deps.dtypes.cast(deps.constant(learningRate), gradient.type()), gradient, opts); + variable, + slot, + deps.dtypes.cast(deps.constant(learningRate), gradient.type()), + gradient, + opts); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index 9de52eb371b..680a6bdcfbc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -33,6 +33,7 @@ public class Nadam extends Optimizer { public static final String MOMENTUM = "momentum"; private static final float DECAY_BASE = 0.96f; private static final float DECAY = 0.004f; + /** The learning rate. */ private final float learningRate; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 59129e8c103..4cbee44d226 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -37,11 +37,14 @@ public abstract class Optimizer { public static final String VARIABLE_V2 = "VariableV2"; + /** Global state variables */ // TODO make this be used. protected final List> globals; + /** The Graph this optimizer is operating on. */ protected final Graph graph; + /** The ops builder for the graph. */ protected final Ops tf; @@ -168,7 +171,11 @@ public Op applyGradients(List> gradsAndVars, String gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList()); createSlots(variables); - List gradients = gradsAndVars.stream().map(GradAndVar::getGradient).filter(g -> !g.isClosed()).collect(Collectors.toList()); + List gradients = + gradsAndVars.stream() + .map(GradAndVar::getGradient) + .filter(g -> !g.isClosed()) + .collect(Collectors.toList()); Ops tfOpsGrads = tf.withControlDependencies(gradients); Optional prepOp = prepare(name + "/prepare"); @@ -275,7 +282,8 @@ private Op applyDense(Ops opDependencies, GradAndVar gradVa * @param The type of the variable. * @return An operand which applies the desired optimizer update to the variable. */ - protected abstract Op applyDense(Ops opDependencies, Output gradient, Output variable); + protected abstract Op applyDense( + Ops opDependencies, Output gradient, Output variable); /** * Gathers up the update operations into a single op that can be used as a run target. diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index d4bc0a7346f..90891ee748b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -9,7 +9,6 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Result;