|
15 | 15 | import tables |
16 | 16 | except ImportError: |
17 | 17 | tables = None |
| 18 | +import numpy as np |
18 | 19 | import warnings |
19 | 20 | from os.path import isfile |
20 | 21 | from pylearn2.compat import OrderedDict |
@@ -86,9 +87,7 @@ def __new__(cls, filename, X=None, topo_view=None, y=None, load_all=False, |
86 | 87 | return HDF5DatasetDeprecated(filename, X, topo_view, y, load_all, |
87 | 88 | cache_size, **kwargs) |
88 | 89 | else: |
89 | | - return super(HDF5Dataset, cls).__new__( |
90 | | - cls, filename, sources, spaces, aliases, load_all, cache_size, |
91 | | - use_h5py, **kwargs) |
| 90 | + return super(HDF5Dataset, cls).__new__(cls) |
92 | 91 |
|
93 | 92 | def __init__(self, filename, sources, spaces, aliases=None, load_all=False, |
94 | 93 | cache_size=None, use_h5py='auto', **kwargs): |
@@ -204,7 +203,7 @@ def iterator(self, mode=None, data_specs=None, batch_size=None, |
204 | 203 | provided when the dataset object has been created will be used. |
205 | 204 | """ |
206 | 205 | if data_specs is None: |
207 | | - data_specs = (self._get_sources, self._get_spaces) |
| 206 | + data_specs = (self._get_spaces(), self._get_sources()) |
208 | 207 |
|
209 | 208 | [mode, batch_size, num_batches, rng, data_specs] = self._init_iterator( |
210 | 209 | mode, batch_size, num_batches, rng, data_specs) |
@@ -240,7 +239,7 @@ def _get_spaces(self): |
240 | 239 | ------- |
241 | 240 | A Space or a list of Spaces. |
242 | 241 | """ |
243 | | - space = [self.spaces[s] for s in self._get_sources] |
| 242 | + space = [self.spaces[s] for s in self._get_sources()] |
244 | 243 | return space[0] if len(space) == 1 else tuple(space) |
245 | 244 |
|
246 | 245 | def get_data_specs(self, source_or_alias=None): |
@@ -310,16 +309,16 @@ def get(self, sources, indexes): |
310 | 309 | sources[s], *e.args)) |
311 | 310 | if (isinstance(indexes, (slice, py_integer_types)) or |
312 | 311 | len(indexes) == 1): |
313 | | - rval.append(sdata[indexes]) |
| 312 | + val = sdata[indexes] |
314 | 313 | else: |
315 | 314 | warnings.warn('Accessing non sequential elements of an ' |
316 | 315 | 'HDF5 file will be at best VERY slow. Avoid ' |
317 | 316 | 'using iteration schemes that access ' |
318 | 317 | 'random/shuffled data with hdf5 datasets!!') |
319 | | - val = [] |
320 | | - [val.append(sdata[idx]) for idx in indexes] |
321 | | - rval.append(val) |
322 | | - return tuple(rval) |
| 318 | + val = [sdata[idx] for idx in indexes] |
| 319 | + val = tuple(tuple(row) for row in val) |
| 320 | + rval.append(val) |
| 321 | + return [np.array(val) for val in rval] |
323 | 322 |
|
324 | 323 | @wraps(Dataset.get_num_examples, assigned=(), updated=()) |
325 | 324 | def get_num_examples(self, source_or_alias=None): |
|
0 commit comments