Skip to content

Commit

Permalink
Fix grid initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
avirshup committed Jul 17, 2017
1 parent 1a03b44 commit a16d8a4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 2 additions & 4 deletions moldesign/_tests/test_mathutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,6 @@ def test_spherical_harmonics_orthonormal(l1, m1, l2, m2):
assert abs(integ) < 1e-2




@pytest.fixture
def ndarray_ranges():
ranges = np.array([[1,4],
Expand All @@ -191,7 +189,7 @@ def ranges_with_units(ndarray_ranges):
@pytest.mark.parametrize('key', ['ndarray_ranges', 'ranges_with_units'])
def test_volumetric_grid_point_list(key, request):
ranges = request.getfixturevalue(key)
grid = mathutils.VolumetricGrid(*ranges, 3, 4, 5)
grid = mathutils.VolumetricGrid(*ranges, xpoints=3, ypoints=4, zpoints=5)
assert (grid.xpoints, grid.ypoints, grid.zpoints) == (3,4,5)
pl = list(grid.iter_points())
pa = grid.allpoints()
Expand All @@ -213,7 +211,7 @@ def test_volumetric_grid_point_list(key, request):
@pytest.mark.parametrize('key', ['ndarray_ranges', 'ranges_with_units'])
def test_volumetric_iteration(key, request):
ranges = request.getfixturevalue(key)
grid = mathutils.VolumetricGrid(*ranges, 4)
grid = mathutils.VolumetricGrid(*ranges, npoints=4)

grid_iterator = grid.iter_points()
assert (grid.xpoints, grid.ypoints, grid.zpoints) == (4,4,4)
Expand Down
8 changes: 6 additions & 2 deletions moldesign/mathutils/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,22 @@ class VolumetricGrid(object):
xrange (Tuple[len=2]): (min,max) in x direction
yrange (Tuple[len=2]): (min,max) in y direction
zrange (Tuple[len=2]): (min,max) in z direction
npoints (int): default number of grid lines in each direction (default: 32)
xpoints (int): number of grid lines in x direction (default: npoints)
ypoints (int): number of grid lines in y direction (default: npoints)
zpoints (int): number of grid lines in z direction (default: npoints)
npoints (int): synonym for "xpoints"
"""
dx, dy, dz = (utils.IndexView('deltas', i) for i in range(3))
xr, yr, zr = (utils.IndexView('ranges', i) for i in range(3))
xpoints, ypoints, zpoints = (utils.IndexView('points', i) for i in range(3))
xspace, yspace, zspace = (utils.IndexView('spaces', i) for i in range(3))

def __init__(self, xrange, yrange, zrange,
npoints=32, xpoints=None, ypoints=None, zpoints=None):
xpoints=None, ypoints=None, zpoints=None,
npoints=32):

if xpoints is not None:
npoints = xpoints

self.points = np.array([xpoints if xpoints is not None else npoints,
ypoints if ypoints is not None else npoints,
Expand Down

0 comments on commit a16d8a4

Please sign in to comment.