diff --git a/traittypes/traittypes.py b/traittypes/traittypes.py index 516a91b..0346c15 100644 --- a/traittypes/traittypes.py +++ b/traittypes/traittypes.py @@ -213,6 +213,7 @@ class XarrayType(SciType): info_text = 'an xarray dataset or dataarray' klass = None + dtype = None def validate(self, obj, value): if value is None and not self.allow_none: @@ -234,7 +235,7 @@ def set(self, obj, value): not old_value.equals(new_value)): obj._notify_trait(self.name, old_value, new_value) - def __init__(self, default_value=Empty, allow_none=False, klass=None, **kwargs): + def __init__(self, default_value=Empty, allow_none=False, klass=None, dtype=None, **kwargs): if klass is None: klass = self.klass if (klass is not None) and inspect.isclass(klass): @@ -246,6 +247,8 @@ def __init__(self, default_value=Empty, allow_none=False, klass=None, **kwargs): default_value = klass() elif default_value is not None and default_value is not Undefined: default_value = klass(default_value) + + self.dtype = dtype super(XarrayType, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs) def make_dynamic_default(self): @@ -274,7 +277,6 @@ class DataArray(XarrayType): """An xarray dataarray trait type.""" info_text = 'an xarray dataarray' - dtype = None def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs): if 'klass' not in kwargs and self.klass is None: @@ -282,4 +284,3 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs): kwargs['klass'] = xr.DataArray super(DataArray, self).__init__( default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs) - self.dtype = dtype