Skip to content

Commit c5b0de5

Browse files
Convert IFRT's kString dtype to NumPy's variable length string dtype.
Previously, `xla::ifrt::DType::kString` would be translated to `NPY_OBJECT`. NumPy 2.0 has introduced a variable length string type, `NPY_VSTRING`, so convert to that when available. This is consistent with `xla::DtypeToIfRtDType`, which performs the opposite translation. PiperOrigin-RevId: 813472617
1 parent bf86259 commit c5b0de5

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tests/string_array_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def test_single_device_array(self, asarray):
6262
jax_string_array = jax.device_put(
6363
numpy_string_array, device=cpu_devices[0]
6464
)
65+
self.assertEqual(jax_string_array.dtype, np.dtypes.StringDType()) # type: ignore
6566
jax_string_array.block_until_ready()
6667

6768
array_read_back = jax.device_get(jax_string_array)
@@ -93,8 +94,13 @@ def test_multi_device_array(self, asarray):
9394
jax_string_array = jnp.asarray(numpy_string_array, device=sharding)
9495
else:
9596
jax_string_array = jax.device_put(numpy_string_array, device=sharding)
97+
self.assertEqual(jax_string_array.dtype, np.dtypes.StringDType()) # type: ignore
9698
jax_string_array.block_until_ready()
9799

100+
self.assertLen(jax_string_array._arrays, 2)
101+
for a in jax_string_array._arrays:
102+
self.assertEqual(a.dtype, np.dtypes.StringDType()) # type: ignore
103+
98104
array_read_back = jax.device_get(jax_string_array)
99105
self.assertEqual(array_read_back.dtype, np.dtypes.StringDType()) # type: ignore
100106
np.testing.assert_array_equal(array_read_back, numpy_string_array)

0 commit comments

Comments
 (0)