Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Commit

Permalink
Merge pull request #876 from marscher/fix_fragmented_reader_skip_hand…
Browse files Browse the repository at this point in the history
…ling

* [FragmentedReader|FeatureReader] fixed a bug in time-lagged access.
* [featurizer|md-iterload] cache mdtraj.Topology objects if they have the same input file.
  • Loading branch information
marscher authored Jul 21, 2016
2 parents 1355d1a + 27c42d8 commit 502e948
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 35 deletions.
22 changes: 19 additions & 3 deletions pyemma/coordinates/data/feature_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,25 @@ def __init__(self, data_source, skip=0, chunk=0, stride=1, return_trajindex=Fals
#self._cols = cols
self._create_mditer()

@DataSourceIterator.chunksize.setter
def chunksize(self, val):
self._mditer._chunksize = int(val)
@property
def chunksize(self):
return self.state.chunk

@chunksize.setter
def chunksize(self, value):
self.state.chunk = value
if hasattr(self, '_mditer'):
self._mditer._chunksize = int(value)

@property
def skip(self):
return self.state.skip

@skip.setter
def skip(self, value):
self.state.skip = value
if hasattr(self, '_mditer'):
self._mditer._skip = value

def close(self):
if hasattr(self, '_mditer') and self._mditer is not None:
Expand Down
8 changes: 6 additions & 2 deletions pyemma/coordinates/data/featurization/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

from .misc import CustomFeature
import numpy as np
from pyemma.coordinates.util.patches import load_topology_cached
from mdtraj import load_topology as load_topology_uncached


__author__ = 'Frank Noe, Martin Scherer'
Expand All @@ -38,18 +40,20 @@
class MDFeaturizer(Loggable):
r"""Extracts features from MD trajectories."""

def __init__(self, topfile):
def __init__(self, topfile, use_cache=True):
"""extracts features from MD trajectories.
Parameters
----------
topfile : str or mdtraj.Topology
a path to a topology file (pdb etc.) or an mdtraj Topology() object
use_cache : boolean, default=True
cache already loaded topologies, if file contents match.
"""
self.topologyfile = None
if isinstance(topfile, six.string_types):
self.topology = (mdtraj.load(topfile)).topology
self.topology = load_topology_cached(topfile) if use_cache else load_topology_uncached(topfile)
self.topologyfile = topfile
elif isinstance(topfile, mdtraj.Topology):
self.topology = topfile
Expand Down
16 changes: 12 additions & 4 deletions pyemma/coordinates/data/fragmented_trajectory_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pyemma.util.annotators import fix_docs


@fix_docs
class _FragmentedTrajectoryIterator(object):
def __init__(self, fragmented_reader, readers, chunksize, stride, skip):
# global time variable
Expand Down Expand Up @@ -123,6 +122,7 @@ def __next__(self):
)
else:
self._reader_it = self._readers[self._reader_at].iterator(self._stride, return_trajindex=False)
self._reader_it.skip = skip
self._reader_overlap = self._calculate_new_overlap(self._stride,
self._reader_lengths[self._reader_at - 1],
self._reader_overlap)
Expand Down Expand Up @@ -197,7 +197,7 @@ def _read_full(self, skip):
_skip = overlap
# if stride doesn't divide length, one has to offset the next trajectory
overlap = self._calculate_new_overlap(self._stride, self._reader_lengths[idx], overlap)
chunksize = min(length, r.trajectory_length(0, self._stride))
chunksize = min(length, r.trajectory_length(0, self._stride, skip=_skip))
it = r._create_iterator(stride=self._stride, skip=_skip, chunk=chunksize, return_trajindex=True)
with it:
for itraj, data in it:
Expand Down Expand Up @@ -263,6 +263,9 @@ def close(self):


class FragmentIterator(DataSourceIterator):
"""
outer iterator, which encapsulates _FragmentedTrajectoryIterator
"""

