From a16d8a4feaf5263d3a21e81bc4217659f776f74b Mon Sep 17 00:00:00 2001 From: Aaron Virshup Date: Mon, 17 Jul 2017 14:49:45 -0700 Subject: [PATCH] Fix grid initialization --- moldesign/_tests/test_mathutils.py | 6 ++---- moldesign/mathutils/grids.py | 8 ++++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/moldesign/_tests/test_mathutils.py b/moldesign/_tests/test_mathutils.py index de63414..6d13dd2 100644 --- a/moldesign/_tests/test_mathutils.py +++ b/moldesign/_tests/test_mathutils.py @@ -173,8 +173,6 @@ def test_spherical_harmonics_orthonormal(l1, m1, l2, m2): assert abs(integ) < 1e-2 - - @pytest.fixture def ndarray_ranges(): ranges = np.array([[1,4], @@ -191,7 +189,7 @@ def ranges_with_units(ndarray_ranges): @pytest.mark.parametrize('key', ['ndarray_ranges', 'ranges_with_units']) def test_volumetric_grid_point_list(key, request): ranges = request.getfixturevalue(key) - grid = mathutils.VolumetricGrid(*ranges, 3, 4, 5) + grid = mathutils.VolumetricGrid(*ranges, xpoints=3, ypoints=4, zpoints=5) assert (grid.xpoints, grid.ypoints, grid.zpoints) == (3,4,5) pl = list(grid.iter_points()) pa = grid.allpoints() @@ -213,7 +211,7 @@ def test_volumetric_grid_point_list(key, request): @pytest.mark.parametrize('key', ['ndarray_ranges', 'ranges_with_units']) def test_volumetric_iteration(key, request): ranges = request.getfixturevalue(key) - grid = mathutils.VolumetricGrid(*ranges, 4) + grid = mathutils.VolumetricGrid(*ranges, npoints=4) grid_iterator = grid.iter_points() assert (grid.xpoints, grid.ypoints, grid.zpoints) == (4,4,4) diff --git a/moldesign/mathutils/grids.py b/moldesign/mathutils/grids.py index 6c260e2..e49c1a8 100644 --- a/moldesign/mathutils/grids.py +++ b/moldesign/mathutils/grids.py @@ -32,10 +32,10 @@ class VolumetricGrid(object): xrange (Tuple[len=2]): (min,max) in x direction yrange (Tuple[len=2]): (min,max) in y direction zrange (Tuple[len=2]): (min,max) in z direction - npoints (int): default number of grid lines in each direction (default: 32) xpoints (int): number of grid lines in x direction (default: npoints) ypoints (int): number of grid lines in y direction (default: npoints) zpoints (int): number of grid lines in z direction (default: npoints) + npoints (int): synonym for "xpoints" """ dx, dy, dz = (utils.IndexView('deltas', i) for i in range(3)) xr, yr, zr = (utils.IndexView('ranges', i) for i in range(3)) @@ -43,7 +43,11 @@ class VolumetricGrid(object): xspace, yspace, zspace = (utils.IndexView('spaces', i) for i in range(3)) def __init__(self, xrange, yrange, zrange, - npoints=32, xpoints=None, ypoints=None, zpoints=None): + xpoints=None, ypoints=None, zpoints=None, + npoints=32): + + if xpoints is not None: + npoints = xpoints self.points = np.array([xpoints if xpoints is not None else npoints, ypoints if ypoints is not None else npoints,