1
+ import tempfile
1
2
import unittest
3
+ import shutil
4
+ import os
5
+ from glob import glob
2
6
import numpy as np
3
7
4
8
from pyemma .coordinates .data import DataInMemory
5
9
from pyemma .util .contexts import settings
6
10
from pyemma .util .files import TemporaryDirectory
7
- import os
8
- from glob import glob
9
-
10
11
11
12
12
13
class TestCoordinatesIterator (unittest .TestCase ):
@@ -15,6 +16,12 @@ class TestCoordinatesIterator(unittest.TestCase):
15
16
def setUpClass (cls ):
16
17
cls .d = [np .random .random ((100 , 3 )) for _ in range (3 )]
17
18
19
+ def setUp (self ):
20
+ self .tempdir = tempfile .mktemp ()
21
+
22
+ def tearDown (self ):
23
+ shutil .rmtree (self .tempdir , ignore_errors = True )
24
+
18
25
def test_current_trajindex (self ):
19
26
r = DataInMemory (self .d )
20
27
expected_itraj = 0
@@ -273,5 +280,67 @@ def test_invalid_data_in_input_inf(self):
273
280
for itraj , X in it :
274
281
pass
275
282
283
+ def test_lagged_iterator (self ):
284
+ import pyemma .coordinates as coor
285
+ from pyemma .coordinates .tests .util import create_traj , get_top
286
+
287
+ trajectory_length = 4720
288
+ lagtime = 1000
289
+ n_trajs = 15
290
+
291
+ top = get_top ()
292
+ trajs_data = [create_traj (top = top , length = trajectory_length ) for _ in range (n_trajs )]
293
+ trajs = [t [0 ] for t in trajs_data ]
294
+ xyzs = [t [1 ].reshape (- 1 , 9 ) for t in trajs_data ]
295
+
296
+ reader = coor .source (trajs , top = top , chunksize = 5000 )
297
+
298
+ for chunk in [None , 0 , trajectory_length , trajectory_length + 1 , trajectory_length + 1000 ]:
299
+ it = reader .iterator (lag = lagtime , chunk = chunk , return_trajindex = True )
300
+ with it :
301
+ for itraj , X , Y in it :
302
+ np .testing .assert_equal (X .shape , Y .shape )
303
+ np .testing .assert_equal (X .shape [0 ], trajectory_length - lagtime )
304
+ np .testing .assert_array_almost_equal (X , xyzs [itraj ][:trajectory_length - lagtime ])
305
+ np .testing .assert_array_almost_equal (Y , xyzs [itraj ][lagtime :])
306
+
307
+ def test_lagged_iterator_optimized (self ):
308
+ import pyemma .coordinates as coor
309
+ from pyemma .coordinates .tests .util import create_traj , get_top
310
+ from pyemma .coordinates .util .patches import iterload
311
+
312
+ trajectory_length = 4720
313
+ lagtime = 20
314
+ n_trajs = 15
315
+ stride = iterload .MAX_STRIDE_SWITCH_TO_RA + 1
316
+
317
+ top = get_top ()
318
+ trajs_data = [create_traj (top = top , length = trajectory_length ) for _ in range (n_trajs )]
319
+ trajs = [t [0 ] for t in trajs_data ]
320
+ xyzs = [t [1 ].reshape (- 1 , 9 )[::stride ] for t in trajs_data ]
321
+ xyzs_lagged = [t [1 ].reshape (- 1 , 9 )[lagtime ::stride ] for t in trajs_data ]
322
+
323
+ reader = coor .source (trajs , stride = stride , top = top , chunksize = 5000 )
324
+
325
+ memory_cutoff = iterload .MEMORY_CUTOFF
326
+ try :
327
+ iterload .MEMORY_CUTOFF = 8
328
+ it = reader .iterator (stride = stride , lag = lagtime , chunk = 5000 , return_trajindex = True )
329
+ with it :
330
+ curr_itraj = 0
331
+ t = 0
332
+ for itraj , X , Y in it :
333
+ if itraj != curr_itraj :
334
+ curr_itraj = itraj
335
+ t = 0
336
+ np .testing .assert_equal (X .shape , Y .shape )
337
+ l = len (X )
338
+ np .testing .assert_array_almost_equal (X , xyzs [itraj ][t :t + l ])
339
+ np .testing .assert_array_almost_equal (Y , xyzs_lagged [itraj ][t :t + l ])
340
+ t += l
341
+ finally :
342
+ iterload .MEMORY_CUTOFF = memory_cutoff
343
+
344
+
276
345
if __name__ == '__main__' :
277
346
unittest .main ()
0 commit comments