@@ -2256,41 +2256,24 @@ def test_mixed_col_index_dtype(string_dtype_no_object):
22562256
22572257
22582258@pytest .mark .parametrize (
2259- "data_type ,fill_val, axis" ,
2259+ "dtype ,fill_val, axis" ,
22602260 [(dt , val , axis ) for axis in axes for dt , val in dt_params ],
22612261)
2262- def test_df_mul_array_fill_value (data_type , fill_val , axis ):
2262+ def test_df_mul_array_fill_value (dtype , fill_val , axis ):
22632263 # GH 61581
2264- base_data = np .arange (12 ).reshape (4 , 3 )
2265- df = DataFrame (base_data )
2266- mult_list = [np .nan , 1 , 5 , np .nan ]
2267- mult_list = mult_list [: df .shape [axis ]]
2268-
2269- if data_type in tm .ALL_INT_NUMPY_DTYPES :
2264+ if dtype == tm .ALL_INT_NUMPY_DTYPES [0 ]:
22702265 # Numpy int type cannot represent NaN
2271- mult_np = np .asarray (mult_list )
2272- mult_list = np .nan_to_num (mult_np , nan = fill_val )
2273-
2274- mult_data = pd .array (mult_list , dtype = data_type )
2266+ safe_null = fill_val
2267+ else :
2268+ safe_null = np .nan
22752269
2276- for i in range (df .shape [0 ]):
2277- try :
2278- df .iat [i , i ] = np .nan
2279- df .iat [i + 2 , i ] = pd .NA
2280- except IndexError :
2281- pass
2270+ df = DataFrame ([[safe_null , 1 , 2 ], [3 , safe_null , 5 ]], dtype = dtype )
22822271
2283- if axis == 0 :
2284- mult_mat = np .broadcast_to (mult_data .reshape (- 1 , 1 ), df .shape )
2285- mask = np .isnan (mult_mat )
2286- else :
2287- mult_mat = np .broadcast_to (mult_data .reshape (1 , - 1 ), df .shape )
2288- mask = np .isnan (mult_mat )
2289- mask = df .isna ().values & mask
2272+ mult = pd .array ([safe_null , 1.0 ], dtype = dtype )
22902273
2291- df_result = df .mul (mult_data , axis = axis , fill_value = fill_val )
2292- df_expected = ( df . fillna ( fill_val ). mul ( mult_data . fillna ( fill_val ), axis = axis )). mask (
2293- mask , np . nan
2294- )
2274+ result = df .mul (mult , axis = 0 , fill_value = fill_val )
2275+ expected = DataFrame (
2276+ [[ safe_null * safe_null , fill_val , fill_val * 2 ], [ 3.0 , fill_val , 5.0 ]]
2277+ ). astype ( dtype )
22952278
2296- tm .assert_frame_equal (df_result , df_expected )
2279+ tm .assert_frame_equal (result , expected )
0 commit comments