def __init__(self, data_source, skip=0, chunk=0, stride=1, return_trajindex=False, cols=None):
super(FragmentIterator, self).__init__(data_source, skip=skip, chunk=chunk,
Expand All @@ -285,7 +288,7 @@ def _next_chunk(self):
if X is None:
raise StopIteration()
self._t += len(X)
if self._t >= self._data_source.trajectory_length(self._itraj, stride=self.stride):
if self._t >= self._data_source.trajectory_length(self._itraj, stride=self.stride, skip=self.skip):
self._itraj += 1
self._it.close()
self._it = None
Expand All @@ -300,6 +303,7 @@ def close(self):
self._it.close()


@fix_docs
class FragmentedTrajectoryReader(DataSource):
"""
Parameters
Expand All @@ -308,7 +312,7 @@ class FragmentedTrajectoryReader(DataSource):
topologyfile, str, default None
chunksize: int, default 100
chunksize: int, default 1000
featurizer: MDFeaturizer, default None
Expand Down Expand Up @@ -367,6 +371,10 @@ def __init__(self, trajectories, topologyfile=None, chunksize=1000, featurizer=N
self._trajectories = trajectories
self._filenames = trajectories

# random-accessible
#self._is_random_accessible = all(r._is_random_accessible for r in self._readers[itraj]
# for itraj in range(0, self._ntraj))

@property
def filenames_flat(self):
flat_readers = itertools.chain.from_iterable(self._readers)
Expand Down
4 changes: 2 additions & 2 deletions pyemma/coordinates/data/util/traj_info_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def set(self, traj_info):
"VALUES (?, ?, ?, ?, ?, ?, ?)", values)
try:
self._database.execute(*statement)
except sqlite3.IntegrityError:
logger.exception()
except sqlite3.IntegrityError as ie:
logger.exception("insert failed: %s " % ie)
return
self._database.commit()

Expand Down
75 changes: 51 additions & 24 deletions pyemma/coordinates/util/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@

from __future__ import absolute_import
import numpy as np
from mdtraj import Topology
from mdtraj.utils.validation import cast_indices
from mdtraj.core.trajectory import load, _parse_topology, _TOPOLOGY_EXTS, _get_extension, open
from mdtraj.core.trajectory import load, _TOPOLOGY_EXTS, _get_extension, open as md_open, load_topology

from itertools import groupby
from operator import itemgetter
Expand All @@ -34,6 +35,30 @@
from six.moves import map


def _cache_mdtraj_topology(args):
import hashlib
from mdtraj import load_topology as md_load_topology
_top_cache = {}

def wrap(top_file):
if isinstance(top_file, Topology):
return top_file
hasher = hashlib.md5()
with open(top_file, 'rb') as f:
hasher.update(f.read())
hash = hasher.hexdigest()

if hash in _top_cache:
top = _top_cache[hash]
else:
top = md_load_topology(top_file)
_top_cache[hash] = top
return top
return wrap

load_topology_cached = _cache_mdtraj_topology(load_topology)


class iterload(object):

def __init__(self, filename, chunk=100, **kwargs):
Expand Down Expand Up @@ -90,46 +115,47 @@ def __init__(self, filename, chunk=100, **kwargs):
self._chunksize = chunk
self._extension = _get_extension(self._filename)
self._closed = False
self._seeked = False
if self._extension not in _TOPOLOGY_EXTS:
self._topology = _parse_topology(self._top)
self._topology = load_topology_cached(self._top)
else:
self._topology = self._top

self._mode = None
if self._chunksize > 0 and self._extension in ('.pdb', '.pdb.gz'):
self._mode = 'pdb'
self._t = load(self._filename, stride=self._stride, atom_indices=self._atom_indices)
self._i = 0
elif isinstance(self._stride, np.ndarray):
if isinstance(self._stride, np.ndarray):
self._mode = 'random_access'
self._f = (lambda x:
open(x, n_atoms=self._topology.n_atoms)
md_open(x, n_atoms=self._topology.n_atoms)
if self._extension in ('.crd', '.mdcrd')
else open(self._filename))(self._filename)
else md_open(self._filename))(self._filename)
self._ra_it = self._random_access_generator(self._f)
else:
self._mode = 'traj'
self._f = (
lambda x: open(x, n_atoms=self._topology.n_atoms)
lambda x: md_open(x, n_atoms=self._topology.n_atoms)
if self._extension in ('.crd', '.mdcrd')
else open(self._filename)
else md_open(self._filename)
)(self._filename)

# offset array handling
offsets = kwargs.pop('offsets', None)
if hasattr(self._f, 'offsets') and offsets is not None:
self._f.offsets = offsets

if self._skip > 0:
self._f.seek(self._skip)
@property
def skip(self):
return self._skip

@skip.setter
def skip(self, value):
assert self._mode == 'traj'
self._skip = value

def __iter__(self):
return self

def close(self):
if hasattr(self, '_t'):
self._t.close()
elif hasattr(self, '_f'):
if hasattr(self, '_f'):
self._f.close()
self._closed = True

Expand All @@ -139,28 +165,29 @@ def __next__(self):
def next(self):
if self._closed:
raise StopIteration()

# apply skip offset only once.
# (we want to do this here, since we want to be able to re-set self.skip)
if not self._seeked:
self._f.seek(self.skip)
self._seeked = True

if not isinstance(self._stride, np.ndarray) and self._chunksize == 0:
# If chunk was 0 then we want to avoid filetype-specific code
# in case of undefined behavior in various file parsers.
# TODO: this will first apply stride, then skip!
if self._extension not in _TOPOLOGY_EXTS:
self._kwargs['top'] = self._top
return load(self._filename, stride=self._stride, **self._kwargs)[self._skip:]
elif self._mode is 'pdb':
# the PDBTrajectortFile class doesn't follow the standard API. Fixing it
# to support iterload could be worthwhile, but requires a deep refactor.
X = self._t[self._i:self._i+self._chunksize]
self._i += self._chunksize
return X
elif isinstance(self._stride, np.ndarray):
return next(self._ra_it)
else:
if self._extension not in _TOPOLOGY_EXTS:
traj = self._f.read_as_traj(self._topology, n_frames=self._chunksize*self._stride,
stride=self._stride, atom_indices=self._atom_indices, **self._kwargs)
stride=self._stride, atom_indices=self._atom_indices, **self._kwargs)
else:
traj = self._f.read_as_traj(n_frames=self._chunksize*self._stride,
stride=self._stride, atom_indices=self._atom_indices, **self._kwargs)
stride=self._stride, atom_indices=self._atom_indices, **self._kwargs)

if len(traj) == 0:
raise StopIteration()
Expand Down

0 comments on commit 502e948

Please sign in to comment.