Skip to content

Commit deea22c

Browse files
authored
Merge pull request #384 from JamesMcClung/pr/polymorphic-marder
Polymorphic Marder
2 parents 2719fa7 + 124b0df commit deea22c

17 files changed

Lines changed: 106 additions & 1045 deletions
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
template <typename MfieldsState, typename Mparticles>
4+
struct GaussCorrectorBase
5+
{
6+
virtual void correct_gauss(MfieldsState& mflds, Mparticles& mprts) = 0;
7+
};

src/include/psc.hxx

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <setup_particles.hxx>
99

1010
#include "../libpsc/vpic/fields_item_vpic.hxx"
11+
#include "gauss_corrector_base.hxx"
1112
#include "diagnostic_base.hxx"
1213
#include "injector_base.hxx"
1314
#include "external_current_base.hxx"
@@ -111,13 +112,13 @@ struct Psc
111112
using Sort = typename PscConfig::Sort;
112113
using Collision = typename PscConfig::Collision;
113114
using Checks = typename PscConfig::Checks;
114-
using Marder = typename PscConfig::Marder;
115115
using PushParticles = typename PscConfig::PushParticles;
116116
using PushFields = typename PscConfig::PushFields;
117117
using Bnd = typename PscConfig::Bnd;
118118
using BndFields = typename PscConfig::BndFields;
119119
using BndParticles = typename PscConfig::BndParticles;
120120
using Dim = typename PscConfig::Dim;
121+
using GaussCorrectorBaseT = GaussCorrectorBase<MfieldsState, Mparticles>;
121122
using DiagnosticBaseT = DiagnosticBase<Mparticles, MfieldsState>;
122123
using ParticleDiagnosticBaseT = ParticleDiagnosticBase<Mparticles>;
123124
using InjectorBaseT = InjectorBase<Mparticles, MfieldsState>;
@@ -127,16 +128,14 @@ struct Psc
127128
// ctor
128129

