Skip to content

Commit 80b59a2

Browse files
Fix TF int64 promotion issue. (#21679)
* Fix TF int64 promotion issue. * Update. * Fix skipif. * Refine the comments. * Update.
1 parent 45b1039 commit 80b59a2

File tree

4 files changed

+52
-11
lines changed

4 files changed

+52
-11
lines changed

keras/src/backend/common/dtypes.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,16 +232,10 @@ def _resolve_weak_type(dtype, precision="32"):
232232
return f"float{precision}"
233233

234234

235-
BIT64_TO_BIT16_DTYPE = {
236-
"int32": "int16",
237-
"int64": "int16",
238-
"uint32": "uint16",
239-
"uint64": "uint16",
240-
"float32": "float16",
241-
"float64": "float16",
242-
}
243235
BIT64_TO_BIT32_DTYPE = {
244-
"int64": "int32",
236+
# Since TF variables require int64 to be placed on the GPU, we exclusively
237+
# enable the int64 dtype for TF.
238+
"int64": "int64" if config.backend() == "tensorflow" else "int32",
245239
"uint64": "uint32",
246240
"float64": "float32",
247241
"complex128": "complex64",
@@ -277,8 +271,8 @@ def _lattice_result_type(*args):
277271
if out_weak_type:
278272
out_dtype = _resolve_weak_type(out_dtype, precision=precision)
279273

280-
# Force to be 32-bit dtype when encountering 64-bit dtype.
281-
# TODO(hongyu): Add a config to enable 64-bit dtypes.
274+
# Force to be 32-bit dtype when encountering 64-bit dtype. This is to
275+
# be aligned with JAX's default behavior.
282276
out_dtype = BIT64_TO_BIT32_DTYPE.get(out_dtype, out_dtype)
283277
return out_dtype
284278

keras/src/backend/common/dtypes_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest.mock import patch
22

3+
import pytest
34
from absl.testing import parameterized
45

56
from keras.src import backend
@@ -27,6 +28,13 @@ class DtypesTest(test_case.TestCase):
2728
] + [None]
2829
if backend.backend() == "torch":
2930
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint16", "uint32")]
31+
elif backend.backend() == "tensorflow":
32+
# TODO(hongyu): Re-enable uint32 tests once we determine how to handle
33+
# dtypes.result_type(uint32, int*) -> int64 promotion.
34+
# Since TF variables require int64 to be placed on the GPU, we
35+
# exclusively enable the int64 dtype for TF. However, JAX does not
36+
# natively support int64, which prevents us from comparing the dtypes.
37+
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint32",)]
3038
elif backend.backend() == "openvino":
3139
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("complex64",)]
3240

@@ -55,6 +63,29 @@ def test_result_type_with_tensor(self, dtype1, dtype2):
5563
expected = jnp.result_type(x1_jax, x2_jax).name
5664
self.assertEqual(out, expected)
5765

66+
@parameterized.named_parameters(
67+
named_product(
68+
dtype=[
69+
"int8",
70+
"int16",
71+
"int32",
72+
"int64",
73+
"uint8",
74+
"uint16",
75+
"uint32",
76+
]
77+
)
78+
)
79+
@pytest.mark.skipif(
80+
backend.backend() != "tensorflow", reason="TensorFlow only"
81+
)
82+
def test_result_type_with_int64(self, dtype):
83+
# https://github.com/keras-team/keras/issues/21677
84+
x1 = ops.ones((1,), dtype="int64")
85+
x2 = ops.ones((1,), dtype=dtype)
86+
out = backend.result_type(x1.dtype, x2.dtype)
87+
self.assertEqual(out, "int64")
88+
5889
def test_result_type_with_none(self):
5990
import jax.numpy as jnp
6091

keras/src/backend/common/variables_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,14 @@ class VariableOpsDTypeTest(test_case.TestCase):
811811
x for x in ALL_DTYPES if x not in ("uint16", "uint32", "complex64")
812812
]
813813
INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint16", "uint32")]
814+
elif backend.backend() == "tensorflow":
815+
# TODO(hongyu): Re-enable uint32 tests once we determine how to handle
816+
# dtypes.result_type(uint32, int*) -> int64 promotion.
817+
# Since TF variables require int64 to be placed on the GPU, we
818+
# exclusively enable the int64 dtype for TF. However, JAX does not
819+
# natively support int64, which prevents us from comparing the dtypes.
820+
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint32",)]
821+
INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint32",)]
814822
elif backend.backend() == "openvino":
815823
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("complex64",)]
816824
NON_COMPLEX_DTYPES = [

keras/src/ops/numpy_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5774,6 +5774,14 @@ class NumpyDtypeTest(testing.TestCase):
57745774
if backend.backend() == "torch":
57755775
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint16", "uint32")]
57765776
INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint16", "uint32")]
5777+
elif backend.backend() == "tensorflow":
5778+
# TODO(hongyu): Re-enable uint32 tests once we determine how to handle
5779+
# dtypes.result_type(uint32, int*) -> int64 promotion.
5780+
# Since TF variables require int64 to be placed on the GPU, we
5781+
# exclusively enable the int64 dtype for TF. However, JAX does not
5782+
# natively support int64, which prevents us from comparing the dtypes.
5783+
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint32",)]
5784+
INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint32",)]
57775785

57785786
@parameterized.named_parameters(
57795787
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))

0 commit comments

Comments
 (0)