Skip to content

Commit 1fce875

Browse files
committed
Increase the absolute tolerance between gradients in CTCLossDeterministicTest.testForwardAndBackward due to the floating-point arithmetics for gfx11xx and gfx12xx. Implement test_util.gpu_gcn_arch() which returns the information on the current gcn architecture. Implement unit test for test_util.gpu_gcn_arch(). Modify GetShortDeviceDescription() in gpu_device.cc to return gcn arch as well.
1 parent 8d26802 commit 1fce875

File tree

4 files changed

+40
-2
lines changed

4 files changed

+40
-2
lines changed

tensorflow/core/common_runtime/gpu/gpu_device.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,7 @@ static string GetShortDeviceDescription(
19701970
#elif TENSORFLOW_USE_ROCM
19711971
return strings::StrCat("device: ", platform_device_id.value(),
19721972
", name: ", desc.name(),
1973+
", gcn arch: ", desc.rocm_compute_capability().gcn_arch_name(),
19731974
", pci bus id: ", desc.pci_bus_id());
19741975
#endif
19751976
}

tensorflow/python/framework/test_util.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,29 @@ def gpu_device_name() -> str:
189189
return compat.as_str(x.name)
190190
return ""
191191

192+
@tf_export("test.gpu_gcn_arch")
193+
def gpu_gcn_arch() -> str:
194+
""" Returns the GCN Arch if GPU available or an empty string.
195+
196+
This method should only be used in tests written with tf.test.TestCase
197+
198+
>>> class MyTest(tf.test.TestCase):
199+
...
200+
... if not tf_test.is_built_with_rocm():
201+
... self.skipTest("Test is only applicable for GPUs built with ROCm")
202+
...
203+
... self.assertNotEqual("", test_util.gpu_gcn_arch())
204+
205+
"""
206+
for x in device_lib.list_local_devices():
207+
if x.device_type == "GPU":
208+
desc = getattr(x, "physical_device_desc", "")
209+
gcn_arch = re.search(r"gfx[0-9]+", desc)
210+
211+
if gcn_arch:
212+
return compat.as_str(gcn_arch.group(0))
213+
214+
return ""
192215

193216
def assert_ops_in_graph(
194217
expected_ops: dict[str, str], graph: ops.Graph

tensorflow/python/framework/test_util_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from tensorflow.python.framework import tensor
4646
from tensorflow.python.framework import test_ops
4747
from tensorflow.python.framework import test_util
48+
from tensorflow.python.platform import test as tf_test
4849
from tensorflow.python.ops import array_ops
4950
from tensorflow.python.ops import control_flow_assert
5051
from tensorflow.python.ops import lookup_ops
@@ -1094,6 +1095,11 @@ def some_test(self):
10941095
some_test(None)
10951096
self.assertEqual(tested_codepaths, set(["present", "future"]))
10961097

1098+
def test_assert_gcn_arch(self):
1099+
if not tf_test.is_built_with_rocm():
1100+
self.skipTest("Test is only applicable for GPUs built with ROCm")
1101+
1102+
self.assertNotEqual("", test_util.gpu_gcn_arch())
10971103

10981104
class SkipTestTest(test_util.TensorFlowTestCase):
10991105

tensorflow/python/kernel_tests/nn_ops/ctc_loss_op_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,8 +1233,16 @@ def testForwardAndBackward(self, sparse_labels, logits_time_major):
12331233
loss_a, loss_b, gradient_a, gradient_b = self.evaluate(
12341234
(loss_a, loss_b, gradient_a, gradient_b))
12351235
self.assertAllEqual(loss_a, loss_b, "Loss mismatch")
1236-
# self.assertAllEqual(gradient_a, gradient_b, "Gradient mismatch")
1237-
self.assertAllClose(gradient_a, gradient_b, atol=5e-05)
1236+
# Determine which gcn architecture is the GPU and set the absolute
1237+
# tolerance based on that information.
1238+
# Needed on gfx11 and gfx12 due to the floating point arithmetic.
1239+
gcn_arch = test_util.gpu_gcn_arch()
1240+
if "gfx11" or "gfx12" in gcn_arch:
1241+
abs_tolerance = 1e-4
1242+
else:
1243+
abs_tolerance = 5e-5
1244+
1245+
self.assertAllClose(gradient_a, gradient_b, atol=abs_tolerance)
12381246

12391247

12401248
if __name__ == "__main__":

0 commit comments

Comments
 (0)