129130
Psc(const PscParams& params, Grid_t& grid, MfieldsState& mflds,
130-
Mparticles& mprts, Balance& balance, Collision& collision, Checks& checks,
131-
Marder& marder)
131+
Mparticles& mprts, Balance& balance, Collision& collision, Checks& checks)
132132
: p_{params},
133133
grid_{&grid},
134134
mflds_{mflds},
135135
mprts_{mprts},
136136
balance_{balance},
137137
collision_{collision},
138138
checks_{checks},
139-
marder_{marder},
140139
bndp_{grid},
141140
checkpointing_{params.write_checkpoint_every_step}
142141
{
@@ -165,6 +164,13 @@ struct Psc
165164
// TODO: improve ownership model: we should own these objects (i.e., use
166165
// unique_ptr), but don't want to burden the user with C++ boilerplate.
167166

167+
void add_gauss_corrector(GaussCorrectorBaseT* corrector)
168+
{
169+
if (corrector) {
170+
gauss_correctors_.push_back(corrector);
171+
}
172+
}
173+
168174
void add_injector(InjectorBaseT* injector)
169175
{
170176
assert(injector);
@@ -429,6 +435,15 @@ struct Psc
429435
bndf.fill_ghosts_E(mflds_);
430436
bnd_.fill_ghosts(mflds_, EX, EX + 3);
431437
prof_stop(pr_bndf);
438+
439+
if (p_.marder_interval > 0 && timestep % p_.marder_interval == 0) {
440+
mpi_printf(comm, "***** Performing Marder correction...\n");
441+
prof_start(pr_marder);
442+
for (auto corrector : gauss_correctors_) {
443+
corrector->correct_gauss(mflds_, mprts_);
444+
}
445+
prof_stop(pr_marder);
446+
}
432447
// state is now: x^{n+3/2}, p^{n+1}, E^{n+3/2}, B^{n+1}
433448

434449
// === field propagation B^{n+1} -> B^{n+3/2}
@@ -451,16 +466,6 @@ struct Psc
451466
prof_stop(pr_checks);
452467
}
453468

454-
// E at t^{n+3/2}, particles at t^{n+3/2}
455-
// B at t^{n+3/2} (Note: that is not its natural time,
456-
// but div B should be == 0 at any time...)
457-
if (p_.marder_interval > 0 && timestep % p_.marder_interval == 0) {
458-
mpi_printf(comm, "***** Performing Marder correction...\n");
459-
prof_start(pr_marder);
460-
marder_(mflds_, mprts_);
461-
prof_stop(pr_marder);
462-
}
463-
464469
if (checks_.gauss.should_do_check(timestep)) {
465470
mpi_printf(comm, "***** Checking gauss...\n");
466471
prof_restart(pr_checks);
@@ -532,7 +537,7 @@ protected:
532537
Balance& balance_;
533538
Collision& collision_;
534539
Checks& checks_;
535-
Marder& marder_;
540+
std::vector<GaussCorrectorBaseT*> gauss_correctors_;
536541
std::vector<DiagnosticBaseT*> diagnostics_;
537542
std::vector<InjectorBaseT*> injectors_;
538543
std::vector<ExternalCurrentBaseT*> external_currents_;
@@ -559,14 +564,13 @@ protected:
559564
// makePscIntegrator
560565

561566
template <typename PscConfig, typename MfieldsState, typename Mparticles,
562-
typename Balance, typename Collision, typename Checks,
563-
typename Marder>
567+
typename Balance, typename Collision, typename Checks>
564568
Psc<PscConfig> makePscIntegrator(const PscParams& params, Grid_t& grid,
565569
MfieldsState& mflds, Mparticles& mprts,
566570
Balance& balance, Collision& collision,
567-
Checks& checks, Marder& marder)
571+
Checks& checks)
568572
{
569-
return {params, grid, mflds, mprts, balance, collision, checks, marder};
573+
return {params, grid, mflds, mprts, balance, collision, checks};
570574
}
571575

572576
// ======================================================================

src/libpsc/psc_push_fields/marder_impl.hxx

Lines changed: 43 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
#pragma once
33

4+
#include "gauss_corrector_base.hxx"
45
#include "fields.hxx"
56
#include "writer_mrc.hxx"
67
#include "mpi_dtype_traits.hxx"
@@ -66,45 +67,35 @@ inline void correct(const Grid_t& grid, E1& efield, const Int3& efield_ib,
6667
assert(efield_ib == -grid.ibn);
6768
assert(mf_ib == -grid.ibn);
6869

69-
Real3 fac = {.5f * grid.dt * diffusion / grid.domain.dx[0],
70-
.5f * grid.dt * diffusion / grid.domain.dx[1],
71-
.5f * grid.dt * diffusion / grid.domain.dx[2]};
70+
Real3 fac = .5f * real_t(grid.dt) * diffusion * Real3(grid.domain.dx_inv);
7271

7372
for (int p = 0; p < grid.n_patches(); p++) {
7473
Int3 lx, rx, ly, ry, lz, rz;
7574
detail::find_limits(grid, p, lx, rx, ly, ry, lz, rz);
7675

77-
if (!grid.isInvar(0)) {
78-
Int3 l = lx, r = rx;
79-
auto ex = efield.view(_all, _all, _all, 0, p);
80-
auto res = mf.view(_all, _all, _all, 0, p);
81-
ex.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2])) =
82-
ex.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2])) +
83-
(res.view(_s(l[0] + 1, r[0] + 1), _s(l[1], r[1]), _s(l[2], r[2])) -
84-
res.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2]))) *
85-
fac[0];
86-
}
76+
Int3 ls[3] = {lx, ly, lz};
77+
Int3 rs[3] = {rx, ry, rz};
8778

88-
{
89-
Int3 l = ly, r = ry;
90-
auto ey = efield.view(_all, _all, _all, 1, p);
91-
auto res = mf.view(_all, _all, _all, 0, p);
92-
ey.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2])) =
93-
ey.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2])) +
94-
(res.view(_s(l[0], r[0]), _s(l[1] + 1, r[1] + 1), _s(l[2], r[2])) -
95-
res.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2]))) *
96-
fac[1];
97-
}
79+
auto res = mf.view(_all, _all, _all, 0, p);
80+
for (int d = 0; d < 3; d++) {
81+
if (grid.isInvar(d)) {
82+
continue;
83+
}
84+
85+
Int3 l = ls[d];
86+
Int3 r = rs[d];
87+
auto e_comp = efield.view(_all, _all, _all, d, p);
88+
89+
gt::gslice s1x = _s(l[0], r[0]);
90+
gt::gslice s1y = _s(l[1], r[1]);
91+
gt::gslice s1z = _s(l[2], r[2]);
92+
93+
gt::gslice s2[3] = {_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2])};
94+
s2[d] = _s(l[d] + 1, r[d] + 1);
9895

