1- from typing import TYPE_CHECKING , Hashable , Iterable , Optional , Union , overload
1+ from typing import TYPE_CHECKING , Generic , Hashable , Iterable , Optional , TypeVar , Union
22
33from . import duck_array_ops
44from .computation import dot
5- from .options import _get_keep_attrs
65from .pycompat import is_duck_dask_array
76
87if TYPE_CHECKING :
8+ from .common import DataWithCoords # noqa: F401
99 from .dataarray import DataArray , Dataset
1010
11+ T_DataWithCoords = TypeVar ("T_DataWithCoords" , bound = "DataWithCoords" )
12+
13+
1114_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
1215 Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
1316
5659 """
5760
5861
59- class Weighted :
62+ class Weighted ( Generic [ T_DataWithCoords ]) :
6063 """An object that implements weighted operations.
6164
6265 You should create a Weighted object by using the ``DataArray.weighted`` or
@@ -70,15 +73,7 @@ class Weighted:
7073
7174 __slots__ = ("obj" , "weights" )
7275
73- @overload
74- def __init__ (self , obj : "DataArray" , weights : "DataArray" ) -> None :
75- ...
76-
77- @overload
78- def __init__ (self , obj : "Dataset" , weights : "DataArray" ) -> None :
79- ...
80-
81- def __init__ (self , obj , weights ):
76+ def __init__ (self , obj : T_DataWithCoords , weights : "DataArray" ):
8277 """
8378 Create a Weighted object
8479
@@ -121,8 +116,8 @@ def _weight_check(w):
121116 else :
122117 _weight_check (weights .data )
123118
124- self .obj = obj
125- self .weights = weights
119+ self .obj : T_DataWithCoords = obj
120+ self .weights : "DataArray" = weights
126121
127122 @staticmethod
128123 def _reduce (
@@ -146,7 +141,6 @@ def _reduce(
146141
147142 # `dot` does not broadcast arrays, so this avoids creating a large
148143 # DataArray (if `weights` has additional dimensions)
149- # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
150144 return dot (da , weights , dims = dim )
151145
152146 def _sum_of_weights (
@@ -203,7 +197,7 @@ def sum_of_weights(
203197 self ,
204198 dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
205199 keep_attrs : Optional [bool ] = None ,
206- ) -> Union [ "DataArray" , "Dataset" ] :
200+ ) -> T_DataWithCoords :
207201
208202 return self ._implementation (
209203 self ._sum_of_weights , dim = dim , keep_attrs = keep_attrs
@@ -214,7 +208,7 @@ def sum(
214208 dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
215209 skipna : Optional [bool ] = None ,
216210 keep_attrs : Optional [bool ] = None ,
217- ) -> Union [ "DataArray" , "Dataset" ] :
211+ ) -> T_DataWithCoords :
218212
219213 return self ._implementation (
220214 self ._weighted_sum , dim = dim , skipna = skipna , keep_attrs = keep_attrs
@@ -225,7 +219,7 @@ def mean(
225219 dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
226220 skipna : Optional [bool ] = None ,
227221 keep_attrs : Optional [bool ] = None ,
228- ) -> Union [ "DataArray" , "Dataset" ] :
222+ ) -> T_DataWithCoords :
229223
230224 return self ._implementation (
231225 self ._weighted_mean , dim = dim , skipna = skipna , keep_attrs = keep_attrs
@@ -239,22 +233,15 @@ def __repr__(self):
239233 return f"{ klass } with weights along dimensions: { weight_dims } "
240234
241235
242- class DataArrayWeighted (Weighted ):
243- def _implementation (self , func , dim , ** kwargs ):
244-
245- keep_attrs = kwargs .pop ("keep_attrs" )
246- if keep_attrs is None :
247- keep_attrs = _get_keep_attrs (default = False )
248-
249- weighted = func (self .obj , dim = dim , ** kwargs )
250-
251- if keep_attrs :
252- weighted .attrs = self .obj .attrs
236+ class DataArrayWeighted (Weighted ["DataArray" ]):
237+ def _implementation (self , func , dim , ** kwargs ) -> "DataArray" :
253238
254- return weighted
239+ dataset = self .obj ._to_temp_dataset ()
240+ dataset = dataset .map (func , dim = dim , ** kwargs )
241+ return self .obj ._from_temp_dataset (dataset )
255242
256243
257- class DatasetWeighted (Weighted ):
244+ class DatasetWeighted (Weighted [ "Dataset" ] ):
258245 def _implementation (self , func , dim , ** kwargs ) -> "Dataset" :
259246
260247 return self .obj .map (func , dim = dim , ** kwargs )
0 commit comments