-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdatasets.py
400 lines (353 loc) · 16.5 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
# -*- coding: utf-8 -*-
import logging
from typing import Tuple, List, Dict, Optional, Iterable, Callable, Sequence, cast
from datetime import datetime
from dateutil.parser import isoparse
from functools import wraps
from typing import Union
import pandas as pd
import numpy as np
from .data_provider.providers import RandomDataProvider, DataLakeProvider
from .exceptions import InsufficientDataError
from .base import GordoBaseDataset, ConfigurationError
from .data_provider.base import GordoBaseDataProvider
from .filter_rows import pandas_filter_rows, parse_pandas_filter_vars
from .filter_periods import FilterPeriods
from .sensor_tag import SensorTag
from .sensor_tag import normalize_sensor_tags
from .utils import capture_args, join_timeseries
from .validators import (
ValidTagList,
ValidDatetime,
ValidDatasetKwargs,
ValidDataProvider,
)
logger = logging.getLogger(__name__)
def compat(init):
"""
__init__ decorator for compatibility where the Gordo config file's ``dataset`` keys have
drifted from what kwargs are actually expected in the given dataset. For example,
using `train_start_date` is common in the configs, but :class:`~TimeSeriesDataset`
takes this parameter as ``train_start_date``, as well as :class:`~RandomDataset`
Renames old/other acceptable kwargs to the ones that the dataset type expects
"""
@wraps(init)
def wrapper(*args, **kwargs):
renamings = {
"from_ts": "train_start_date",
"to_ts": "train_end_date",
"tags": "tag_list",
}
for old, new in renamings.items():
if old in kwargs:
kwargs[new] = kwargs.pop(old)
return init(*args, **kwargs)
return wrapper
TagList = List[Union[Dict, str, SensorTag]]
class TimeSeriesDataset(GordoBaseDataset):
train_start_date = ValidDatetime()
train_end_date = ValidDatetime()
tag_list = ValidTagList()
target_tag_list = ValidTagList()
data_provider = ValidDataProvider()
kwargs = ValidDatasetKwargs()
TAG_NORMALIZERS = {"default": normalize_sensor_tags}
@staticmethod
def create_default_data_provider() -> GordoBaseDataProvider:
return DataLakeProvider()
@staticmethod
def tag_normalizer(
sensors: TagList,
asset: str = None,
default_asset: str = None,
) -> List[SensorTag]:
"""
Converts a list of sensors in different formats, into a list of SensorTag elements.
This function might be useful for overwriting in the extended class
"""
return normalize_sensor_tags(sensors, asset, default_asset)
@compat
@capture_args
def __init__(
self,
train_start_date: Union[datetime, str],
train_end_date: Union[datetime, str],
tag_list: Sequence[Union[str, Dict, SensorTag]],
target_tag_list: Optional[Sequence[Union[str, Dict, SensorTag]]] = None,
data_provider: Optional[Union[GordoBaseDataProvider, dict]] = None,
resolution: Optional[str] = "10T",
row_filter: Union[str, list] = "",
known_filter_periods: Optional[list] = None,
aggregation_methods: Union[str, List[str], Callable] = "mean",
row_filter_buffer_size: int = 0,
asset: Optional[str] = None,
default_asset: Optional[str] = None,
n_samples_threshold: int = 0,
low_threshold: Optional[int] = -1000,
high_threshold: Optional[int] = 50000,
interpolation_method: str = "linear_interpolation",
interpolation_limit: str = "8H",
filter_periods: Optional[dict] = None,
process_metadata: bool = True,
):
"""
Creates a TimeSeriesDataset backed by a provided dataprovider.
A TimeSeriesDataset is a dataset backed by timeseries, but resampled,
aligned, and (optionally) filtered.
Parameters
----------
train_start_date: Union[datetime, str]
Earliest possible point in the dataset (inclusive)
train_end_date: Union[datetime, str]
Earliest possible point in the dataset (exclusive)
tag_list: Sequence[Union[str, Dict, sensor_tag.SensorTag]]
List of tags to include in the dataset. The elements can be strings,
dictionaries or SensorTag namedtuples.
target_tag_list: Sequence[List[Union[str, Dict, sensor_tag.SensorTag]]]
List of tags to set as the dataset y. These will be treated the same as
tag_list when fetching and pre-processing (resampling) but will be split
into the y return from ``.get_data()``
data_provider: Union[GordoBaseDataProvider, dict]
A dataprovider which can provide dataframes for tags from train_start_date to train_end_date
of which can also be a config definition from a data provider's ``.to_dict()`` method.
resolution: Optional[str]
The bucket size for grouping all incoming time data (e.g. "10T").
Available strings come from https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#dateoffset-objects
**Note**: If this parameter is ``None`` or ``False``, then _no_ aggregation/resampling is applied to the data.
row_filter: str or list
Filter on the rows. Only rows satisfying the filter will be in the dataset.
See :func:`gordo_dataset.filter_rows.pandas_filter_rows` for
further documentation of the filter format.
known_filter_periods: list
List of periods to drop in the format [~('2020-04-08 04:00:00+00:00' < index < '2020-04-08 10:00:00+00:00')].
Note the time-zone suffix (+00:00), which is required.
aggregation_methods
Aggregation method(s) to use for the resampled buckets. If a single
resample method is provided then the resulting dataframe will have names
identical to the names of the series it got in. If several
aggregation-methods are provided then the resulting dataframe will
have a multi-level column index, with the series-name as the first level,
and the aggregation method as the second level.
See :py:func::`pandas.core.resample.Resampler#aggregate` for more
information on possible aggregation methods.
row_filter_buffer_size: int
Whatever elements are selected for removal based on the ``row_filter``, will also
have this amount of elements removed fore and aft.
Default is zero 0
asset: Optional[str]
Asset for which the tags are associated with.
default_asset: Optional[str]
Asset which will be used if `asset` is not provided and the tag is not
resolvable to a specific asset.
n_samples_threshold: int = 0
The threshold at which the generated DataFrame is considered to have too few rows of data.
interpolation_method: str
How should missing values be interpolated. Either forward fill (`ffill`) or by linear
interpolation (default, `linear_interpolation`).
interpolation_limit: str
Parameter sets how long from last valid data point values will be interpolated/forward filled.
Default is eight hours (`8H`).
If None, all missing values are interpolated/forward filled.
fiter_periods: dict
Performs a series of algorithms that drops noisy data is specified.
See `filter_periods` class for details.
process_metadata: bool
Processing metadata if true
"""
self.train_start_date = self._validate_dt(train_start_date)
self.train_end_date = self._validate_dt(train_end_date)
if self.train_start_date >= self.train_end_date:
raise ValueError(
f"train_end_date ({self.train_end_date}) must be after train_start_date ({self.train_start_date})"
)
self.asset = asset
self.default_asset = default_asset
self.tag_list = self.tag_normalizer(list(tag_list), asset, default_asset)
self.target_tag_list = (
self.tag_normalizer(list(target_tag_list), asset, default_asset)
if target_tag_list
else self.tag_list.copy()
)
self.resolution = resolution
if data_provider is None:
data_provider = self.create_default_data_provider()
self.data_provider = (
data_provider
if not isinstance(data_provider, dict)
else GordoBaseDataProvider.from_dict(data_provider)
)
self.row_filter = row_filter
self.aggregation_methods = aggregation_methods
self.row_filter_buffer_size = row_filter_buffer_size
self.n_samples_threshold = n_samples_threshold
self.low_threshold = low_threshold
self.high_threshold = high_threshold
self.interpolation_method = interpolation_method
self.interpolation_limit = interpolation_limit
self.filter_periods = (
FilterPeriods(granularity=self.resolution, **filter_periods)
if filter_periods
else None
)
self.known_filter_periods = (
known_filter_periods if known_filter_periods is not None else []
)
self.process_metadata = process_metadata
if not self.train_start_date.tzinfo or not self.train_end_date.tzinfo:
raise ValueError(
f"Timestamps ({self.train_start_date}, {self.train_end_date}) need to include timezone "
f"information"
)
super().__init__()
def to_dict(self):
params = super().to_dict()
to_str = lambda dt: str(dt) if not hasattr(dt, "isoformat") else dt.isoformat()
params["train_start_date"] = to_str(params["train_start_date"])
params["train_end_date"] = to_str(params["train_end_date"])
return params
@staticmethod
def _validate_dt(dt: Union[str, datetime]) -> datetime:
dt = dt if isinstance(dt, datetime) else isoparse(dt)
if dt.tzinfo is None:
raise ValueError(
"Must provide an ISO formatted datetime string with timezone information"
)
return dt
def get_data(self) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
tag_list = set(self.tag_list + self.target_tag_list)
triggered_tags = set()
if self.row_filter:
pandas_filter_tags = set(
self.tag_normalizer(
cast(TagList, parse_pandas_filter_vars(self.row_filter)),
self.asset,
self.default_asset,
)
)
triggered_tags = pandas_filter_tags.difference(tag_list)
tag_list.update(triggered_tags)
series_iter: Iterable[pd.Series] = self.data_provider.load_series(
train_start_date=self.train_start_date,
train_end_date=self.train_end_date,
tag_list=list(tag_list),
resolution=self.resolution,
)
# Resample if we have a resolution set, otherwise simply join the series.
if self.resolution:
data, metadata = join_timeseries(
series_iter,
self.train_start_date,
self.train_end_date,
self.resolution,
aggregation_methods=self.aggregation_methods,
interpolation_method=self.interpolation_method,
interpolation_limit=self.interpolation_limit,
)
if self.process_metadata:
self._metadata["tag_loading_metadata"] = metadata
else:
data = pd.concat(series_iter, axis=1, join="inner")
if len(data) <= self.n_samples_threshold:
raise InsufficientDataError(
f"The length of the generated DataFrame ({len(data)}) does not exceed the "
f"specified required threshold for number of rows ({self.n_samples_threshold})."
)
if self.known_filter_periods:
data = pandas_filter_rows(data, self.known_filter_periods, buffer_size=0)
if len(data) <= self.n_samples_threshold:
raise InsufficientDataError(
f"The length of the filtered DataFrame ({len(data)}) does not exceed the "
f"specified required threshold for number of rows ({self.n_samples_threshold})"
f" after dropping known periods."
)
if self.row_filter:
data = pandas_filter_rows(
data, self.row_filter, buffer_size=self.row_filter_buffer_size
)
if len(data) <= self.n_samples_threshold:
raise InsufficientDataError(
f"The length of the filtered DataFrame ({len(data)}) does not exceed the "
f"specified required threshold for the number of rows ({self.n_samples_threshold}), "
f" after applying the specified numerical row-filter."
)
if triggered_tags:
triggered_columns = [tag.name for tag in triggered_tags]
data = data.drop(columns=triggered_columns)
if isinstance(self.low_threshold, int) and isinstance(self.high_threshold, int):
if self.low_threshold >= self.high_threshold:
raise ConfigurationError(
"Low threshold need to be larger than high threshold"
)
logger.info("Applying global min/max filtering")
mask = ((data > self.low_threshold) & (data < self.high_threshold)).all(1)
data = data[mask]
logger.info("Shape of data after global min/max filtering: %s", data.shape)
if len(data) <= self.n_samples_threshold:
raise InsufficientDataError(
f"The length of the filtered DataFrame ({len(data)}) does not exceed the "
f"specified required threshold for number of rows ({self.n_samples_threshold})"
f" after filtering global extrema."
)
if self.filter_periods:
data, drop_periods, _ = self.filter_periods.filter_data(data)
if self.process_metadata:
self._metadata["filtered_periods"] = drop_periods
if len(data) <= self.n_samples_threshold:
raise InsufficientDataError(
f"The length of the filtered DataFrame ({len(data)}) does not exceed the "
f"specified required threshold for number of rows ({self.n_samples_threshold})"
f" after applying nuisance filtering algorithm."
)
x_tag_names = [tag.name for tag in self.tag_list]
y_tag_names = [tag.name for tag in self.target_tag_list]
X = data[x_tag_names]
y = data[y_tag_names] if self.target_tag_list else None
if self.process_metadata:
if X.first_valid_index():
self._metadata["train_start_date_actual"] = X.index[0]
self._metadata["train_end_date_actual"] = X.index[-1]
self._metadata["summary_statistics"] = X.describe().to_dict()
hists = dict()
for tag in X.columns:
step = round((X[tag].max() - X[tag].min()) / 100, 6)
if step < 9e-07:
hists[str(tag)] = "{}"
continue
outs = pd.cut(
X[tag],
bins=np.arange(
round(X[tag].min() - step, 6),
round(X[tag].max() + step, 6),
step,
),
retbins=False,
)
hists[str(tag)] = (
outs.value_counts().sort_index().to_json(orient="index")
)
self._metadata["x_hist"] = hists
return X, y
def get_metadata(self):
return self._metadata.copy()
class RandomDataset(TimeSeriesDataset):
"""
Get a TimeSeriesDataset backed by
gordo_dataset.data_provider.providers.RandomDataProvider
"""
@compat
@capture_args
def __init__(
self,
train_start_date: Union[datetime, str],
train_end_date: Union[datetime, str],
tag_list: list,
**kwargs,
):
kwargs.pop("data_provider", None) # Don't care what you ask for, you get random
super().__init__(
data_provider=RandomDataProvider(),
train_start_date=train_start_date,
train_end_date=train_end_date,
tag_list=tag_list,
**kwargs,
)