99-
{
100-
Int3 l = lz, r = rz;
101-
auto ez = efield.view(_all, _all, _all, 2, p);
102-
auto res = mf.view(_all, _all, _all, 0, p);
103-
ez.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2])) =
104-
ez.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2])) +
105-
(res.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2] + 1, r[2] + 1)) -
106-
res.view(_s(l[0], r[0]), _s(l[1], r[1]), _s(l[2], r[2]))) *
107-
fac[2];
96+
e_comp.view(s1x, s1y, s1z) =
97+
e_comp.view(s1x, s1y, s1z) +
98+
(res.view(s2[0], s2[1], s2[2]) - res.view(s1x, s1y, s1z)) * fac[d];
10899
}
109100
}
110101
}
@@ -198,12 +189,14 @@ inline void correct(const Grid_t& grid, E1& efield, const Int3& efield_ib,
198189
} // namespace marder
199190
} // namespace psc
200191

201-
template <typename S, typename D, typename ITEM_RHO, typename BND>
202-
class MarderCommon
192+
template <typename MFIELDS_STATE, typename MPARTICLES, typename ITEM_RHO,
193+
typename BND>
194+
class MarderCommon : public GaussCorrectorBase<MFIELDS_STATE, MPARTICLES>
203195
{
204196
public:
205-
using storage_type = S;
206-
using dim_t = D;
197+
using MfieldsState = MFIELDS_STATE;
198+
using Mparticles = MPARTICLES;
199+
using storage_type = typename MfieldsState::Storage;
207200
using Item_rho_t = ITEM_RHO;
208201
using Bnd = BND;
209202
using real_t = typename storage_type::value_type;
@@ -241,23 +234,16 @@ public:
241234
}
242235
}
243236

244-
// ----------------------------------------------------------------------
245-
// operator()
246-
//
247-
// Do the modified marder correction (See eq.(5, 7, 9, 10) in Mardahl and
248-
// Verboncoeur, CPC, 1997)
249-
250-
template <typename Mparticles>
251-
void operator()(const Grid_t& grid, storage_type& mflds, const Int3& mflds_ib,
252-
Mparticles& mprts)
237+
void correct_gauss(MfieldsState& mflds, Mparticles& mprts) override
253238
{
254-
auto efield = mflds.view(_all, _all, _all, _s(EX, EX + 3), _all);
255-
auto efield_ib = mflds_ib;
239+
const Grid_t& grid = mflds.grid();
240+
auto efield = mflds.storage().view(_all, _all, _all, _s(EX, EX + 3), _all);
241+
Int3 efield_ib = mflds.ib();
256242

257243
double inv_sum = 0.;
258244
for (int d = 0; d < 3; d++) {
259245
if (!grid.isInvar(d)) {
260-
inv_sum += 1. / sqr(grid.domain.dx[d]);
246+
inv_sum += sqr(grid.domain.dx_inv[d]);
261247
}
262248
}
263249
double diffusion_max = 1. / 2. / (.5 * grid.dt) / inv_sum;
@@ -270,7 +256,7 @@ public:
270256
// and expected to be filled when done, so we're playing it safe for the
271257
// time being.
272258
for (int i = 0; i < loop_; i++) {
273-
bnd_.fill_ghosts(grid, mflds, mflds_ib, EX, EX + 3);
259+
bnd_.fill_ghosts(mflds, EX, EX + 3);
274260
auto dive = psc::mflds::interior(grid, psc::item::div_nc(grid, efield));
275261

276262
Int3 res_ib = -grid.ibn;
@@ -282,21 +268,8 @@ public:
282268

283269
psc::marder::correct(grid, efield, efield_ib, res, res_ib, diffusion);
284270
}
285-
bnd_.fill_ghosts(grid, mflds, mflds_ib, EX, EX + 3);
286-
}
287-
288-
template <typename MfieldsState, typename Mparticles>
289-
void operator()(MfieldsState& mflds, Mparticles& mprts)
290-
{
291-
static int pr;
292-
if (!pr) {
293-
pr = prof_register("marder", 1., 0, 0);
294-
}
295-
296-
prof_start(pr);
297-
(*this)(mprts.grid(), mflds.storage(), mflds.ib(), mprts);
298-
prof_stop(pr);
299-
}
271+
bnd_.fill_ghosts(mflds, EX, EX + 3);
272+
};
300273

301274
// private:
302275
real_t diffusion_; //< diffusion coefficient for Marder correction
@@ -306,16 +279,18 @@ public:
306279
WriterMRC io_; //< for debug dumping
307280
};
308281

309-
template <typename S, typename D>
310-
using Marder_ = MarderCommon<S, D, Moment_rho_1st_nc<S, D>, Bnd_>;
282+
template <typename MFIELDS_STATE, typename MPARTICLES, typename D>
283+
using Marder_ =
284+
MarderCommon<MFIELDS_STATE, MPARTICLES,
285+
Moment_rho_1st_nc<typename MFIELDS_STATE::Storage, D>, Bnd_>;
311286

