Skip to content

Commit a0aa16a

Browse files
committed
marder_impl: fix cuda version
1 parent e0fd7ac commit a0aa16a

1 file changed

Lines changed: 17 additions & 17 deletions

File tree

src/libpsc/psc_push_fields/marder_impl.hxx

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,20 @@ inline void correct(const Grid_t& grid, E1& efield, const Int3& efield_ib,
6363
#ifdef USE_CUDA
6464

6565
template <typename E1, typename E2>
66-
inline void cuda_marder_correct_yz(E1& efield, E2& res, Float3 fac, Int3 ly,
67-
Int3 ry, Int3 lz, Int3 rz)
66+
inline void cuda_marder_correct_yz(E1& efield, E2& res, Float3 fac, Int3 l,
67+
Int3 r, Int3 l, Int3 r)
6868
{
6969
auto k_efield = efield.to_kernel();
7070
auto k_res = res.to_kernel();
7171
gt::launch<2>(
7272
{k_efield.shape(1), k_efield.shape(2)}, GT_LAMBDA(int iy, int iz) {
73-
if ((iy >= ly[1] && iy < ry[1]) && (iz >= ly[2] && iz < ry[2])) {
73+
if ((iy >= l[1] && iy < r[1]) && (iz >= l[2] && iz < r[2])) {
7474
k_efield(0, iy, iz, 1) =
7575
k_efield(0, iy, iz, 1) +
7676
fac[1] * (k_res(0, iy + 1, iz) - k_res(0, iy, iz));
7777
}
7878

79-
if ((iy >= lz[1] && iy < rz[1]) && (iz >= lz[2] && iz < rz[2])) {
79+
if ((iy >= l[1] && iy < r[1]) && (iz >= l[2] && iz < r[2])) {
8080
k_efield(0, iy, iz, 2) =
8181
k_efield(0, iy, iz, 2) +
8282
fac[2] * (k_res(0, iy, iz + 1) - k_res(0, iy, iz));
@@ -86,30 +86,30 @@ inline void cuda_marder_correct_yz(E1& efield, E2& res, Float3 fac, Int3 ly,
8686
}
8787

8888
template <typename E1, typename E2>
89-
inline void cuda_marder_correct_xyz(E1& efield, E2& res, Float3 fac, Int3 lx,
90-
Int3 rx, Int3 ly, Int3 ry, Int3 lz, Int3 rz)
89+
inline void cuda_marder_correct_xyz(E1& efield, E2& res, Float3 fac, Int3 l,
90+
Int3 r)
9191
{
9292
auto k_efield = efield.to_kernel();
9393
auto k_res = res.to_kernel();
9494
gt::launch<3>(
9595
{k_efield.shape(0), k_efield.shape(1), k_efield.shape(2)},
9696
GT_LAMBDA(int ix, int iy, int iz) {
97-
if ((ix >= lx[0] && ix < rx[0]) && (iy >= lx[1] && iy < rx[1]) &&
98-
(iz >= lx[2] && iz < rx[2])) {
97+
if ((ix >= l[0] && ix < r[0]) && (iy >= l[1] && iy < r[1]) &&
98+
(iz >= l[2] && iz < r[2])) {
9999
k_efield(ix, iy, iz, 0) =
100100
k_efield(ix, iy, iz, 0) +
101101
fac[0] * (k_res(ix, iy + 1, iz) - k_res(ix, iy, iz));
102102
}
103103

104-
if ((ix >= ly[0] && ix < ry[0]) && (iy >= ly[1] && iy < ry[1]) &&
105-
(iz >= ly[2] && iz < ry[2])) {
104+
if ((ix >= l[0] && ix < r[0]) && (iy >= l[1] && iy < r[1]) &&
105+
(iz >= l[2] && iz < r[2])) {
106106
k_efield(ix, iy, iz, 1) =
107107
k_efield(ix, iy, iz, 1) +
108108
fac[1] * (k_res(ix, iy + 1, iz) - k_res(ix, iy, iz));
109109
}
110110

111-
if ((ix >= lz[0] && ix < rz[0]) && (iy >= lz[1] && iy < rz[1]) &&
112-
(iz >= lz[2] && iz < rz[2])) {
111+
if ((ix >= l[0] && ix < r[0]) && (iy >= l[1] && iy < r[1]) &&
112+
(iz >= l[2] && iz < r[2])) {
113113
k_efield(ix, iy, iz, 2) =
114114
k_efield(ix, iy, iz, 2) +
115115
fac[2] * (k_res(ix, iy, iz + 1) - k_res(ix, iy, iz));
@@ -132,15 +132,15 @@ inline void correct(const Grid_t& grid, E1& efield, const Int3& efield_ib,
132132
assert(mf_ib == -grid.ibn);
133133
// OPT, do all patches in one kernel
134134
for (int p = 0; p < grid.n_patches(); p++) {
135-
Int3 lx, rx, ly, ry, lz, rz;
136-
detail::find_limits(grid, p, lx, rx, ly, ry, lz, rz);
135+
Int3 l = grid.ibn;
136+
Int3 r = grid.ibn + grid.ldims;
137137

138138
auto p_efield = efield.view(_all, _all, _all, _all, p);
139139
auto p_res = mf.view(_all, _all, _all, 0, p);
140140
if (grid.isInvar(0)) {
141-
cuda_marder_correct_yz(p_efield, p_res, fac, ly, ry, lz, rz);
141+
cuda_marder_correct_yz(p_efield, p_res, fac, l, r);
142142
} else {
143-
cuda_marder_correct_xyz(p_efield, p_res, fac, lx, rx, ly, ry, lz, rz);
143+
cuda_marder_correct_xyz(p_efield, p_res, fac, l, r);
144144
}
145145
}
146146
}
@@ -161,7 +161,7 @@ public:
161161
using Bnd = BND;
162162
using real_t = typename storage_type::value_type;
163163

164-
// FIXME: checkpointing won't properly restore state
164+
// FIXME: checkpointing won't properl restore state
165165

166166
MarderCommon(const Grid_t& grid, real_t diffusion, int loop, bool dump)
167167
: diffusion_{diffusion}, loop_{loop}, dump_{dump}

0 commit comments

Comments
 (0)