Skip to content

Commit 40cd3a8

Browse files
committed
do not make newly created arrays immutable
1 parent b997ba3 commit 40cd3a8

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

src/methods.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ is_objectable(PyObject *m, PyObject *a) {
247247
Py_RETURN_TRUE;
248248
}
249249

250-
250+
// Convert array to the dtype provided. NOTE: mutable arrays will be returned unless the input array is immutable and no dtype change is needed
251251
PyObject*
252252
astype_array(PyObject* m, PyObject* args) {
253253

@@ -272,12 +272,12 @@ astype_array(PyObject* m, PyObject* args) {
272272

273273
if (PyArray_EquivTypes(PyArray_DESCR(array), dtype)) {
274274
Py_DECREF(dtype);
275+
275276
if (PyArray_ISWRITEABLE(array)) {
276277
PyObject* result = PyArray_NewCopy(array, NPY_ANYORDER);
277278
if (!result) {
278279
return NULL;
279280
}
280-
PyArray_CLEARFLAGS((PyArrayObject *)result, NPY_ARRAY_WRITEABLE);
281281
return result;
282282
}
283283
else { // already immutable
@@ -319,7 +319,6 @@ astype_array(PyObject* m, PyObject* args) {
319319
PyArray_ITER_NEXT(it);
320320
}
321321
Py_DECREF(it);
322-
PyArray_CLEARFLAGS((PyArrayObject *)result, NPY_ARRAY_WRITEABLE);
323322
return result;
324323
}
325324
}
@@ -330,7 +329,6 @@ astype_array(PyObject* m, PyObject* args) {
330329
Py_DECREF(dtype);
331330
return NULL;
332331
}
333-
PyArray_CLEARFLAGS((PyArrayObject *)result, NPY_ARRAY_WRITEABLE);
334332
return result;
335333
}
336334

test/test_astype_array.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ def test_astype_array_a3(self) -> None:
2828

2929
a2 = astype_array(a1, np.int8)
3030
self.assertEqual(a2.dtype, np.dtype(np.int8))
31-
self.assertFalse(a2.flags.writeable)
31+
self.assertTrue(a2.flags.writeable)
3232

3333
def test_astype_array_b1(self) -> None:
3434
a1 = np.array(['2021', '2024'], dtype=np.datetime64)
3535

3636
a2 = astype_array(a1, np.object_)
3737
self.assertEqual(a2.dtype, np.dtype(np.object_))
38-
self.assertFalse(a2.flags.writeable)
38+
self.assertTrue(a2.flags.writeable)
3939
self.assertEqual(list(a2), [np.datetime64('2021'), np.datetime64('2024')])
4040

4141

@@ -44,7 +44,7 @@ def test_astype_array_b2(self) -> None:
4444

4545
a2 = astype_array(a1, np.object_)
4646
self.assertEqual(a2.dtype, np.dtype(np.object_))
47-
self.assertFalse(a2.flags.writeable)
47+
self.assertTrue(a2.flags.writeable)
4848
self.assertEqual(list(a2), [np.datetime64('2021'), np.datetime64('1642')])
4949

5050

@@ -53,7 +53,7 @@ def test_astype_array_b3(self) -> None:
5353

5454
a2 = astype_array(a1, np.object_)
5555
self.assertEqual(a2.dtype, np.dtype(np.object_))
56-
self.assertFalse(a2.flags.writeable)
56+
self.assertTrue(a2.flags.writeable)
5757
self.assertEqual(
5858
list(list(a) for a in a2),
5959
[[np.datetime64('2021'), np.datetime64('2024')], [np.datetime64('1984'), np.datetime64('1642')]])
@@ -64,7 +64,7 @@ def test_astype_array_b4(self) -> None:
6464
a2 = astype_array(a1, np.object_)
6565
self.assertEqual(a2.dtype, np.dtype(np.object_))
6666
self.assertEqual(a2.shape, (2, 3))
67-
self.assertFalse(a2.flags.writeable)
67+
self.assertTrue(a2.flags.writeable)
6868
self.assertEqual(
6969
list(list(a) for a in a2),
7070
[[np.datetime64('2021'), np.datetime64('2024'), np.datetime64('1532')],
@@ -81,7 +81,7 @@ def test_astype_array_d1(self) -> None:
8181

8282
self.assertEqual(a2.dtype, np.dtype(np.float64))
8383
self.assertEqual(a2.shape, (3,))
84-
self.assertFalse(a2.flags.writeable)
84+
self.assertTrue(a2.flags.writeable)
8585

8686

8787
def test_astype_array_d2(self) -> None:
@@ -90,7 +90,7 @@ def test_astype_array_d2(self) -> None:
9090

9191
self.assertEqual(a2.dtype, np.dtype(np.float64))
9292
self.assertEqual(a2.shape, (3,))
93-
self.assertFalse(a2.flags.writeable)
93+
self.assertTrue(a2.flags.writeable)
9494

9595

9696

@@ -100,7 +100,7 @@ def test_astype_array_d3(self) -> None:
100100

101101
self.assertEqual(a2.dtype, np.dtype(np.int64))
102102
self.assertEqual(a2.shape, (3,))
103-
self.assertFalse(a2.flags.writeable)
103+
self.assertTrue(a2.flags.writeable)
104104

105105
self.assertNotEqual(id(a1), id(a2))
106106

@@ -110,7 +110,7 @@ def test_astype_array_e(self) -> None:
110110
a2 = astype_array(a1, np.object_)
111111
self.assertEqual(a2.dtype, np.dtype(np.object_))
112112
self.assertEqual(a2.shape, (2, 3))
113-
self.assertFalse(a2.flags.writeable)
113+
self.assertTrue(a2.flags.writeable)
114114
self.assertEqual(
115115
list(list(a) for a in a2),
116116
[[np.datetime64('2021-01-01T00:00:00.000000000'),

0 commit comments

Comments
 (0)