312287
#ifdef USE_CUDA
313288

314289
#include "psc_particles_single.h"
315290
#include "mparticles_cuda.hxx"
316291
#include "fields_item_moments_1st_cuda.hxx"
317292

318-
template <typename D>
319-
using MarderCuda = MarderCommon<MfieldsStateCuda::Storage, D,
293+
template <typename Mparticles, typename D>
294+
using MarderCuda = MarderCommon<MfieldsStateCuda, Mparticles,
320295
Moment_rho_1st_nc_cuda<D>, BndCuda3>;
321296
#endif

src/libpsc/tests/test_boundary_injector.cxx

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ using Mparticles = PscConfig::Mparticles;
3838
using Balance = PscConfig::Balance;
3939
using Collision = PscConfig::Collision;
4040
using Checks = PscConfig::Checks;
41-
using Marder = PscConfig::Marder;
4241
using OutputParticles = PscConfig::OutputParticles;
4342

4443
Grid_t* setupGrid()
@@ -133,10 +132,9 @@ TEST(BoundaryInjectorTest, Integration1Particle)
133132

134133
Balance balance{.1};
135134
Collision collision{grid, 0, 0.1};
136-
Marder marder(grid, 0.9, 3, false);
137135

138136
auto psc = makePscIntegrator<PscConfig>(psc_params, grid, mflds, mprts,
139-
balance, collision, checks, marder);
137+
balance, collision, checks);
140138

141139
psc.add_injector(
142140
new BoundaryInjector<ParticleGenerator, typename PscConfig::PushParticles>(
@@ -191,10 +189,9 @@ TEST(BoundaryInjectorTest, IntegrationManyParticles)
191189

192190
Balance balance{.1};
193191
Collision collision{grid, 0, 0.1};
194-
Marder marder(grid, 0.9, 3, false);
195192

196193
auto psc = makePscIntegrator<PscConfig>(psc_params, grid, mflds, mprts,
197-
balance, collision, checks, marder);
194+
balance, collision, checks);
198195

199196
psc.add_injector(
200197
new BoundaryInjector<ParticleGenerator, PscConfig::PushParticles>(
@@ -249,7 +246,6 @@ TEST(BoundaryInjectorTest, IntegrationManySpecies)
249246

250247
Balance balance{.1};
251248
Collision collision{grid, 0, 0.1};
252-
Marder marder(grid, 0.9, 3, false);
253249

254250
auto inject_electrons =
255251
BoundaryInjector<ParticleGenerator, PscConfig::PushParticles>{
@@ -259,7 +255,7 @@ TEST(BoundaryInjectorTest, IntegrationManySpecies)
259255
ParticleGenerator(-1, 1), grid};
260256

261257
auto psc = makePscIntegrator<PscConfig>(psc_params, grid, mflds, mprts,
262-
balance, collision, checks, marder);
258+
balance, collision, checks);
263259

264260
psc.add_injector(&inject_ions);
265261
psc.add_injector(&inject_electrons);

src/libpsc/tests/test_open_bcs_integration.cxx

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ using Mparticles = PscConfig::Mparticles;
2626
using Balance = PscConfig::Balance;
2727
using Collision = PscConfig::Collision;
2828
using Checks = PscConfig::Checks;
29-
using Marder = PscConfig::Marder;
3029
using OutputParticles = PscConfig::OutputParticles;
3130

3231
// ======================================================================
@@ -104,10 +103,9 @@ TEST(OpenBcsTest, IntegrationY)
104103

105104
Balance balance{.1};
106105
Collision collision{grid, 0, 0.1};
107-
Marder marder(grid, 0.9, 3, false);
108106

109107
auto psc = makePscIntegrator<PscConfig>(psc_params, *grid_ptr, mflds, mprts,
110-
balance, collision, checks, marder);
108+
balance, collision, checks);
111109

112110
// ----------------------------------------------------------------------
113111
// set up initial conditions

src/libpsc/tests/test_reflective_bcs.cxx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ using Mparticles = PscConfig::Mparticles;
2626
using Balance = PscConfig::Balance;
2727
using Collision = PscConfig::Collision;
2828
using Checks = PscConfig::Checks;
29-
using Marder = PscConfig::Marder;
3029
using OutputParticles = PscConfig::OutputParticles;
3130

3231
// ======================================================================

0 commit comments

Comments
 (0)