diff --git a/Exec/GNUmakefile b/Exec/GNUmakefile index 0379b6f..e8cfea4 100644 --- a/Exec/GNUmakefile +++ b/Exec/GNUmakefile @@ -9,10 +9,51 @@ USE_HIP = FALSE COMP = gnu DIM = 3 USE_FFT = TRUE - +USE_ML = FALSE +TINY_PROFILE = FALSE +PROFILE = FALSE USE_SUNDIALS = FALSE SUNDIALS_HOME ?= ../../sundials/instdir +ifeq ($(USE_ML),TRUE) + + CPPFLAGS += -DAMREX_USE_ML + + # Define a macro for the C++ preprocessor + DEFINES += -DML_ENABLE -D_GLIBCXX_USE_CXX11_ABI=1 + + # Pytorch root directory selection + ifeq ($(USE_CUDA),TRUE) + PYTORCH_ROOT := ../../libtorch_cuda + else + PYTORCH_ROOT := ../../libtorch_cpu + endif + + TORCH_LIBPATH = $(PYTORCH_ROOT)/lib + + # Library definitions + ifeq ($(USE_CUDA),TRUE) + # Note: Modern LibTorch often requires both torch_cuda and torch_cpu + TORCH_LIBS = -ltorch -ltorch_cuda -ltorch_cpu -lc10 -lc10_cuda -lcuda + else + TORCH_LIBS = -ltorch -ltorch_cpu -lc10 + endif + + # Header search paths + INCLUDE_LOCATIONS += $(PYTORCH_ROOT)/include \ + $(PYTORCH_ROOT)/include/torch/csrc/api/include + + # Library search paths + LIBRARY_LOCATIONS += $(TORCH_LIBPATH) + + # Linker flags (rpath ensures the .so files are found at runtime) + ifeq ($(USE_CUDA),TRUE) + LDFLAGS += -Xlinker "--no-as-needed,-rpath,$(TORCH_LIBPATH)" $(TORCH_LIBS) + else + LDFLAGS += -Wl,--no-as-needed,-rpath=$(TORCH_LIBPATH) $(TORCH_LIBS) + endif +endif + include $(AMREX_HOME)/Tools/GNUMake/Make.defs include ../Source/Make.package diff --git a/Exec/README b/Exec/README deleted file mode 100644 index ce5b9b2..0000000 --- a/Exec/README +++ /dev/null @@ -1,72 +0,0 @@ -------------------------------------------- - -inputs_compare_std4 - -Standard problem 4. -Entire domain is magnetic. -Matches comapre_std.m -==================== Initial Setup ==================== - demag_coupling = 1 - M_normalization = 1 - exchange_coupling = 1 - DMI_coupling = 0 - anisotropy_coupling = 0 - TimeIntegratorOption = 1 - -------------------------------------------- - -inputs_compare_subdomain - -Test comparison problem of demag solver; to compare initial H_demag to compare_subdomain.m -A cubic block of material within the domain is magnetic. -==================== Initial Setup ==================== - demag_coupling = 1 - M_normalization = 1 - exchange_coupling = 1 - DMI_coupling = 0 - anisotropy_coupling = 0 - TimeIntegratorOption = 1 - -------------------------------------------- - -inputs_exchange - -A block of magnetic material in the center of the domain initialized -so that My = Ms. Only exchange physics is enabled, so the system does -not change due to the interface boundary conditions. -==================== Initial Setup ==================== - demag_coupling = 0 - M_normalization = 1 - exchange_coupling = 1 - DMI_coupling = 0 - anisotropy_coupling = 0 - TimeIntegratorOption = 2 - -------------------------------------------- - -inputs_PSSW - -A block of magnetic material in the center of the domain initialized so that My = Ms. -==================== Initial Setup ==================== - demag_coupling = 1 - M_normalization = 1 - exchange_coupling = 1 - DMI_coupling = 0 - anisotropy_coupling = 0 - TimeIntegratorOption = 1 - -------------------------------------------- - -inputs_restart - -This is the same as inputs_PSSW but with lower resolution and restart -hooks more easily enabled to test this capability with these physics. -==================== Initial Setup ==================== - demag_coupling = 1 - M_normalization = 1 - exchange_coupling = 1 - DMI_coupling = 0 - anisotropy_coupling = 0 - TimeIntegratorOption = 1 - -------------------------------------------- diff --git a/Exec/README_md.pytorch b/Exec/README_md.pytorch new file mode 100644 index 0000000..e84f4c2 --- /dev/null +++ b/Exec/README_md.pytorch @@ -0,0 +1,36 @@ +# PyTorch (libtorch) Download and Setup + +This guide downloads the libtorch CUDA 11.8 C++ distribution and unzips it in the same directory that contains `MagneX`, then renames the extracted folder to `libtorch_cuda`. + +## Steps + +1. Change to the parent directory of `MagneX`: + +```bash +cd +``` + +2. Download the libtorch archive with `wget`: + +```bash +wget https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.7.1%2Bcu118.zip +``` + +3. Unzip the archive in the same directory: + +```bash +unzip libtorch-cxx11-abi-shared-with-deps-2.7.1+cu118.zip +``` + +4. Rename the extracted folder to `libtorch_cuda`: + +```bash +mv libtorch libtorch_cuda +``` + +## Result + +After the steps above, you should have: + +- `/MagneX` +- `/pytorch_cuda` diff --git a/Exec/README_md.sundials b/Exec/README_md.sundials new file mode 100644 index 0000000..04f7cdd --- /dev/null +++ b/Exec/README_md.sundials @@ -0,0 +1,6 @@ +# SUNDIALS Setup + +Refer to https://amrex-codes.github.io/amrex/docs_html/TimeIntegration_Chapter.html +for SUNDIALS installation, build, and usage instructions. + +Make sure that `SUNDIALS_HOME` in the `GNUmakefile` points to the installation directory. diff --git a/Exec/README_sundials b/Exec/README_sundials deleted file mode 100644 index 3b6830a..0000000 --- a/Exec/README_sundials +++ /dev/null @@ -1,4 +0,0 @@ -Refer to https://amrex-codes.github.io/amrex/docs_html/TimeIntegration_Chapter.html -for SNUDIALS installation, build, and usage instructions. - -Make sure that SUNDIALS_HOME in the GNUmakefile points to the installation directory. diff --git a/Source/Demag_ml.cpp b/Source/Demag_ml.cpp new file mode 100644 index 0000000..c6f5cc5 --- /dev/null +++ b/Source/Demag_ml.cpp @@ -0,0 +1,325 @@ +#ifdef AMREX_USE_ML + +// MagneX_ML_Infer_Dynamic.cpp +#include "MagneX.H" +#include +#include +// #include +#include +#include +#include +using namespace amrex; + +// ------------------------------------------------------------ +// Helper: Move all parameters + buffers of a TorchScript module +// to the same device (e.g., cuda:3) to avoid device mismatch. +// ------------------------------------------------------------ +void MoveModuleToDevice(torch::jit::script::Module& m, + const torch::Device& device) +{ + m.to(device); +// for (auto& p : m.named_parameters(true)) { +// p.value().set_data(p.value().to(device)); +// } +// for (auto& b : m.named_buffers(true)) { +// b.value().set_data(b.value().to(device)); +// } +} + +// ------------------------------------------------------------ +// Helper: read expected_spatial = [nx, ny, nz] from normalizer. +// normalizer must have exported method get_expected_spatial(). +// ------------------------------------------------------------ +amrex::IntVect GetExpectedSpatial(torch::jit::script::Module& x_norm_module) +{ + at::Tensor t = x_norm_module.get_method("get_expected_spatial")({}).toTensor(); + t = t.to(torch::kCPU).to(torch::kLong).contiguous(); + + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(t.numel() == 3, + "get_expected_spatial() must return a tensor of 3 elements: [nx, ny, nz]"); + + const int nx = static_cast(t[0].item()); + const int ny = static_cast(t[1].item()); + const int nz = static_cast(t[2].item()); + + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(nx > 0 && ny > 0 && nz > 0, + "expected_spatial invalid (must be >0)"); + + return amrex::IntVect(nx, ny, nz); +} + +// ------------------------------------------------------------ +// Bundle: holds model + normalizers + expected shape meta. +// ------------------------------------------------------------ +struct MLBundle +{ + torch::jit::script::Module x_norm; + torch::jit::script::Module y_norm; + torch::jit::script::Module model; + + amrex::IntVect expected_spatial{0}; // (nx,ny,nz) + int device_id = 0; + + bool initialized = false; +}; + +// ------------------------------------------------------------ +// Load bundle from .pt files; move everything to cuda:device_id; +// read expected_spatial from x_norm. +// ------------------------------------------------------------ +static inline MLBundle LoadMLBundle(const std::string& x_norm_pt, + const std::string& y_norm_pt, + const std::string& model_pt, + int device_id) +{ + MLBundle b; + b.device_id = device_id; + + // Load modules (default loads to CPU) + b.x_norm = torch::jit::load(x_norm_pt); + b.y_norm = torch::jit::load(y_norm_pt); + b.model = torch::jit::load(model_pt); + + // Move to target GPU + torch::Device dev(torch::kCUDA, device_id); + MoveModuleToDevice(b.x_norm, dev); + MoveModuleToDevice(b.y_norm, dev); + MoveModuleToDevice(b.model, dev); + + // Read expected spatial shape from normalizer + b.expected_spatial = GetExpectedSpatial(b.x_norm); + + b.initialized = true; + + amrex::Print() << "[MLBundle] Loaded modules on cuda:" << device_id + << " expected_spatial = (" + << b.expected_spatial[0] << ", " + << b.expected_spatial[1] << ", " + << b.expected_spatial[2] << ")\n"; + + return b; +} + +// ------------------------------------------------------------ +// Pack Mfield MultiFab (Mx,My,Mz) into Torch tensor [1,3,nx,ny,nz] +// Dynamically sized using bx.size() and expected_spatial. +// ------------------------------------------------------------ +at::Tensor PackMfieldToTensorDynamic( + const Array& Mfield, + const MFIter& mfi, + const Box& bx, + const amrex::IntVect& expected_spatial, + int device_id) +{ +#if AMREX_SPACEDIM != 3 + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(false, "PackMfieldToTensorDynamic expects 3D"); +#endif + + const auto& Mx = Mfield[0].const_array(mfi); + const auto& My = Mfield[1].const_array(mfi); + const auto& Mz = Mfield[2].const_array(mfi); + + const IntVect bx_lo = bx.smallEnd(); + const IntVect nbox = bx.size(); // (nx,ny,nz) + + const int nx = nbox[0]; + const int ny = nbox[1]; + const int nz = nbox[2]; + const int ncell = nx * ny * nz; + + AMREX_ALWAYS_ASSERT_WITH_MESSAGE( + nx == expected_spatial[0] && ny == expected_spatial[1] && nz == expected_spatial[2], + "PackMfieldToTensorDynamic: bx size != expected_spatial from normalizer" + ); + + amrex::Gpu::ManagedVector aux_Mx(ncell), aux_My(ncell), aux_Mz(ncell); + Real* AMREX_RESTRICT auxPtr_Mx = aux_Mx.dataPtr(); + Real* AMREX_RESTRICT auxPtr_My = aux_My.dataPtr(); + Real* AMREX_RESTRICT auxPtr_Mz = aux_Mz.dataPtr(); + + amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept { + const int ii = i - bx_lo[0]; + const int jj = j - bx_lo[1]; + const int kk = k - bx_lo[2]; + + // flatten index consistent with (nx,ny,nz) reshape below + const int index = kk + jj * nz + ii * nz * ny; + + auxPtr_Mx[index] = Mx(i, j, k); + auxPtr_My[index] = My(i, j, k); + auxPtr_Mz[index] = Mz(i, j, k); + }); + + amrex::Gpu::streamSynchronize(); + + // Wrap managed memory; then we will immediately copy to CUDA tensor + at::Tensor tMx = torch::from_blob(auxPtr_Mx, {ncell}, torch::kFloat64); + at::Tensor tMy = torch::from_blob(auxPtr_My, {ncell}, torch::kFloat64); + at::Tensor tMz = torch::from_blob(auxPtr_Mz, {ncell}, torch::kFloat64); + + at::Tensor rMx = tMx.reshape({nx, ny, nz}); + at::Tensor rMy = tMy.reshape({nx, ny, nz}); + at::Tensor rMz = tMz.reshape({nx, ny, nz}); + + at::Tensor M = torch::stack({rMx, rMy, rMz}, 0); // [3,nx,ny,nz] + + torch::Device dev(torch::kCUDA, device_id); + M = M.to(dev).to(torch::kFloat32).unsqueeze(0); // [1,3,nx,ny,nz] + + return M; +} + +// ------------------------------------------------------------ +// Normalize, Forward, Denormalize +// ------------------------------------------------------------ +at::Tensor NormalizeInput(const at::Tensor& M_cuda_f32, + torch::jit::script::Module& x_norm_module) +{ + BL_PROFILE("NormalizeInput"); + return x_norm_module.get_method("encode")({M_cuda_f32}).toTensor(); +} + +// at::Tensor MLForwardOnly(const at::Tensor& norm_tensor, +// torch::jit::script::Module& ml_module) +// { +// BL_PROFILE("MLForwardOnly"); +// auto out = ml_module.forward({norm_tensor}).toTensor(); + +// #ifdef AMREX_USE_CUDA +// // 把 forward 的 GPU 时间“结算”在这里(用于 profiling 验证) +// amrex::Gpu::streamSynchronize(); +// #endif + +// return out; +// } +at::Tensor MLForwardOnly(const at::Tensor& norm_tensor, + torch::jit::script::Module& ml_module) +{ + BL_PROFILE("MLForwardOnly"); + +#ifdef AMREX_USE_CUDA + at::cuda::CUDAEvent start(/*enable_timing=*/true); + at::cuda::CUDAEvent stop (/*enable_timing=*/true); + + auto stream = at::cuda::getDefaultCUDAStream(); + + stream.synchronize(); + + start.record(stream); + + auto out = ml_module.forward({norm_tensor}).toTensor(); + + stop.record(stream); + stop.synchronize(); + + float ms = start.elapsed_time(stop); + amrex::Print() << "Forward-only time (ms) = " << ms << "\n"; + + return out; +#else + return ml_module.forward({norm_tensor}).toTensor(); +#endif +} + + +at::Tensor DenormalizeOutput(const at::Tensor& y_tensor, + torch::jit::script::Module& y_norm_module) +{ + BL_PROFILE("DenormalizeOutput"); + // Keep device, convert to float64 for AMReX Real=double path + return y_norm_module.get_method("decode")({y_tensor}).toTensor().to(torch::kFloat64); +} + +// ------------------------------------------------------------ +// Unpack Torch tensor [1,3,nx,ny,nz] into H_demagfield MultiFabs. +// ------------------------------------------------------------ +void UnpackTensorToHfieldDynamic( + const at::Tensor& denorm_torch_f64, // [1,3,nx,ny,nz] + Array& H_demagfield, + const MFIter& mfi, + const Box& bx, + const amrex::IntVect& expected_spatial) +{ +#if AMREX_SPACEDIM != 3 + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(false, "UnpackTensorToHfieldDynamic expects 3D"); +#endif + + auto Hx_demag = H_demagfield[0].array(mfi); + auto Hy_demag = H_demagfield[1].array(mfi); + auto Hz_demag = H_demagfield[2].array(mfi); + + const IntVect bx_lo = bx.smallEnd(); + const IntVect nbox = bx.size(); + + const int nx = nbox[0]; + const int ny = nbox[1]; + const int nz = nbox[2]; + + AMREX_ALWAYS_ASSERT_WITH_MESSAGE( + nx == expected_spatial[0] && ny == expected_spatial[1] && nz == expected_spatial[2], + "UnpackTensorToHfieldDynamic: bx size != expected_spatial from normalizer" + ); + + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(denorm_torch_f64.dim() == 5, "denorm must be [1,3,nx,ny,nz]"); + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(denorm_torch_f64.size(0) == 1, "batch must be 1"); + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(denorm_torch_f64.size(1) == 3, "channel must be 3"); + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(denorm_torch_f64.size(2) == nx && + denorm_torch_f64.size(3) == ny && + denorm_torch_f64.size(4) == nz, + "denorm tensor spatial mismatch"); + + // Flatten each component to 1D for packed_accessor + at::Tensor Hx = denorm_torch_f64.select(0, 0).select(0, 0).contiguous().view({-1}); + at::Tensor Hy = denorm_torch_f64.select(0, 0).select(0, 1).contiguous().view({-1}); + at::Tensor Hz = denorm_torch_f64.select(0, 0).select(0, 2).contiguous().view({-1}); + +#ifdef AMREX_USE_CUDA + auto Hx_acc = Hx.packed_accessor64(); + auto Hy_acc = Hy.packed_accessor64(); + auto Hz_acc = Hz.packed_accessor64(); +#endif + + amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept { + const int ii = i - bx_lo[0]; + const int jj = j - bx_lo[1]; + const int kk = k - bx_lo[2]; + + const int index = kk + jj * nz + ii * nz * ny; + + Hx_demag(i, j, k) = Hx_acc[index]; + Hy_demag(i, j, k) = Hy_acc[index]; + Hz_demag(i, j, k) = Hz_acc[index]; + }); + + amrex::Gpu::streamSynchronize(); +} + +// ------------------------------------------------------------ +// One-call wrapper: pack -> encode -> forward -> decode -> unpack +// ------------------------------------------------------------ +void RunMLDemagOnBox( + MLBundle& b, + const Array& Mfield, + Array& H_demagfield, + const MFIter& mfi, + const Box& bx) +{ + AMREX_ALWAYS_ASSERT_WITH_MESSAGE(b.initialized, "MLBundle not initialized"); + + // 1) Pack raw M to [1,3,nx,ny,nz] on cuda:device_id + at::Tensor M = PackMfieldToTensorDynamic(Mfield, mfi, bx, b.expected_spatial, b.device_id); + + // 2) Normalize (encode) on same GPU + at::Tensor norm = NormalizeInput(M, b.x_norm); + + // 3) Forward + at::Tensor pred = MLForwardOnly(norm, b.model); + + // 4) Denormalize + at::Tensor denorm = DenormalizeOutput(pred, b.y_norm); + + // 5) Unpack back to MultiFab + UnpackTensorToHfieldDynamic(denorm, H_demagfield, mfi, bx, b.expected_spatial); +} + +#endif diff --git a/Source/MagneX.H b/Source/MagneX.H index c09c3f6..7c441b2 100644 --- a/Source/MagneX.H +++ b/Source/MagneX.H @@ -1,3 +1,7 @@ +#ifdef AMREX_USE_ML +#include +#endif + #include #include "MagneX_namespace.H" @@ -175,3 +179,40 @@ void WritePlotfile(MultiFab& Ms, const Geometry& geom, const Real& time, const int& plt_step); + +#ifdef AMREX_USE_ML + +at::Tensor PackMfieldToTensorDynamic( + const amrex::Array& Mfield, + const amrex::MFIter& mfi, + const amrex::Box& bx, + const amrex::IntVect& expected_spatial, + int device_id); + +at::Tensor NormalizeInput( + const at::Tensor& M_cuda_f32, + torch::jit::script::Module& x_norm_module); + +at::Tensor MLForwardOnly( + const at::Tensor& norm_tensor, + torch::jit::script::Module& ml_module); + +at::Tensor DenormalizeOutput( + const at::Tensor& y_tensor, + torch::jit::script::Module& y_norm_module); + +void UnpackTensorToHfieldDynamic( + const at::Tensor& denorm_torch_f64, + amrex::Array& H_demagfield, + const amrex::MFIter& mfi, + const amrex::Box& bx, + const amrex::IntVect& expected_spatial); + +// Move all parameters + buffers of a TorchScript module to a device +void MoveModuleToDevice(torch::jit::script::Module& m, + const torch::Device& device); + +// Read expected_spatial = [nx, ny, nz] from normalizer module +amrex::IntVect GetExpectedSpatial(torch::jit::script::Module& x_norm_module); + +#endif diff --git a/Source/MagneX.cpp b/Source/MagneX.cpp index ba35f1c..515e376 100644 --- a/Source/MagneX.cpp +++ b/Source/MagneX.cpp @@ -132,6 +132,10 @@ AMREX_GPU_MANAGED int MagneX::demag_coupling; // 0 = FFTW (single-MPI), 1 = heFFTe (distributed) AMREX_GPU_MANAGED int MagneX::FFT_solver; +// ML flag +int MagneX::ml_enable; + + void InitializeMagneXNamespace() { BL_PROFILE_VAR("InitializeMagneXNamespace()",InitializeMagneXNameSpace); @@ -228,6 +232,14 @@ void InitializeMagneXNamespace() { restart = -1; pp.query("restart",restart); + ml_enable = 0; + pp.query("ml_enable",ml_enable); +#ifndef AMREX_USE_ML + if (ml_enable == 1) { + amrex::Abort("ml_enable=1 requires USE_ML=TRUE"); + } +#endif + diag_type = -1; pp.query("diag_type",diag_type); diff --git a/Source/MagneX_namespace.H b/Source/MagneX_namespace.H index 0660dce..18bfe82 100644 --- a/Source/MagneX_namespace.H +++ b/Source/MagneX_namespace.H @@ -52,6 +52,8 @@ namespace MagneX { extern int diag_type; + extern int ml_enable; + extern int timedependent_Hbias; extern int timedependent_alpha; diff --git a/Source/Make.package b/Source/Make.package index f99179f..24f7df0 100644 --- a/Source/Make.package +++ b/Source/Make.package @@ -1,6 +1,7 @@ CEXE_sources += Checkpoint.cpp CEXE_sources += ComputeLLGRHS.cpp CEXE_sources += Demagnetization.cpp +CEXE_sources += Demag_ml.cpp CEXE_sources += Diagnostics.cpp CEXE_sources += EffectiveAnisotropyField.cpp CEXE_sources += EffectiveDMIField.cpp diff --git a/Source/main.cpp b/Source/main.cpp index 3c91de2..d39ec68 100644 --- a/Source/main.cpp +++ b/Source/main.cpp @@ -1,15 +1,21 @@ #include "MagneX.H" #include "Demagnetization.H" - #include #include - +#include +#include #ifdef AMREX_USE_SUNDIALS #include #endif #include +#ifdef AMREX_USE_ML +#include +#include // for at::cuda::setDevice +#include +#endif + using namespace amrex; using namespace MagneX; @@ -57,6 +63,11 @@ void main_main () Array LLG_RHS; Array LLG_RHS_pre; Array LLG_RHS_avg; +#ifdef AMREX_USE_ML + torch::jit::script::Module ml_module; + torch::jit::script::Module x_norm_module; + torch::jit::script::Module y_norm_module; +#endif // Declare variables for hysteresis Real normalized_Mx; @@ -83,6 +94,10 @@ void main_main () // Count how many times we have incremented Hbias int increment_count = 0; + // ML related variables (Declared here so they are visible to the whole function) + amrex::IntVect expected_spatial(0,0,0); + int device_id = -1; + BoxArray ba; DistributionMapping dm; @@ -100,6 +115,66 @@ void main_main () } + // ********************************** + // // LOAD PYTORCH MODEL + if (ml_enable == 1) { +#ifdef AMREX_USE_ML + BL_PROFILE_VAR("LoadPytorch",LoadPytorch); + + // Load pytorch module via torch script + + + std::string ml_model_name; + std::string x_normalizer_name; + std::string y_normalizer_name; + + ParmParse pp_ml; + pp_ml.query("ml_model_name", ml_model_name); + pp_ml.query("x_normalizer_name", x_normalizer_name); + pp_ml.query("y_normalizer_name", y_normalizer_name); + + amrex::Print()<<"\n"< tensor + at::Tensor M_cuda_f32 = PackMfieldToTensorDynamic( + Mfield_old, mfi, bx, expected_spatial, device_id + ); + + at::Tensor norm = NormalizeInput(M_cuda_f32, x_norm_module); + at::Tensor pred = MLForwardOnly(norm, ml_module); + at::Tensor denorm_f64 = DenormalizeOutput(pred, y_norm_module); + + // unpack: tensor -> MultiFab + UnpackTensorToHfieldDynamic(denorm_f64, H_demagfield, mfi, bx, expected_spatial); + } +#endif + } else { + demag_solver.CalculateH_demag(Mfield_old, H_demagfield); + } } if (exchange_coupling == 1) { @@ -496,7 +593,30 @@ void main_main () // Poisson solve and H_demag computation with Mfield if (demag_coupling == 1) { - demag_solver.CalculateH_demag(Mfield, H_demagfield); + if (ml_enable == 1) { +#ifdef AMREX_USE_ML + for (amrex::MFIter mfi(Mfield_old[0], amrex::TilingIfNotGPU()); + mfi.isValid(); ++mfi) + { + const amrex::Box& bx = mfi.validbox(); + + // pack: MultiFab -> tensor + at::Tensor M_cuda_f32 = PackMfieldToTensorDynamic( + Mfield_old, mfi, bx, expected_spatial, device_id + ); + + at::Tensor norm = NormalizeInput(M_cuda_f32, x_norm_module); + at::Tensor pred = MLForwardOnly(norm, ml_module); + at::Tensor denorm_f64 = DenormalizeOutput(pred, y_norm_module); + + // unpack: tensor -> MultiFab + UnpackTensorToHfieldDynamic(denorm_f64, H_demagfield, mfi, bx, expected_spatial); + } +#endif + } else { + demag_solver.CalculateH_demag(Mfield, H_demagfield); + } + } if (exchange_coupling == 1) { @@ -599,7 +719,11 @@ void main_main () H_demagfield[idim].setVal(0.); } } else { - demag_solver.CalculateH_demag(ar_state, H_demagfield); + if (ml_enable == 1) { + amrex::Abort("add ML demag to SUNDIALS rhs"); + } else { + demag_solver.CalculateH_demag(ar_state, H_demagfield); + } } } @@ -674,7 +798,11 @@ void main_main () // H_demag if (demag_coupling == 1) { if (fast_demag==1) { - demag_solver.CalculateH_demag(ar_state, H_demagfield); + if (ml_enable == 1) { + amrex::Abort("add ML demag to fast dynamics"); + } else { + demag_solver.CalculateH_demag(ar_state, H_demagfield); + } } else { for (int idim=0; idim