Skip to content

Commit c79c210

Browse files
committed
Simplified test case for mul*array with fill_value
1 parent a321daf commit c79c210

File tree

1 file changed

+13
-30
lines changed

1 file changed

+13
-30
lines changed

pandas/tests/frame/test_arithmetic.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,41 +2256,24 @@ def test_mixed_col_index_dtype(string_dtype_no_object):
22562256

22572257

22582258
@pytest.mark.parametrize(
2259-
"data_type,fill_val, axis",
2259+
"dtype,fill_val, axis",
22602260
[(dt, val, axis) for axis in axes for dt, val in dt_params],
22612261
)
2262-
def test_df_mul_array_fill_value(data_type, fill_val, axis):
2262+
def test_df_mul_array_fill_value(dtype, fill_val, axis):
22632263
# GH 61581
2264-
base_data = np.arange(12).reshape(4, 3)
2265-
df = DataFrame(base_data)
2266-
mult_list = [np.nan, 1, 5, np.nan]
2267-
mult_list = mult_list[: df.shape[axis]]
2268-
2269-
if data_type in tm.ALL_INT_NUMPY_DTYPES:
2264+
if dtype == tm.ALL_INT_NUMPY_DTYPES[0]:
22702265
# Numpy int type cannot represent NaN
2271-
mult_np = np.asarray(mult_list)
2272-
mult_list = np.nan_to_num(mult_np, nan=fill_val)
2273-
2274-
mult_data = pd.array(mult_list, dtype=data_type)
2266+
safe_null = fill_val
2267+
else:
2268+
safe_null = np.nan
22752269

2276-
for i in range(df.shape[0]):
2277-
try:
2278-
df.iat[i, i] = np.nan
2279-
df.iat[i + 2, i] = pd.NA
2280-
except IndexError:
2281-
pass
2270+
df = DataFrame([[safe_null, 1, 2], [3, safe_null, 5]], dtype=dtype)
22822271

2283-
if axis == 0:
2284-
mult_mat = np.broadcast_to(mult_data.reshape(-1, 1), df.shape)
2285-
mask = np.isnan(mult_mat)
2286-
else:
2287-
mult_mat = np.broadcast_to(mult_data.reshape(1, -1), df.shape)
2288-
mask = np.isnan(mult_mat)
2289-
mask = df.isna().values & mask
2272+
mult = pd.array([safe_null, 1.0], dtype=dtype)
22902273

2291-
df_result = df.mul(mult_data, axis=axis, fill_value=fill_val)
2292-
df_expected = (df.fillna(fill_val).mul(mult_data.fillna(fill_val), axis=axis)).mask(
2293-
mask, np.nan
2294-
)
2274+
result = df.mul(mult, axis=0, fill_value=fill_val)
2275+
expected = DataFrame(
2276+
[[safe_null * safe_null, fill_val, fill_val * 2], [3.0, fill_val, 5.0]]
2277+
).astype(dtype)
22952278

2296-
tm.assert_frame_equal(df_result, df_expected)
2279+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)