Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ public class HybridizationParameters {
*/
private boolean reduceRetracingParamExists;

/**
* True iff this {@link Function}'s {@link decoratorsType} has parameter experimental_compile.
*/
private boolean experimentalCompileParamExists;

/**
* True iff this {@link Function}'s {@link decoratorsType} has parameter experminetal_relax_shapes.
*/
private boolean experimentalRelaxShapesParamExists;

public HybridizationParameters(IProgressMonitor monitor) throws BadLocationException {
FunctionDefinition functionDefinition = Function.this.getFunctionDefinition();
decoratorsType[] decoratorArray = functionDefinition.getFunctionDef().decs;
Expand Down Expand Up @@ -139,16 +149,13 @@ else if (name.id.equals(INPUT_SIGNATURE))
else if (name.id.equals(AUTOGRAPH))
// Found parameter autograph
this.autoGraphParamExists = true;
// The version of the API we are using allows
// parameter names jit_compile and
// deprecated name experimental_compile
else if (name.id.equals(JIT_COMPILE) || name.id.equals(EXPERIMENTAL_COMPILE))
else if (name.id.equals(JIT_COMPILE))
// Found parameter jit_compile/experimental_compile
this.jitCompileParamExists = true;
// The version of the API we are using allows
// parameter names reduce_retracing
// and deprecated name experimental_relax_shapes
else if (name.id.equals(REDUCE_RETRACING) || name.id.equals(EXPERIMENTAL_RELAX_SHAPES))
else if (name.id.equals(REDUCE_RETRACING))
// Found parameter reduce_retracing
// or experimental_relax_shapes
this.reduceRetracingParamExists = true;
Expand All @@ -161,6 +168,12 @@ else if (name.id.equals(EXPERIMENTAL_AUTOGRAPH_OPTIONS))
else if (name.id.equals(EXPERIMENTAL_FOLLOW_TYPE_HINTS))
// Found parameter experimental_follow_type_hints
this.experimentaFollowTypeHintsParamExists = true;
else if (name.id.equals(EXPERIMENTAL_COMPILE))
// Found parameter experimental_compile
this.experimentalCompileParamExists = true;
else if (name.id.equals(EXPERIMENTAL_RELAX_SHAPES))
// Found parameter experimental_relax_shapes
this.experimentalRelaxShapesParamExists = true;
}
}
} // else, tf.function is used without parameters.
Expand Down Expand Up @@ -237,6 +250,24 @@ public boolean hasJitCompileParam() {
public boolean hasReduceRetracingParam() {
return this.reduceRetracingParamExists;
}

/**
* True iff this {@link Function}'s {@link decoratorsType} has parameter experimental_compile.
*
* @return True iff this {@link Function} has parameter experimental_compile.
*/
public boolean hasExperimentalCompileParam() {
return this.experimentalCompileParamExists;
}

/**
* True iff this {@link Function}'s {@link decoratorsType} has parameter experimental_relax_shapes.
*
* @return True iff this {@link Function} has parameter experimental_relax_shapes.
*/
public boolean hasExperimentalRelaxShapesParam() {
return this.experimentalRelaxShapesParamExists;
}
}

private static final String TF_FUNCTION_FQN = "tensorflow.python.eager.def_function.function";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import custom
import tensorflow as tf


@custom.decorator(input_signature=None)
def func(x):
print('Tracing with', x)
return x
@tf.function
def test(x):
return x



if __name__ == '__main__':
func(1)
x = tf.constant(1)
test(x)

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import custom
import tensorflow as tf


@custom.decorator(input_signature=None)
@tf.function(autograph=False)
def func():
pass
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),), autograph=False)
def func(x):
print('Tracing with', x)
return x


if __name__ == '__main__':
func()

number = tf.constant([1.0, 1.0])
func(number)
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import tensorflow as tf
import custom


@tf.function(autograph=False)
@tf.function(jit_compile=True)
@custom.decorator(input_signature=None)
def func(x):
print('Tracing with', x)
return x


if __name__ == '__main__':
func(tf.constant(1))

if __name__ == '__main__':
func(1)
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import custom
import tensorflow as tf


@custom.decorator(input_signature=None)
@tf.function(autograph=False)
def func():
pass


if __name__ == '__main__':
func()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf


@tf.function(autograph=False)
@tf.function(jit_compile=True)
def func(x):
return x


if __name__ == '__main__':
func(tf.constant(1))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow==2.9.3
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import tensorflow as tf


@tf.function
def test(x):
return x
@tf.function(experimental_compile=True)
def func():
print("Testing")


if __name__ == '__main__':
x = tf.constant(1)
test(x)

func()
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import tensorflow as tf


@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),), autograph=False)
def func(x):
print('Tracing with', x)
return x
@tf.function(experimental_relax_shapes=True)
def func():
print("Testing")


if __name__ == '__main__':
number = tf.constant([1.0, 1.0])
func(number)

func()
Loading