Skip to content

Commit cf6de58

Browse files
theaveyThomas Heavey
and
Thomas Heavey
authored
BUG: Series.dot for arrow and nullable dtypes returns object-dtyped series (#61376)
Co-authored-by: Thomas Heavey <[email protected]>
1 parent 65bf9cd commit cf6de58

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ Numeric
710710
^^^^^^^
711711
- Bug in :meth:`DataFrame.corr` where numerical precision errors resulted in correlations above ``1.0`` (:issue:`61120`)
712712
- Bug in :meth:`DataFrame.quantile` where the column type was not preserved when ``numeric_only=True`` with a list-like ``q`` produced an empty result (:issue:`59035`)
713+
- Bug in :meth:`Series.dot` returning ``object`` dtype for :class:`ArrowDtype` and nullable-dtype data (:issue:`61375`)
713714
- Bug in ``np.matmul`` with :class:`Index` inputs raising a ``TypeError`` (:issue:`57079`)
714715

715716
Conversion

pandas/core/series.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2951,8 +2951,9 @@ def dot(self, other: AnyArrayLike | DataFrame) -> Series | np.ndarray:
29512951
)
29522952

29532953
if isinstance(other, ABCDataFrame):
2954+
common_type = find_common_type([self.dtypes] + list(other.dtypes))
29542955
return self._constructor(
2955-
np.dot(lvals, rvals), index=other.columns, copy=False
2956+
np.dot(lvals, rvals), index=other.columns, copy=False, dtype=common_type
29562957
).__finalize__(self, method="dot")
29572958
elif isinstance(other, Series):
29582959
return np.dot(lvals, rvals)

pandas/tests/frame/methods/test_dot.py

+16
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,19 @@ def test_arrow_dtype(dtype, exp_dtype):
153153
expected = DataFrame([[1, 2], [3, 4], [5, 6]], dtype=exp_dtype)
154154

155155
tm.assert_frame_equal(result, expected)
156+
157+
158+
@pytest.mark.parametrize(
159+
"dtype,exp_dtype",
160+
[("Float32", "Float64"), ("Int16", "Int32"), ("float[pyarrow]", "double[pyarrow]")],
161+
)
162+
def test_arrow_dtype_series(dtype, exp_dtype):
163+
pytest.importorskip("pyarrow")
164+
165+
cols = ["a", "b"]
166+
series_a = Series([1, 2], index=cols, dtype="int32")
167+
df_b = DataFrame([[1, 0], [0, 1]], index=cols, dtype=dtype)
168+
result = series_a.dot(df_b)
169+
expected = Series([1, 2], dtype=exp_dtype)
170+
171+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)