diff --git a/tests/string_array_test.py b/tests/string_array_test.py index 797f85c63c7b..2762e5c14cfc 100644 --- a/tests/string_array_test.py +++ b/tests/string_array_test.py @@ -62,6 +62,7 @@ def test_single_device_array(self, asarray): jax_string_array = jax.device_put( numpy_string_array, device=cpu_devices[0] ) + self.assertEqual(jax_string_array.dtype, np.dtypes.StringDType()) # type: ignore jax_string_array.block_until_ready() array_read_back = jax.device_get(jax_string_array) @@ -93,8 +94,13 @@ def test_multi_device_array(self, asarray): jax_string_array = jnp.asarray(numpy_string_array, device=sharding) else: jax_string_array = jax.device_put(numpy_string_array, device=sharding) + self.assertEqual(jax_string_array.dtype, np.dtypes.StringDType()) # type: ignore jax_string_array.block_until_ready() + self.assertLen(jax_string_array._arrays, 2) + for a in jax_string_array._arrays: + self.assertEqual(a.dtype, np.dtypes.StringDType()) # type: ignore + array_read_back = jax.device_get(jax_string_array) self.assertEqual(array_read_back.dtype, np.dtypes.StringDType()) # type: ignore np.testing.assert_array_equal(array_read_back, numpy_string_array)