diff --git a/.gitignore b/.gitignore
index 1887d60e..6ed58fc9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -53,3 +53,6 @@ examples/.virtual_documents/*
.virtual_documents/
# Local Netlify folder
.netlify
+
+examples/control_pickles/*
+freegsnke/control_loop/control_test_files/*
diff --git a/README.md b/README.md
index 8ce85d40..0b59e8fb 100644
--- a/README.md
+++ b/README.md
@@ -38,7 +38,7 @@ Static Grad-Shafranov problems are solved using **fourth-order accurate finite d
-In the left panel above we show an example of a dynamic equilibrium calculated using FreeGSNKE's forward solver, simulating the flat-phase of a **MAST-U** plasma discharge.On the right is the sequence of equilibrium reconstructions for the actual MAST-U shot. The agreement between the simulation and the real shot is very good in both the plasma shape targets and the currents in the poloidal field coils, illustrating FreeGSNKE's accuracy. The contours represent constant poloidal flux and the different tokamak features are plotted in various colours (refer back to table above).
+In the left panel above we show an example of a dynamic equilibrium calculated using FreeGSNKE's forward solver, simulating the flat-phase of a **MAST-U** plasma discharge. On the right is the sequence of equilibrium reconstructions for the actual MAST-U shot. The agreement between the simulation and the real shot is very good in both the plasma shape targets and the currents in the poloidal field coils, illustrating FreeGSNKE's accuracy. The contours represent constant poloidal flux and the different tokamak features are plotted in various colours (refer back to table above).
## Feature roadmap
FreeGSNKE is constantly evolving and so we hope to provide users with more advanced features over time:
@@ -187,6 +187,7 @@ Here are a list of FreeGSNKE papers that describe or use the code:
- A. Agnello et al, "Emulation techniques for scenario and classical control design of tokamak plasmas", Physics of Plasmas, **31**, 043091 (2024). DOI: [10.1063/5.0187822](https://doi.org/10.1063/5.0187822).
- K. Pentland et al, "Validation of the static forward Grad-Shafranov equilibrium solvers in FreeGSNKE and Fiesta using EFIT++ reconstructions from MAST-U", Physica Scripta, **100**, 025608 (2025). DOI: [10.1088/1402-4896/ada192](https://iopscience.iop.org/article/10.1088/1402-4896/ada192).
- K. Pentland et al, "Multiple solutions to the static forward free-boundary Grad-Shafranov problem on MAST-U", Nuclear Fusion (2025). DOI: [10.1088/1741-4326/adf3cc](https://iopscience.iop.org/article/10.1088/1741-4326/adf3cc).
+- P. Cavestany et al, "Real-Time Applicability of Emulated Virtual Circuits for Tokamak Plasma Shape Control", IEEE Conference on Control Technology and Applications (CCTA), San Diego, CA, USA, (2025), pp. 826-831, DOI: [10.1109/CCTA53793.2025.11151371](https://ieeexplore.ieee.org/document/11151371).
If you would like your FreeGSNKE-related paper to be added, please let us know!
diff --git a/examples/README.md b/examples/README.md
index 670e3004..79a54cd8 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -1,6 +1,6 @@
# Examples
-These example Jupyter notebooks are intended to be the **first port of call for new users** trying to get up to speed with the basics of simulating Grad-Shafranov equilibria using FreeGSNKE (the numberings represent approximately increasing complexity!). Most examples will use a MAST-U-like tokamak machine, unless otherwise specified.
+These example Jupyter notebooks are intended to be the **first port of call for new users** trying to get up to speed with the basics of simulating Grad-Shafranov (GS) equilibria using FreeGSNKE (the numberings represent approximately increasing complexity!). Most examples will use a MAST-U-like tokamak machine, unless otherwise specified.
| Example notebook | Purpose | Who can use it? |
| ------ | ------ | ------ |
@@ -10,11 +10,14 @@ These example Jupyter notebooks are intended to be the **first port of call for
| Example 02 | Learn how to use the static forward solver. | Anyone |
| Example 03 | Learn how to extract results from a calculated equilibrium. | Anyone |
| Example 04 | Learn how to use the magnetic probes object. | Anyone |
-| Example 05 | Learn how to use the evolutive solver to simulate time-dependent equilibria. | Anyone |
+| Example 05a | Learn how to use the nonlinear and linear (with GS) evolutive solver to simulate time-dependent equilibria. | Anyone |
+| Example 05b | Learn how to use the linear evolutive solver without solving GS at each timestep. | Anyone |
+| Example 05c | Learn how to use the linear evolutive solver (with or without solving GS) with automatic relinearisation enabled. | Anyone |
| Example 06a/b | Simulate (static) MAST-U equilibria over an entire shot using inputs from EFIT++ (requires internal UKAEA MAST-U database). | UKAEA employees + collaborators |
| Example 07 | Static inverse solve in a SPARC-like tokamak. | Anyone |
| Example 08 | Static inverse solve in an ITER-like tokamak. | Anyone |
| Example 09 | Learn how to use and build virtual circuits for plasma shape control. | Anyone |
| Example 10 | Learn how to calculate growth rates associated with vertically unstable modes. | Anyone |
+| Example 11 | Learn how to use the evolutive solver alongside a virtual plasma control system (FreeGSNKE Pulse Design Tool). | Anyone |
If a new example has been created, please add a new line to the table explaining its purpose!
\ No newline at end of file
diff --git a/examples/example03 - extracting_equilibrium_quantites.ipynb b/examples/example03 - extracting_equilibrium_quantites.ipynb
index dfceb3ff..f31069d8 100644
--- a/examples/example03 - extracting_equilibrium_quantites.ipynb
+++ b/examples/example03 - extracting_equilibrium_quantites.ipynb
@@ -486,8 +486,8 @@
" \n",
" if type(eq._pgreen[name]) is dict:\n",
" num_coils = len(eq._pgreen[name])\n",
- " for i, ind in enumerate(eq._pgreen[name]):\n",
- " greens_matrix += eq._pgreen[name][ind]*scaling[i*(len(scaling)//num_coils)]\n",
+ " for j, ind in enumerate(eq._pgreen[name]):\n",
+ " greens_matrix += eq._pgreen[name][ind]*scaling[j*(len(scaling)//num_coils)]\n",
" else:\n",
" num_coils = 1\n",
" greens_matrix = eq._pgreen[name]*scaling[0]\n",
@@ -542,7 +542,7 @@
"name = \"P6\"\n",
"ax3.grid(True, which='both')\n",
"eq.tokamak.plot(axis=ax3,show=False)\n",
- "# ax3.plot(tokamak.limiter.R, tokamak.limiter.Z, color='k', linewidth=1.2, linestyle=\"--\")\n",
+ "# ax3.plot(tokamak.limiter.R, tokamak.limiter.Z, color=' k', linewidth=1.2, linestyle=\"--\")\n",
"ax3.plot(tokamak.wall.R, tokamak.wall.Z, color='k', linewidth=1.2, linestyle=\"-\")\n",
"im3 = ax3.contour(eq.R, eq.Z, psi_coils[name], levels=50) \n",
"ax3.set_aspect('equal')\n",
@@ -1179,7 +1179,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.16"
+ "version": "3.10.13"
}
},
"nbformat": 4,
diff --git a/examples/example05 - evolutive_forward_solve.ipynb b/examples/example05a - nonlinear_and_linear_evolution_with_GS.ipynb
similarity index 99%
rename from examples/example05 - evolutive_forward_solve.ipynb
rename to examples/example05a - nonlinear_and_linear_evolution_with_GS.ipynb
index c0486a7a..0bffdbbc 100644
--- a/examples/example05 - evolutive_forward_solve.ipynb
+++ b/examples/example05a - nonlinear_and_linear_evolution_with_GS.ipynb
@@ -1007,7 +1007,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.16"
+ "version": "3.10.13"
}
},
"nbformat": 4,
diff --git a/examples/example05b - linear_evolution_without_GS.ipynb b/examples/example05b - linear_evolution_without_GS.ipynb
new file mode 100644
index 00000000..e14670c1
--- /dev/null
+++ b/examples/example05b - linear_evolution_without_GS.ipynb
@@ -0,0 +1,394 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Example: linear evolution without solving the Grad-Shafranov equation\n",
+ "\n",
+ "This notebook follows on from the previous one, except that the following (linear) evolutive solves **will not** solve the GS equation at each timestep. \n",
+ "\n",
+ "This mode uses the linearisations from the initial equilibrium configuration at time $t=0$ (i.e. the Jacobians inside the evolutive solver) to evolve plasma shape parameters of interest (e.g. midplane radii, X-point positions) forward in time rapidly.\n",
+ "\n",
+ "This can be useful for doing very fast scans forward in time into how certain shape parameters will evolve (linearly at least). Note, however, the equilibria for these simulations cannot be visualised as GS has not been solved!\n",
+ "\n",
+ "The simulation proceeds in almost exactly the same way as the previous notebook so please do check that out first. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Static forward simulation\n",
+ "\n",
+ "Begin by building the initial plasma configuration as usual. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "# build machine\n",
+ "from freegsnke import build_machine\n",
+ "tokamak = build_machine.tokamak(\n",
+ " active_coils_path=f\"../machine_configs/MAST-U/MAST-U_like_active_coils.pickle\",\n",
+ " passive_coils_path=f\"../machine_configs/MAST-U/MAST-U_like_passive_coils.pickle\",\n",
+ " limiter_path=f\"../machine_configs/MAST-U/MAST-U_like_limiter.pickle\",\n",
+ " wall_path=f\"../machine_configs/MAST-U/MAST-U_like_wall.pickle\",\n",
+ ")\n",
+ "\n",
+ "# initialise equilibrium object\n",
+ "from freegsnke import equilibrium_update\n",
+ "eq = equilibrium_update.Equilibrium(\n",
+ " tokamak=tokamak,\n",
+ " Rmin=0.1, Rmax=2.0, # radial range\n",
+ " Zmin=-2.2, Zmax=2.2, # vertical range\n",
+ " nx=65, # number of grid points in the radial direction (needs to be of the form (2**n + 1) with n being an integer)\n",
+ " ny=129, # number of grid points in the vertical direction (needs to be of the form (2**n + 1) with n being an integer)\n",
+ " # psi=plasma_psi\n",
+ ") \n",
+ "\n",
+ "# initialise profile object\n",
+ "from freegsnke.jtor_update import ConstrainPaxisIp\n",
+ "profiles = ConstrainPaxisIp(\n",
+ " eq=eq,\n",
+ " paxis=8.1e3,\n",
+ " Ip=6.2e5,\n",
+ " fvac=0.5,\n",
+ " alpha_m=1.8,\n",
+ " alpha_n=1.2\n",
+ ")\n",
+ "\n",
+ "# initialise solver\n",
+ "from freegsnke import GSstaticsolver\n",
+ "GSStaticSolver = GSstaticsolver.NKGSsolver(eq) \n",
+ "\n",
+ "# set coil currents\n",
+ "import pickle\n",
+ "with open('data/simple_diverted_currents_PaxisIp.pk', 'rb') as f:\n",
+ " current_values = pickle.load(f)\n",
+ "\n",
+ "for key in current_values.keys():\n",
+ " eq.tokamak[key].current = current_values[key]\n",
+ "eq.tokamak[\"P6\"].current += 100\n",
+ "\n",
+ "# carry out forward solve\n",
+ "GSStaticSolver.solve(eq=eq, \n",
+ " profiles=profiles, \n",
+ " constrain=None, \n",
+ " target_relative_tolerance=1e-9)\n",
+ "\n",
+ "# plot the resulting equilbrium\n",
+ "fig1, ax1 = plt.subplots(1, 1, figsize=(4, 8), dpi=80)\n",
+ "ax1.grid(True, which='both')\n",
+ "eq.plot(axis=ax1, show=False)\n",
+ "eq.tokamak.plot(axis=ax1, show=False)\n",
+ "ax1.set_xlim(0.1, 2.15)\n",
+ "ax1.set_ylim(-2.25, 2.25)\n",
+ "plt.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Define the plasma descriptors\n",
+ "\n",
+ "Here we define the \"plasma descriptors\" (i.e. the shape parameters) that we want to track over time. These descriptors will be evolved according to the following equation:\n",
+ "\n",
+ "$$ \n",
+ " \\vec{s}(t) = \\vec{s}(0) + \\frac{\\partial \\vec{s}}{\\partial \\vec{I}_e} (\\vec{I}_e(t) - \\vec{I}_e(0)) + \\frac{\\partial \\vec{s}}{\\partial \\vec{\\theta}} (\\vec{\\theta}(t) - \\vec{\\theta}(0))\n",
+ "$$\n",
+ "\n",
+ "where\n",
+ " - $\\vec{s}(t)$ are the plasma descriptors (defined in the following cells) approximated via the linearisation, and $\\vec{s}(0)$ are the values defined by the initial equilibrium configuration. \n",
+ " - $\\vec{I}_e(t) = (\\vec{I}_m(t), I_p(t))$ is a vector of the currents in the metals and total plasma current at time $t$. \n",
+ " - $\\vec{\\theta}(t)$ are the (plasma current density) profile parameters at time $t$. \n",
+ " - $\\frac{\\partial \\vec{s}}{\\partial \\vec{I}_e}$ and $\\frac{\\partial \\vec{s}}{\\partial \\vec{\\theta}}$ are the Jacobians calculated wrt the initial equilibrium. \n",
+ "\n",
+ "The currents come from solving the circuit equations the same way as before while the plasma profile parameters are inputs provided by the user. \n",
+ "\n",
+ "This enables a rapid (approximate) evolution of these descriptors over short periods of time - noting that the Jacobians lose accuracy as they diverge from the equilibrium from which they constructed."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The plasma descriptors are defined via a function that must take a FreeGSNKE equilibrium object as it's sole input. \n",
+ "\n",
+ "For example, below we define a function that will return an array containing: the average $Z$ position of the $J_{\\phi}$ map; the radial and vertical position of the lower X-point; and the inboard midplane radius. \n",
+ "\n",
+ "This can be customsied to add any descriptors of interest to you. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define the descriptors function (it should return an array of values and take in an eq object)\n",
+ "def plasma_descriptors(eq):\n",
+ "\n",
+ " # find lower X-point\n",
+ " # define a \"box\" in which to search for the lower X-point\n",
+ " XPT_BOX = [[0.33, -0.88], [0.95, -1.38]]\n",
+ "\n",
+ " # mask those points\n",
+ " xpt_mask = (\n",
+ " (eq.xpt[:, 0] >= XPT_BOX[0][0])\n",
+ " & (eq.xpt[:, 0] <= XPT_BOX[1][0])\n",
+ " & (eq.xpt[:, 1] <= XPT_BOX[0][1])\n",
+ " & (eq.xpt[:, 1] >= XPT_BOX[1][1])\n",
+ " )\n",
+ " xpts = eq.xpt[xpt_mask, 0:2].squeeze()\n",
+ " if xpts.ndim > 1 and xpts.shape[0] > 1:\n",
+ " opt = eq.opt[0, 0:2]\n",
+ " dists = np.linalg.norm(xpts - opt, axis=1)\n",
+ " idx = np.argmin(dists) # index of closest point\n",
+ " Rx, Zx = xpts[idx, :]\n",
+ " else:\n",
+ " Rx, Zx = xpts\n",
+ "\n",
+ " # find avg. Z position of jtor\n",
+ " Zcurrent = eq.Zcurrent()\n",
+ "\n",
+ " # inboard midplane radius\n",
+ " Rin = eq.innerOuterSeparatrix()[0]\n",
+ "\n",
+ " return np.array([Zcurrent, Rx, Zx, Rin])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Time evolution\n",
+ "\n",
+ "Having defined the plasma descriptors, we can now instantiate the evolutive solver object. By including the `plasma_descriptors` function as an argument, the relevant Jacobians will be calculated to enable this evolution. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from freegsnke import nonlinear_solve\n",
+ "\n",
+ "stepping = nonlinear_solve.nl_solver(\n",
+ " eq=eq, \n",
+ " profiles=profiles, \n",
+ " GSStaticSolver=GSStaticSolver,\n",
+ " full_timestep=5e-4, \n",
+ " plasma_resistivity=1e-6,\n",
+ " max_mode_frequency=10**2.5,\n",
+ " plasma_descriptor_function=plasma_descriptors\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Carry out time stepping\n",
+ "\n",
+ "As we did in the previous notebook, we define the number of steps we wish to take (based on the timestep chosen above) and create some storage lists. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# number of time steps to simulate\n",
+ "max_count = 50\n",
+ "\n",
+ "# initialising some variables for iteration and logging\n",
+ "counter = 0\n",
+ "t = 0\n",
+ "\n",
+ "# initialise object\n",
+ "stepping.initialize_from_ICs(eq, profiles)\n",
+ "\n",
+ "# storage\n",
+ "history_times = [t]\n",
+ "history_currents = [stepping.currents_vec]\n",
+ "history_plasma_descriptors = [plasma_descriptors(eq)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# constant voltages (i.e. no control or external drive)\n",
+ "voltages = (stepping.vessel_currents_vec*stepping.evol_metal_curr.coil_resist)[:stepping.evol_metal_curr.n_active_coils] \n",
+ "\n",
+ "# # time-dependent plasma current density profile parameters\n",
+ "# alpha_m = np.tile(profiles.alpha_m, max_count+1)\n",
+ "# alpha_m -= (0.1 * np.sin(0.05 * np.pi * np.arange(max_count+1))) # we add some perturbation\n",
+ "\n",
+ "# alpha_n = np.tile(profiles.alpha_n, max_count+1)\n",
+ "# alpha_n += (0.1 * np.sin(0.1 * np.pi * np.arange(max_count+1))) # we add some perturbation\n",
+ "\n",
+ "# paxis = np.tile(profiles.paxis, max_count+1)\n",
+ "# paxis += (0.1 * np.sin(0.01 * np.pi * np.arange(max_count+1))) # we add some perturbation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can now call the timestepper with both the `linear_only=True` and `no_GS=True` options. This ensures that a linear solve is carried out **without** solving GS at each timestep. Note that you will **not** be able to extract all of the usual equilibrium information at each timestep as the internal `eq` objects no longer contain valid GS solutions! Switch back to `no_GS=False` if that's what you need. \n",
+ "\n",
+ "This means that in this mode of evolution, only the the metal currents, the plasma current, and the plasma descriptors are evolved over time."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# loop over time steps\n",
+ "while counter < max_count:\n",
+ " print(f'Step: {counter}/{max_count-1}')\n",
+ " print(f'--- t = {t:.2e}')\n",
+ "\n",
+ " # carry out the time step (feed in the voltages and profile parameters)\n",
+ " stepping.nlstepper(\n",
+ " active_voltage_vec=voltages,\n",
+ " linear_only=True,\n",
+ " no_GS=True,\n",
+ " verbose=False,\n",
+ " )\n",
+ "\n",
+ " # store time-advanced currents and plasma descriptors\n",
+ " history_currents.append(stepping.currents_vec)\n",
+ " history_plasma_descriptors.append(stepping.plasma_descriptors_vec)\n",
+ "\n",
+ " t += stepping.dt_step\n",
+ " history_times.append(t)\n",
+ " counter += 1\n",
+ "\n",
+ "# transform lists to arrays\n",
+ "history_currents = np.array(history_currents)\n",
+ "history_times = np.array(history_times)\n",
+ "history_plasma_descriptors = np.array(history_plasma_descriptors)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let us carry out the same evolution, this time with the linear solver actually solving GS at each step (to compare the difference). Note that we change the way we extract the plasma descriptors as `no_GS=False`!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# re-initialising some variables for iteration and logging\n",
+ "counter = 0\n",
+ "t = 0\n",
+ "\n",
+ "# re-initialise object\n",
+ "stepping.initialize_from_ICs(eq, profiles)\n",
+ "\n",
+ "# storage\n",
+ "history_times_with_GS = [t]\n",
+ "history_currents_with_GS = [stepping.currents_vec]\n",
+ "history_plasma_descriptors_with_GS = [plasma_descriptors(eq)]\n",
+ "\n",
+ "# loop over the time steps\n",
+ "while counter < max_count:\n",
+ " print(f'Step: {counter}/{max_count-1}')\n",
+ " print(f'--- t = {t:.2e}')\n",
+ " \n",
+ " # carry out the time step (feed in the voltages and profile parameters)\n",
+ " stepping.nlstepper(\n",
+ " active_voltage_vec=voltages,\n",
+ " linear_only=True,\n",
+ " no_GS=False,\n",
+ " verbose=False,\n",
+ " )\n",
+ "\n",
+ " # store time-advanced currents and plasma descriptors\n",
+ " history_currents_with_GS.append(stepping.currents_vec)\n",
+ " history_plasma_descriptors_with_GS.append(plasma_descriptors(stepping.eq1)) # <-- changed!\n",
+ "\n",
+ " t += stepping.dt_step\n",
+ " history_times_with_GS.append(t)\n",
+ " counter += 1\n",
+ "\n",
+ "# transform lists to arrays\n",
+ "history_currents_with_GS = np.array(history_currents_with_GS)\n",
+ "history_times_with_GS = np.array(history_times_with_GS)\n",
+ "history_plasma_descriptors_with_GS = np.array(history_plasma_descriptors_with_GS)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we can compare the difference between the linear evolution with and without GS solves at each step. We see that for some parameters the evolution diverges very quickly with solving GS!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, ax = plt.subplots(4,1, figsize=(10, 12), dpi=80)\n",
+ "ax = ax.flatten()\n",
+ "\n",
+ "labels = [\"Zcurrent [m]\", \"Rx [m]\", \"Zx [m]\", \"Rin [m]\"]\n",
+ "\n",
+ "for i in range(4):\n",
+ " ax[i].plot(history_times_with_GS, history_plasma_descriptors_with_GS[:, i], 'k+', label=\"Linear (with GS)\")\n",
+ " ax[i].plot(history_times, history_plasma_descriptors[:, i], 'm3', label=\"Linear (without GS)\")\n",
+ " ax[i].set_xlabel(\"Time (s)\")\n",
+ " ax[i].set_ylabel(labels[i])\n",
+ " ax[i].legend()\n",
+ " ax[i].grid()\n",
+ "\n",
+ "fig.tight_layout()\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/example05c - linear_evolution_with_relinearisation.ipynb b/examples/example05c - linear_evolution_with_relinearisation.ipynb
new file mode 100644
index 00000000..2584c00f
--- /dev/null
+++ b/examples/example05c - linear_evolution_with_relinearisation.ipynb
@@ -0,0 +1,940 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Example: linear evolution with re-linearisation during evolutive equilibrium calculations\n",
+ "\n",
+ "This notebook also follows on from the previous two. \n",
+ "\n",
+ "Here, we demonstrate how to re-calculate the linearisation (e.g. Jacobians) of the dynamics (\"relinearise\") when using the linear solver. As we know, the accuracy of the Jacobians (calculated using the initial equilibrium, before evolution) will degrade as the plasma \"moves away\" from the initial equilibrium around which it was calculated. This means that we need to re-linearise every so often in order to maintain accuracy of the plasma evolution over time. \n",
+ "\n",
+ "The more often a relinearisation occurs, the more accurate (in theory) a linear simulation will be (when compared to the full nonlinear simulation). It is not practical or efficient to relinearise each timestep, but doing so once a given criterion is met can significantly improve agreement with the nonlinear solver. To trigger relinearisation, we monitor the relative change in plasma current density $J_{\\phi}$ at the current time $t$ and the initial time $t_0$. If this threshold exceeds tolerance $\\epsilon$, then we re-linearise around the equilibrium at time $t$ (and set $t_0 = t$). \n",
+ "\n",
+ "The criterion checks whether\n",
+ "$$\n",
+ "\\frac{\\left \\lVert J_{\\phi}(t) - J_{\\phi}(t_0) \\right \\rVert}{\\left \\lVert J_{\\phi}(t_0) \\right \\rVert} < \\epsilon,\n",
+ "$$\n",
+ "is met or not at each timestep $t$."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Import packages"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "from IPython.display import display, clear_output\n",
+ "import pickle\n",
+ "import time"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Build the machine"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "# build machine\n",
+ "from freegsnke import build_machine\n",
+ "tokamak = build_machine.tokamak(\n",
+ " active_coils_path=f\"../machine_configs/MAST-U/MAST-U_like_active_coils.pickle\",\n",
+ " passive_coils_path=f\"../machine_configs/MAST-U/MAST-U_like_passive_coils.pickle\",\n",
+ " limiter_path=f\"../machine_configs/MAST-U/MAST-U_like_limiter.pickle\",\n",
+ " wall_path=f\"../machine_configs/MAST-U/MAST-U_like_wall.pickle\",\n",
+ ")\n",
+ "\n",
+ "# initialise equilibrium object\n",
+ "from freegsnke import equilibrium_update\n",
+ "eq = equilibrium_update.Equilibrium(\n",
+ " tokamak=tokamak,\n",
+ " Rmin=0.1, Rmax=2.0, # radial range\n",
+ " Zmin=-2.2, Zmax=2.2, # vertical range\n",
+ " nx=65, # number of grid points in the radial direction (needs to be of the form (2**n + 1) with n being an integer)\n",
+ " ny=129, # number of grid points in the vertical direction (needs to be of the form (2**n + 1) with n being an integer)\n",
+ " # psi=plasma_psi\n",
+ ") \n",
+ "\n",
+ "# initialise profile object\n",
+ "from freegsnke.jtor_update import ConstrainPaxisIp\n",
+ "profiles = ConstrainPaxisIp(\n",
+ " eq=eq,\n",
+ " paxis=8.1e3,\n",
+ " Ip=6.2e5,\n",
+ " fvac=0.5,\n",
+ " alpha_m=1.8,\n",
+ " alpha_n=1.2\n",
+ ")\n",
+ "\n",
+ "# initialise solver\n",
+ "from freegsnke import GSstaticsolver\n",
+ "GSStaticSolver = GSstaticsolver.NKGSsolver(eq) \n",
+ "\n",
+ "# set coil currents\n",
+ "import pickle\n",
+ "with open('data/simple_diverted_currents_PaxisIp.pk', 'rb') as f:\n",
+ " current_values = pickle.load(f)\n",
+ "\n",
+ "for key in current_values.keys():\n",
+ " eq.tokamak[key].current = current_values[key]\n",
+ "\n",
+ "# carry out forward solve\n",
+ "GSStaticSolver.solve(eq=eq, \n",
+ " profiles=profiles, \n",
+ " constrain=None, \n",
+ " target_relative_tolerance=1e-9)\n",
+ "\n",
+ "# plot the resulting equilbrium\n",
+ "fig1, ax1 = plt.subplots(1, 1, figsize=(4, 8), dpi=80)\n",
+ "ax1.grid(True, which='both')\n",
+ "eq.plot(axis=ax1, show=False)\n",
+ "eq.tokamak.plot(axis=ax1, show=False)\n",
+ "ax1.set_xlim(0.1, 2.15)\n",
+ "ax1.set_ylim(-2.25, 2.25)\n",
+ "plt.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define the descriptors function (it should return an array of values and take in an eq object)\n",
+ "def plasma_descriptors(eq):\n",
+ "\n",
+ " # find lower X-point\n",
+ " # define a \"box\" in which to search for the lower X-point\n",
+ " XPT_BOX = [[0.33, -0.88], [0.95, -1.38]]\n",
+ "\n",
+ " # mask those points\n",
+ " xpt_mask = (\n",
+ " (eq.xpt[:, 0] >= XPT_BOX[0][0])\n",
+ " & (eq.xpt[:, 0] <= XPT_BOX[1][0])\n",
+ " & (eq.xpt[:, 1] <= XPT_BOX[0][1])\n",
+ " & (eq.xpt[:, 1] >= XPT_BOX[1][1])\n",
+ " )\n",
+ " xpts = eq.xpt[xpt_mask, 0:2].squeeze()\n",
+ " if xpts.ndim > 1 and xpts.shape[0] > 1:\n",
+ " opt = eq.opt[0, 0:2]\n",
+ " dists = np.linalg.norm(xpts - opt, axis=1)\n",
+ " idx = np.argmin(dists) # index of closest point\n",
+ " Rx, Zx = xpts[idx, :]\n",
+ " else:\n",
+ " Rx, Zx = xpts\n",
+ "\n",
+ " # find avg. Z position of jtor\n",
+ " Zcurrent = eq.Zcurrent()\n",
+ "\n",
+ " # inboard midplane radius\n",
+ " Rin = eq.innerOuterSeparatrix()[0]\n",
+ "\n",
+ " return np.array([Zcurrent, Rx, Zx, Rin])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# initialise the nonlinear solver object\n",
+ "from freegsnke import nonlinear_solve\n",
+ "\n",
+ "stepping = nonlinear_solve.nl_solver(\n",
+ " eq=eq, \n",
+ " profiles=profiles, \n",
+ " GSStaticSolver=GSStaticSolver,\n",
+ " full_timestep=5e-4, \n",
+ " plasma_resistivity=1e-6,\n",
+ " max_mode_frequency=10**2.5,\n",
+ " plasma_descriptor_function=plasma_descriptors\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Carry out standard linear solve\n",
+ "\n",
+ "As we have done before, we carry out a standard linear simulation so that we can compare how the relinearisation performs compared to it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# initialise with the initial condition\n",
+ "stepping.initialize_from_ICs(eq, profiles)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# number of time steps to simulate\n",
+ "max_count = 50\n",
+ "\n",
+ "# initialising some variables for iteration and logging\n",
+ "counter = 0\n",
+ "t = 0\n",
+ "\n",
+ "history_times = [t]\n",
+ "history_currents = [stepping.currents_vec]\n",
+ "history_equilibria = [stepping.eq1.create_auxiliary_equilibrium()]\n",
+ "history_o_points = [stepping.eq1.opt[0]]\n",
+ "history_elongation = [stepping.eq1.geometricElongation()]\n",
+ "history_triangularity = [stepping.eq1.triangularity()]\n",
+ "history_squareness = [stepping.eq1.squareness()[1]]\n",
+ "history_area = [stepping.eq1.separatrix_area()]\n",
+ "history_length = [stepping.eq1.separatrix_length()]\n",
+ "history_jtor = [stepping.profiles1.jtor]\n",
+ "history_jtor_norm = []\n",
+ "history_timings = []\n",
+ "history_plasma_descriptors = [plasma_descriptors(eq)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# use constant voltages\n",
+ "voltages = (stepping.vessel_currents_vec*stepping.evol_metal_curr.coil_resist)[:stepping.evol_metal_curr.n_active_coils] "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Call the solver (linear)\n",
+ "We call the linear solver without performing any re-linearisations. This will use the linearisation (calculated when `initialize_from_ICs` is called) to evolve the plasma for all iterations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# initialise the solver with the initial equilibrium/profiles\n",
+ "stepping.initialize_from_ICs(eq, profiles)\n",
+ "\n",
+ "# loop over time steps\n",
+ "while counter threshold]\n",
+ "for i in range(0,9):\n",
+ " axs_flat[i].vlines(relin_times, ymin=[axs_flat[i].get_ylim()[0]]*len(relin_times), ymax=[axs_flat[i].get_ylim()[1]]*len(relin_times), linestyles=\"--\", color=\"k\", linewidths=0.7)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# plot the equilibria at the final time step\n",
+ "fig1, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 8), dpi=60)\n",
+ "\n",
+ "ax1.grid(True, which='both')\n",
+ "history_equilibria[-1].plot(axis=ax1, show=False)\n",
+ "eq.tokamak.plot(axis=ax1, show=False)\n",
+ "ax1.set_xlim(0.1, 2.15)\n",
+ "ax1.set_ylim(-2.25, 2.25)\n",
+ "ax1.set_title(\"Linear\")\n",
+ "\n",
+ "ax2.grid(True, which='both')\n",
+ "history_equilibria_rl[-1].plot(axis=ax2, show=False)\n",
+ "eq.tokamak.plot(axis=ax2, show=False)\n",
+ "ax2.set_xlim(0.1, 2.15)\n",
+ "ax2.set_ylim(-2.25, 2.25)\n",
+ "ax2.set_title(\"Relinearised\")\n",
+ "\n",
+ "ax3.grid(True, which='both')\n",
+ "history_equilibria_nl[-1].plot(axis=ax3, show=False)\n",
+ "eq.tokamak.plot(axis=ax3, show=False)\n",
+ "ax3.set_xlim(0.1, 2.15)\n",
+ "ax3.set_ylim(-2.25, 2.25)\n",
+ "ax3.set_title(\"Nonlinear\")\n",
+ "\n",
+ "plt.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# plot relinearisation criteria evolution\n",
+ "fig, axs = plt.subplots(1, 1, figsize=(12, 4), dpi=80)\n",
+ "axs.plot(history_times[0:-1], history_jtor_norm, 'k+-', label=\"linear\")\n",
+ "axs.plot(history_times_rl[0:-1], history_jtor_norm_rl, 'b.-', label=\"relinearised\")\n",
+ "axs.hlines(threshold, xmin=history_times[0], xmax=history_times[-1], linestyle='--', color='k', label=\"Threshold\")\n",
+ "axs.set_xlabel('Time')\n",
+ "axs.set_ylabel(r\"Relative $J_{\\phi}$ norm\")\n",
+ "axs.set_yscale('log')\n",
+ "axs.legend()\n",
+ "axs.grid()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# plot runtimes\n",
+ "fig, axs = plt.subplots(2, 1, figsize=(10, 6), dpi=80)\n",
+ "axs = axs.flat\n",
+ "\n",
+ "axs[0].plot(history_times[0:-1], history_timings, 'k+-', label=\"linear\")\n",
+ "axs[0].plot(history_times_rl[0:-1], history_timings_rl, 'b.-', label=\"relinearised\")\n",
+ "axs[0].plot(history_times_nl[0:-1], history_timings_nl, 'rx-', label=\"nonlinear\")\n",
+ "axs[0].set_ylabel('Runtime per step [s]')\n",
+ "axs[0].set_yscale('log')\n",
+ "axs[0].grid()\n",
+ "axs[0].set_xticklabels([])\n",
+ "\n",
+ "axs[1].plot(history_times[0:-1], np.cumsum(history_timings)/60, 'k+-', label=\"linear\")\n",
+ "axs[1].plot(history_times_rl[0:-1], np.cumsum(history_timings_rl)/60, 'b.-', label=\"relinearised\")\n",
+ "axs[1].plot(history_times_nl[0:-1], np.cumsum(history_timings_nl)/60, 'rx-', label=\"nonlinear\")\n",
+ "\n",
+ "\n",
+ "# plot relinearisation times\n",
+ "for i in range(0,2):\n",
+ " axs[i].vlines(relin_times, ymin=[axs[i].get_ylim()[0]]*len(relin_times), ymax=[axs[i].get_ylim()[1]]*len(relin_times), linestyles=\"--\", color=\"k\", linewidths=0.9)\n",
+ "\n",
+ "# axs.plot(jtor_norm_nl)\n",
+ "axs[1].set_xlabel('Simulation time [s]')\n",
+ "axs[1].set_ylabel('Cumulative runtime [mins]')\n",
+ "axs[1].legend()\n",
+ "axs[1].grid()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Re-linearising when using the `no_GS` option\n",
+ "\n",
+ "Next, we will do the same simulations (with and without relinearsiation) but this time without solving Grad-Shafranov at each timestep (see prior example notebook). This will enable a very fast simulation but, due to the lack of a control scheme, the accuracy of the simulation will be lost very quickly (less so when relinearisation is switched on). \n",
+ "\n",
+ "We should note that given GS is no longer solved, the criterion for relinearisation is no longer valid (as we will not have access to $J_{\\phi}$ at each timestep). We therefore carry out relinearisation when the absolute change in the `plasma_descriptors` (with respect to the values at the last linearisation) exceeds $\\epsilon$. \n",
+ "\n",
+ "A value for $\\epsilon$ can be set for each descriptor individually (as a list of floats) or for all of them (as a single float)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "First let us carry out a standard linear simulation without solving GS at each timestep (no relinearisation is carried out). Here we set the `relinearise_threshold` as a list of `None` (or large) values to indicate that we do not relinearise at all. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# reset the solver object by resetting the initial conditions\n",
+ "stepping.dIydI_ICs = None\n",
+ "stepping.dIydtheta_ICs = None\n",
+ "stepping.initialize_from_ICs(eq, profiles)\n",
+ "\n",
+ "counter = 0\n",
+ "t = 0\n",
+ "\n",
+ "# storage\n",
+ "history_times_nogs = [t]\n",
+ "history_currents_nogs = [stepping.currents_vec]\n",
+ "history_plasma_descriptors_nogs = [plasma_descriptors(eq)]\n",
+ "history_criteria_nogs = []"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# simulate\n",
+ "while counter < max_count:\n",
+ " print(f'Step: {counter}/{max_count-1}')\n",
+ " print(f'--- t = {t:.2e}')\n",
+ "\n",
+ " # carry out the time step (feed in the voltages and profile parameters)\n",
+ " stepping.nlstepper(\n",
+ " active_voltage_vec=voltages,\n",
+ " linear_only=True,\n",
+ " no_GS=True,\n",
+ " verbose=False,\n",
+ " relinearise_threshold=[None, 1.0, 1.0, 1.0],\n",
+ " )\n",
+ "\n",
+ " # store time-advanced currents and plasma descriptors\n",
+ " history_currents_nogs.append(stepping.currents_vec)\n",
+ " history_plasma_descriptors_nogs.append(stepping.plasma_descriptors_vec)\n",
+ " history_criteria_nogs.append(stepping.relinearise_criteria)\n",
+ "\n",
+ " t += stepping.dt_step\n",
+ " history_times_nogs.append(t)\n",
+ " counter += 1\n",
+ "\n",
+ "# transform lists to arrays\n",
+ "history_currents_nogs = np.array(history_currents_nogs)\n",
+ "history_times_nogs = np.array(history_times_nogs)\n",
+ "history_plasma_descriptors_nogs = np.array(history_plasma_descriptors_nogs)\n",
+ "history_criteria_nogs = np.array(history_criteria_nogs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we'll do the same simulation but this time we will enable re-linearisation when one of the shape parameter changes exceeds a few centimetres."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# reset the solver object by resetting the initial conditions\n",
+ "stepping.dIydI_ICs = None\n",
+ "stepping.dIydtheta_ICs = None\n",
+ "stepping.initialize_from_ICs(eq, profiles)\n",
+ "\n",
+ "counter = 0\n",
+ "t = 0\n",
+ "\n",
+ "# storage\n",
+ "history_times_nogs_rl = [t]\n",
+ "history_currents_nogs_rl = [stepping.currents_vec]\n",
+ "history_plasma_descriptors_nogs_rl = [plasma_descriptors(eq)]\n",
+ "history_criteria_nogs_rl = []"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# simulate\n",
+ "thresholds = [0.01, 0.02, 0.02, None]\n",
+ "\n",
+ "currents = []\n",
+ "profs = []\n",
+ "params = []\n",
+ "while counter < max_count:\n",
+ " print(f'Step: {counter}/{max_count-1}')\n",
+ " print(f'--- t = {t:.2e}')\n",
+ "\n",
+ " # carry out the time step (feed in the voltages and profile parameters)\n",
+ " stepping.nlstepper(\n",
+ " active_voltage_vec=voltages,\n",
+ " linear_only=True,\n",
+ " no_GS=True,\n",
+ " verbose=False,\n",
+ " relinearise_threshold=thresholds,\n",
+ " )\n",
+ "\n",
+ " # store time-advanced currents and plasma descriptors\n",
+ " history_currents_nogs_rl.append(stepping.currents_vec)\n",
+ " history_plasma_descriptors_nogs_rl.append(stepping.plasma_descriptors_vec)\n",
+ " history_criteria_nogs_rl.append(stepping.relinearise_criteria)\n",
+ "\n",
+ " t += stepping.dt_step\n",
+ " history_times_nogs_rl.append(t)\n",
+ " counter += 1\n",
+ "\n",
+ " currents.append(stepping.initial_currents_plasma_descriptor)\n",
+ " profs.append(stepping.initial_profiles_plasma_descriptor) \n",
+ " params.append(stepping.initial_plasma_descriptors)\n",
+ "\n",
+ "# transform lists to arrays\n",
+ "history_currents_nogs_rl = np.array(history_currents_nogs_rl)\n",
+ "history_times_nogs_rl = np.array(history_times_nogs_rl)\n",
+ "history_plasma_descriptors_nogs_rl = np.array(history_plasma_descriptors_nogs_rl)\n",
+ "history_criteria_nogs_rl = np.array(history_criteria_nogs_rl)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, ax = plt.subplots(4,1, figsize=(10, 12), dpi=80)\n",
+ "ax = ax.flatten()\n",
+ "\n",
+ "labels = [\"Zcurrent [m]\", \"Rx [m]\", \"Zx [m]\", \"Rin [m]\"]\n",
+ "\n",
+ "for i in range(4):\n",
+ " ax[i].plot(history_times, history_plasma_descriptors_nl[:, i], 'rx-', label=\"Nonlinear\")\n",
+ " # ax[i].plot(history_times, history_plasma_descriptors[:, i], 'k+-', label=\"Linear (with GS)\")\n",
+ " ax[i].plot(history_times_nogs, history_plasma_descriptors_nogs[:, i], 'm3-', label=\"Linear (without GS)\")\n",
+ " ax[i].plot(history_times_nogs_rl, history_plasma_descriptors_nogs_rl[:, i], 'g.-', label=\"Linear (without GS, with relinearisation)\")\n",
+ " ax[i].set_xlabel(\"Time (s)\")\n",
+ " ax[i].set_ylabel(labels[i])\n",
+ " ax[i].legend()\n",
+ " ax[i].grid()\n",
+ "\n",
+ "fig.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# plot relinearisation criteria evolution\n",
+ "fig, ax = plt.subplots(4,1, figsize=(10, 12), dpi=80)\n",
+ "ax = ax.flatten()\n",
+ "\n",
+ "for i in range(4):\n",
+ " ax[i].plot(history_times_nogs_rl[0:-1], history_criteria_nogs_rl[:, i], 'b.-', label=\"Linear (without GS, with relinearisation)\")\n",
+ " ax[i].plot(history_times_nogs[0:-1], history_criteria_nogs[:, i], 'k+-', label=\"Linear (without GS)\")\n",
+ " ax[i].hlines(thresholds[i], xmin=history_times_nogs_rl[0], xmax=history_times_nogs_rl[-1], linestyle='--', color='k', label=\"Threshold\")\n",
+ " ax[i].set_xlabel(\"Time (s)\")\n",
+ " ax[i].set_ylabel(f\"Relative change in {labels[i]}\")\n",
+ " ax[i].set_yscale('log')\n",
+ " ax[i].legend()\n",
+ " ax[i].grid()\n",
+ "\n",
+ "fig.tight_layout()\n",
+ "\n",
+ "\n",
+ "# fig, axs = plt.subplots(1, 1, figsize=(12, 4), dpi=80)\n",
+ "# # axs.plot(history_times_nogs[0:-1], history_jtor_norm, 'k+-', label=\"linear\")\n",
+ "# axs.plot(history_times_nogs[0:-1], history_criteria_nogs, 'b.-', label=\"relinearised\")\n",
+ "# axs.hlines(threshold, xmin=history_times[0], xmax=history_times[-1], linestyle='--', color='k', label=\"Threshold\")\n",
+ "# axs.set_xlabel('Time')\n",
+ "# axs.set_ylabel(r\"Relative $J_{\\phi}$ norm\")\n",
+ "# axs.set_yscale('log')\n",
+ "# axs.legend()\n",
+ "# axs.grid()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, ax = plt.subplots(12,1, figsize=(10, 24), dpi=80)\n",
+ "ax = ax.flatten()\n",
+ "\n",
+ "labels = eq.tokamak.coil_names[0:12]\n",
+ "\n",
+ "for i in range(len(labels)):\n",
+ " ax[i].plot(history_times, history_currents_nl[:, i], 'rx-', label=\"Nonlinear\")\n",
+ " # ax[i].plot(history_times, history_plasma_descriptors[:, i], 'k+-', label=\"Linear (with GS)\")\n",
+ " ax[i].plot(history_times_nogs, history_currents_nogs[:, i], 'm3-', label=\"Linear (without GS)\")\n",
+ " ax[i].plot(history_times_nogs_rl, history_currents_nogs_rl[:, i], 'g.-', label=\"Linear (without GS, with relinearisation)\")\n",
+ " ax[i].set_xlabel(\"Time (s)\")\n",
+ " ax[i].set_ylabel(labels[i])\n",
+ " ax[i].legend()\n",
+ " ax[i].grid()\n",
+ "\n",
+ "fig.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/example11 - pulse_design_tool.ipynb b/examples/example11 - pulse_design_tool.ipynb
new file mode 100644
index 00000000..b719379e
--- /dev/null
+++ b/examples/example11 - pulse_design_tool.ipynb
@@ -0,0 +1,1547 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Example: Pulse Design Tool\n",
+ "\n",
+ "This notebook demonstrates how to use the **FreeGSNKE Pulse Design Tool** (FPDT). \n",
+ "\n",
+ "The FPDT combines the **evolutive equilibrium solver** introduced in previous notebooks with a **virtual plasma control system (PCS)** that governs the time evolution of the plasma. It enables **closed-loop** plasma scenario and control design studies.\n",
+ "\n",
+ "The virtual PCS generates voltage requests for the active poloidal field coils, which are then passed to the evolutive solver. At present, FPDT is primarily intended for use during the **flat-top phase** of a discharge. While it may function during ramp-up or ramp-down phases, its performance in those regimes is not guaranteed.\n",
+ "\n",
+ "The **controllers** implemented in the virtual PCS are inspired by the MAST-U PCS and allow control of the following quantities:\n",
+ "- Plasma current.\n",
+ "- Geometric shape parameters (derived from the equilibrium).\n",
+ "- Vertical plasma position.\n",
+ "- Coil activation times.\n",
+ "- Coil current and voltage limits.\n",
+ "\n",
+ "For **more detailed information on these controllers and features**, please refer to the **FPDT paper on arXiv**.\n",
+ "\n",
+ "The FPDT can be used to:\n",
+ "- Plan challenging new plasma scenarios.\n",
+ "- Test new control schemes and parameter settings.\n",
+ "- Serve as a training tool for plasma experimentalists.\n",
+ "\n",
+ "In the following sections, we demonstrate how to set up the FPDT for the **controlled evolution of a MAST-U–like plasma**.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Generate a starting equilibrum\n",
+ "\n",
+ "As in prior notebooks, we need a starting equilibrium. \n",
+ "\n",
+ "This equilibrium will be the initial condition from which the FPDT will evolve the plasma."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# package\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import time"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# build machine\n",
+ "from freegsnke import build_machine\n",
+ "tokamak = build_machine.tokamak(\n",
+ " active_coils_path=f\"../machine_configs/MAST-U/MAST-U_like_active_coils.pickle\",\n",
+ " passive_coils_path=f\"../machine_configs/MAST-U/MAST-U_like_passive_coils.pickle\",\n",
+ " limiter_path=f\"../machine_configs/MAST-U/MAST-U_like_limiter.pickle\",\n",
+ " wall_path=f\"../machine_configs/MAST-U/MAST-U_like_wall.pickle\",\n",
+ ")\n",
+ "\n",
+ "\n",
+ "# initialise equilibrium\n",
+ "from freegsnke import equilibrium_update\n",
+ "eq = equilibrium_update.Equilibrium(\n",
+ " tokamak=tokamak,\n",
+ " Rmin=0.1, Rmax=2.0, # Radial range\n",
+ " Zmin=-2.2, Zmax=2.2, # Vertical range\n",
+ " nx=65, # Number of grid points in the radial direction (needs to be of the form (2**n + 1) with n being an integer)\n",
+ " ny=65, # Number of grid points in the vertical direction (needs to be of the form (2**n + 1) with n being an integer)\n",
+ ") \n",
+ "\n",
+ "# initialise profiles\n",
+ "from freegsnke.jtor_update import Lao85\n",
+ "profiles = Lao85(\n",
+ " eq=eq, # equilibrium object\n",
+ " Ip=7.59e5, # plasma current\n",
+ " fvac=-0.52, # fvac = rB_{tor}\n",
+ " alpha=[362685, 17696], # Lao profiles parameters\n",
+ " beta=[0.103, 0.753], # Lao profiles parameters\n",
+ " alpha_logic=True,\n",
+ " beta_logic=True,\n",
+ ")\n",
+ "\n",
+ "# set initial coil currents (passive structure currents start at zero here)\n",
+ "coil_currents = np.array([2766.10782056, 193.12768191, 4019.86873871, 4857.42357466,\n",
+ " -726.99282376, -1696.92521166, -77.7252396 , 266.38207493,\n",
+ " -73.01483716, -4169.86874271, -4518.28147417, 5.2 ])\n",
+ "for i, key in enumerate(tokamak.coil_names[0:12]):\n",
+ " eq.tokamak.set_coil_current(coil_label=key, current_value=coil_currents[i])\n",
+ "\n",
+ "# initialise the static solver\n",
+ "from freegsnke import GSstaticsolver\n",
+ "GSStaticSolver = GSstaticsolver.NKGSsolver(eq) \n",
+ "\n",
+ "# solve for the equilbirium\n",
+ "GSStaticSolver.solve(\n",
+ " eq=eq, \n",
+ " profiles=profiles, \n",
+ " constrain=None, \n",
+ " target_relative_tolerance=1e-8,\n",
+ " )\n",
+ "\n",
+ "# plot the equilbirium\n",
+ "fig1, ax1 = plt.subplots(1, 1, figsize=(4, 8), dpi=70)\n",
+ "ax1.grid(True, which='both')\n",
+ "eq.plot(axis=ax1, show=False)\n",
+ "eq.tokamak.plot(axis=ax1, show=False)\n",
+ "ax1.set_xlim(0.1, 2.15)\n",
+ "ax1.set_ylim(-2.25, 2.25)\n",
+ "plt.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Prepare inputs for the PCS class\n",
+ "\n",
+ "Here, we will set up the inputs (lists and dictionaries) for each of the (internal) controller classes within the virtual PCS (e.g. plasma, shape, virtual circuits). \n",
+ "\n",
+ "We start by defining which coils have which purpose."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# COIL NAMES AND SHAPE TARGET NAMES\n",
+ "\n",
+ "# active coils: must match the order they appear in FreeGSNKE machine description\n",
+ "active_coils = ['Solenoid', 'PX', 'D1', 'D2', 'D3', 'Dp', 'D5', 'D6', 'D7', 'P4', 'P5', 'P6']\n",
+ "\n",
+ "# control coils: a subset of the active coils to be used for plasma and shape parameter control (i.e. we exclude vertical control coil P6 here)\n",
+ "ctrl_coils = ['Solenoid', 'PX', 'D1', 'D2', 'D3', 'Dp', 'D5', 'D6', 'D7', 'P4', 'P5']\n",
+ "\n",
+ "# Ohmic coil (for plasma current control) \n",
+ "solenoid_coils=[\"Solenoid\"]\n",
+ "\n",
+ "# vertical control coil\n",
+ "vertical_coils=[\"P6\"]\n",
+ "\n",
+ "# names of the plasma targets (this will be clear later, equal to the number of solenoid coils)\n",
+ "plasma_targets = ['plasma']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "All of the following inputs are going to be specified as **waveforms**, i.e. they will be a dicitonary of **times** and **vals**. \n",
+ "\n",
+ "For each array/list of **times** (in seconds), we must specify a corresponding list of list/arrays in **vals**. \n",
+ "\n",
+ "Later on, the PCS class will take these point values and interpolate them (linearly) internally, for querying at a given time $t$ within the simulation much later on. \n",
+ "\n",
+ "The following cell provides an example of how to set a generic waveform for a scalar quantity vs. time (for example, this could be the desired radial position of the X-point over time). "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# waveform dictionary structure\n",
+ "waveform = {\n",
+ " \"times\": np.array([0.0, 0.05, 0.15, 0.35, 0.45, 0.5]),\n",
+ " \"vals\": np.array([0.57, 0.57, 0.5, 0.5, 0.57, 0.57]),\n",
+ "}\n",
+ "\n",
+ "# plot what it looks like \n",
+ "fig, ax = plt.subplots(\n",
+ " nrows=1,\n",
+ " ncols=1,\n",
+ " figsize=(8, 4),\n",
+ " dpi=80\n",
+ ")\n",
+ "\n",
+ "ax.plot(waveform['times'], waveform['vals'], color='k', linewidth=1, linestyle=\"--\", marker=\"x\", markersize=7, label=\"waveform\")\n",
+ "\n",
+ "ax.set_xlabel(r\"Shot time [$s$]\")\n",
+ "ax.set_ylabel(\"Scalar quantity [units]\")\n",
+ "ax.grid()\n",
+ "ax.legend()\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Before we define each of the controller settings, let us define some global times, i.e. the start and end of the simulation say. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# choose some min amd max simulation times\n",
+ "tmin = 0.0\n",
+ "tmax = 0.5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "First, we define the desired control settings for the **plasma category**.\n",
+ "\n",
+ "These waveforms govern how the plasma current, $I_p$, is controlled—either via feedback (FB), feedforward (FF), or a combination of both.\n",
+ "\n",
+ "The following quantities must be specified:\n",
+ "\n",
+ "- **`ip_ref`**: Plasma current feedback reference waveform \\[A\\].\n",
+ "- **`vloop_ff`**: Loop voltage feedforward reference waveform \\[V·s⁻¹\\].\n",
+ "- **`blend`**: Control blending parameter [dimensionless]:\n",
+ " - `1` → purely FB control \n",
+ " - `0` → purely FF control \n",
+ " - values in `(0, 1)` → mixed FB/FF control.\n",
+ "- **`k_prop`**: Proportional gain waveform for the FB PID controller \\[s⁻¹\\].\n",
+ "- **`k_int`**: Integral gain waveform for the FB PID controller \\[s⁻²\\].\n",
+ "- **`M_solenoid`**: Mutual inductance between the plasma and solenoid (required for FF control) \\[V·s·A⁻¹\\].\n",
+ "\n",
+ "\n",
+ "In this particular example, we will tell the PCS to hold the plasma current constant at it's starting value. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# PLASMA CATEGORY DATA\n",
+ "\n",
+ "# build the dictionary which will be passed to the virtual PCS (all keys below are required)\n",
+ "plasma_data = {}\n",
+ "\n",
+ "# plasma current feedback (FB) reference waveform \n",
+ "# here we hold the value for the initial equilibrium constant\n",
+ "plasma_data[\"ip_ref\"] = {\n",
+ " 'times': np.array([tmin, tmax]),\n",
+ " 'vals': np.array([eq._profiles.Ip, eq._profiles.Ip])\n",
+ " }\n",
+ "\n",
+ "# loop voltage feedforward (FF) reference waveform (can specify the desired loop voltage on the plasma)\n",
+ "# (here we don't use a loop voltage as we will use the FB reference waveform above)\n",
+ "plasma_data[\"vloop_ff\"] = {\n",
+ " 'times': np.array([tmin, tmax]),\n",
+ " 'vals': np.array([0, 0])\n",
+ " }\n",
+ "\n",
+ "# blending waveform: tells controller to use purely FB control (1), purely FF control (0), or a mixture (between 0 and 1)\n",
+ "# we are doing pure FB control here\n",
+ "plasma_data[\"ip_blend\"] = {\n",
+ " 'times': np.array([tmin, tmax]),\n",
+ " 'vals': np.array([1, 1])\n",
+ " }\n",
+ "\n",
+ "# proportional gain for the FB PID controller\n",
+ "plasma_data[\"k_prop\"] = {\n",
+ " 'times': np.array([tmin, tmax]),\n",
+ " 'vals': np.array([-5*4, -5*4])\n",
+ " }\n",
+ "\n",
+ "# integral gain for the FB PID controller\n",
+ "plasma_data[\"k_int\"] = {\n",
+ " 'times': np.array([tmin, tmax]),\n",
+ " 'vals': np.array([-50*2, -50*2])\n",
+ " }\n",
+ "\n",
+ "# mutual inductance between plasma and solenoid: required for FF control\n",
+ "# not used here as we do not set a loop voltage\n",
+ "plasma_data[\"M_solenoid\"] = {\"times\": np.array([0]), \"vals\": np.array([1.0])}\n",
+ "\n",
+ "plasma_data.keys()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we define the desired control settings for the **shape category**.\n",
+ "\n",
+ "First, we define a function that extracts the relevant shape parameters from the equilibrium object (here referred to as `plasma_descriptors`). In this example, we extract the following quantities:\n",
+ "\n",
+ "- **`Rin`**: Inboard midplane radius.\n",
+ "- **`Rout`**: Outboard midplane radius.\n",
+ "- **`Rx`**: Radial position of the lower X-point.\n",
+ "- **`Zx`**: Vertical position of the lower X-point.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# define the descriptors function (it should return an array of values and take in an eq object)\n",
+ "# outputs (\"measurements\") from this function will be passed to the PCS later on during simulation\n",
+ "def plasma_descriptors(eq):\n",
+ "\n",
+ " # inboard/outboard midplane radii\n",
+ " RinRout = eq.innerOuterSeparatrix()\n",
+ "\n",
+ " # find lower X-point\n",
+ " # define a \"box\" in which to search for the lower X-point\n",
+ " XPT_BOX = [[0.33, -0.88], [0.95, -1.38]]\n",
+ "\n",
+ " # mask those points\n",
+ " xpt_mask = (\n",
+ " (eq.xpt[:, 0] >= XPT_BOX[0][0])\n",
+ " & (eq.xpt[:, 0] <= XPT_BOX[1][0])\n",
+ " & (eq.xpt[:, 1] <= XPT_BOX[0][1])\n",
+ " & (eq.xpt[:, 1] >= XPT_BOX[1][1])\n",
+ " )\n",
+ " xpts = eq.xpt[xpt_mask, 0:2].squeeze()\n",
+ " if xpts.ndim > 1 and xpts.shape[0] > 1:\n",
+ " opt = eq.opt[0, 0:2]\n",
+ " dists = np.linalg.norm(xpts - opt, axis=1)\n",
+ " idx = np.argmin(dists) # index of closest point\n",
+ " Rx, Zx = xpts[idx, :]\n",
+ " else:\n",
+ " Rx, Zx = xpts\n",
+ "\n",
+ " return np.array([RinRout[0], RinRout[1], Rx, Zx])\n",
+ "\n",
+ "# give these descriptors some names for use in the PCS\n",
+ "ctrl_targets = ['Rin','Rout', 'Rx', 'Zx']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we specify the waveforms used to control these shape parameters throughout the simulation. \n",
+ "\n",
+ "Not all shape parameters need to be defined in units of metres—these are simply illustrative examples.\n",
+ "\n",
+ "For each shape parameter, the following waveforms must be specified:\n",
+ "\n",
+ "- **`ff`**: Feedforward (FF) waveform. \n",
+ " Not used in this example, but can be employed to impose a fixed offset on the parameter \\[m\\].\n",
+ "\n",
+ "- **`ref`**: Feedback (FB) reference waveform. \n",
+ " These are the desired target values of the parameter over the course of the simulation \\[m\\].\n",
+ "\n",
+ "- **`blend`**: Control blending waveform specifying the relative contribution of FF and FB control [dimensionless]:\n",
+ " - `0` → purely FF control \n",
+ " - `1` → purely FB control \n",
+ " - values in `(0, 1)` → mixed FF/FB control. \n",
+ "\n",
+ "- **`k_prop`**: Proportional gain waveform for the FB PID controller \\[s⁻¹\\].\n",
+ "\n",
+ "- **`k_int`**: Integral gain waveform for the FB PID controller (not used here) \\[s⁻²\\].\n",
+ "\n",
+ "- **`damping`**: Damping waveform (not used here) [dimensionless]. \n",
+ "\n",
+ "In the following, we tell will tell the controllers to:\n",
+ "\n",
+ "- hold both **`Rin`** and **`Rout`** constant. \n",
+ "- hold **`Rx`** constant for a short period, before ramping down (to increase plasma triangularity), then ramp back up to the original value.\n",
+ "- allow **`Zx`** to be uncontrolled. \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# SHAPE/DIVERTOR CATEGORY\n",
+ "\n",
+ "# build the dictionary which will be passed to the virtual PCS (all keys below are required)\n",
+ "shape_data = {}\n",
+ "\n",
+ "# waveforms for Rin\n",
+ "shape_data[\"Rin\"] = {}\n",
+ "shape_data[\"Rin\"][\"ff\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])} \n",
+ "shape_data[\"Rin\"][\"ref\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.28, 0.28])} \n",
+ "shape_data[\"Rin\"][\"blend\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([1.0, 1.0])}\n",
+ "shape_data[\"Rin\"][\"k_prop\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([120, 120])}\n",
+ "shape_data[\"Rin\"][\"k_int\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])}\n",
+ "shape_data[\"Rin\"][\"damping\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([1.0, 1.0])}\n",
+ "\n",
+ "# waveforms for Rout\n",
+ "shape_data[\"Rout\"] = {}\n",
+ "shape_data[\"Rout\"][\"ff\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])} \n",
+ "shape_data[\"Rout\"][\"ref\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([1.35, 1.35])} \n",
+ "shape_data[\"Rout\"][\"blend\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([1.0, 1.0])}\n",
+ "shape_data[\"Rout\"][\"k_prop\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([170, 170])}\n",
+ "shape_data[\"Rout\"][\"k_int\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])}\n",
+ "shape_data[\"Rout\"][\"damping\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([1.0, 1.0])}\n",
+ "\n",
+ "# waveforms for Rx\n",
+ "shape_data[\"Rx\"] = {}\n",
+ "shape_data[\"Rx\"][\"ff\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])} \n",
+ "shape_data[\"Rx\"][\"ref\"] = {\"times\": np.array([tmin, 0.05, 0.15, 0.35, 0.45, tmax]), \"vals\": np.array([0.57, 0.57, 0.5, 0.5, 0.57, 0.57])} # we vary Rx\n",
+ "shape_data[\"Rx\"][\"blend\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([1.0, 1.0])}\n",
+ "shape_data[\"Rx\"][\"k_prop\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([120, 120])}\n",
+ "shape_data[\"Rx\"][\"k_int\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])}\n",
+ "shape_data[\"Rx\"][\"damping\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([1.0, 1.0])}\n",
+ "\n",
+ "# waveforms for Zx\n",
+ "shape_data[\"Zx\"] = {}\n",
+ "shape_data[\"Zx\"][\"ff\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])} \n",
+ "shape_data[\"Zx\"][\"ref\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([-1.24, -1.24])} \n",
+ "shape_data[\"Zx\"][\"blend\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])}\n",
+ "shape_data[\"Zx\"][\"k_prop\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([120, 120])}\n",
+ "shape_data[\"Zx\"][\"k_int\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])}\n",
+ "shape_data[\"Zx\"][\"damping\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([1.0, 1.0])}\n",
+ "\n",
+ "shape_data.keys()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we define the schedule of virtual circuits (VCs) in the **virtual circuits category**.\n",
+ "\n",
+ "For each shape parameter in the **shape category**, a VC is required to map the discrepancy between the target value (from the shape controller) and the measured value (from the equilibrium) to requests for the poloidal field (PF) coil currents. The VCs used here were obtained using the procedure described in the example notebook on VC construction. Recall that VCs are defined with units \\[m/A\\].\n",
+ "\n",
+ "Each VC is an array with length equal to the number of `ctrl_coils`, with entries ordered consistently with that list. When specified in the waveform dictionary, VCs are **previous-value interpolated** by the PCS class. For example, for the `Rin` controller, the VC defined at time `tmin` is used over the interval $ [t_\\mathrm{min},\\, t_\\mathrm{max})$, after which the VC defined at `tmax` is applied.\n",
+ "\n",
+ "Note that the first entry of each shape-parameter VC is zero. This ensures that the solenoid does not contribute to shape control and is reserved exclusively for plasma current, $I_p$, control.\n",
+ "\n",
+ "In addition to the shape VCs, a final VC is defined for the **plasma category**. This VC maps the requested change in $I_p$ from the plasma controller to PF coil current requests (well in this case, only a request to the solenoid).\n",
+ "\n",
+ "Finally, optional pre-programmed feedforward (FF) requests for the PF coil currents may be specified. These allow the coil currents to be set directly if desired. In this example, no such FF requests are used, and the coil currents are driven purely via feedback control.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# VIRTUAL CIRCUITS CATEGORY\n",
+ "\n",
+ "# build the dictionary which will be passed to the virtual PCS (all keys below are required)\n",
+ "circuits_data = {}\n",
+ "\n",
+ "circuits_data[\"Rin\"] = {\n",
+ " \"times\": np.array([tmin, tmax]), \n",
+ " \"vals\": [\n",
+ " np.array([0, 3.1394e+04, 8.5990e+03, -8.8300e+02, -1.2460e+03, -6.6690e+03, 1.2660e+03, 2.9770e+03, 3.5350e+03, 8.3110e+03, -7.8520e+03]),\n",
+ " np.array([0, 3.1394e+04, 8.5990e+03, -8.8300e+02, -1.2460e+03, -6.6690e+03, 1.2660e+03, 2.9770e+03, 3.5350e+03, 8.3110e+03, -7.8520e+03]),\n",
+ " ]\n",
+ " }\n",
+ "\n",
+ "circuits_data[\"Rout\"] = {\n",
+ " \"times\": np.array([tmin, tmax]), \n",
+ " \"vals\": [\n",
+ " np.array([0, -9.9700e+02, 1.1790e+03, 3.9700e+02, -2.0000e+02, -1.8970e+03, -5.5500e+02, -2.2640e+03, -1.5860e+03, -1.9750e+03, 5.5800e+03]),\n",
+ " np.array([0, -9.9700e+02, 1.1790e+03, 3.9700e+02, -2.0000e+02, -1.8970e+03, -5.5500e+02, -2.2640e+03, -1.5860e+03, -1.9750e+03, 5.5800e+03]),\n",
+ " ]\n",
+ " }\n",
+ "\n",
+ "circuits_data[\"Rx\"] = {\n",
+ " \"times\": np.array([tmin, tmax]), \n",
+ " \"vals\": [\n",
+ " np.array([0, -3.0021e+04, 3.3980e+03, 8.6900e+03, 5.8310e+03, 1.5858e+04, 2.4000e+01, -1.2100e+02, -2.2160e+03, -9.5290e+03, 1.3570e+03]),\n",
+ " np.array([0, -3.0021e+04, 3.3980e+03, 8.6900e+03, 5.8310e+03, 1.5858e+04, 2.4000e+01, -1.2100e+02, -2.2160e+03, -9.5290e+03, 1.3570e+03]),\n",
+ " ]\n",
+ " }\n",
+ "\n",
+ "circuits_data[\"Zx\"] = {\n",
+ " \"times\": np.array([tmin, tmax]), \n",
+ " \"vals\": [\n",
+ " np.array([0, 3.7670e+03, 2.0328e+04, 1.0562e+04, 4.2560e+03, 2.2600e+03, -2.0450e+03, -5.5680e+03, -5.3160e+03, -9.8440e+03, 3.3750e+03]),\n",
+ " np.array([0, 3.7670e+03, 2.0328e+04, 1.0562e+04, 4.2560e+03, 2.2600e+03, -2.0450e+03, -5.5680e+03, -5.3160e+03, -9.8440e+03, 3.3750e+03]),\n",
+ " ]\n",
+ " }\n",
+ "\n",
+ "circuits_data[\"plasma\"] = {\n",
+ " \"times\": np.array([tmin, tmax]), \n",
+ " \"vals\": [\n",
+ " np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
+ " np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
+ " ]\n",
+ " }\n",
+ "\n",
+ "# store the coil order for future reference\n",
+ "circuits_data[\"coil_order\"] = ctrl_coils\n",
+ "\n",
+ "# define any feedforward coil current drives on each control coil (no drives used here)\n",
+ "zeros_dict = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.0, 0.0])}\n",
+ "for coil in ctrl_coils: # linearly interpolated\n",
+ " circuits_data[coil+\"_ref\"] = zeros_dict\n",
+ "\n",
+ "circuits_data.keys()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we define the inputs for the **systems category**.\n",
+ "\n",
+ "This controller specifies operational limits for the coils listed in `ctrl_coils`, including:\n",
+ "\n",
+ "- **`min_coil_curr_lims`**: Minimum allowable coil currents \\[A\\].\n",
+ "- **`max_coil_curr_lims`**: Maximum allowable coil currents \\[A\\].\n",
+ "- **`max_coil_curr_ramp_lims`**: Maximum allowable coil current ramp rates \\[A·s⁻¹\\].\n",
+ "\n",
+ "There is also the option to define feedforward (FF) perturbations to the control coil currents—similar to the FF drives available in the **virtual circuits category**. These are not used here.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# SYSTEMS CATEGORY\n",
+ "\n",
+ "# build the dictionary which will be passed to the virtual PCS (all keys below are required)\n",
+ "systems_data = {}\n",
+ "\n",
+ "# define the min coil current limits and ramp rate limits\n",
+ "systems_data[\"min_coil_curr_lims\"] = {\n",
+ " 'times': [-2.6],\n",
+ " 'vals': [np.array([-10000, -7000, -9000, -9000, -9000, -9000, -9000, -9000, -9000, -10000, -10000])],\n",
+ " }\n",
+ "\n",
+ "# define the max coil current limits\n",
+ "systems_data[\"max_coil_curr_lims\"] = {\n",
+ " 'times': [-2.6],\n",
+ " 'vals': [np.array([10000, 7000, 9000, 9000, 9000, 9000, 9000, 9000, 9000, 0, 0])],\n",
+ " }\n",
+ "\n",
+ "# define the max ramp rate limits for the coils\n",
+ "systems_data[\"max_coil_curr_ramp_lims\"] = {\n",
+ " 'times': [-2.6],\n",
+ " 'vals': [1e12*np.ones(len(ctrl_coils))],\n",
+ " }\n",
+ "\n",
+ "# define the control coil current perturbations\n",
+ "for name in ctrl_coils: # linearly interpolated\n",
+ " systems_data[name+\"_pert\"] = zeros_dict\n",
+ "\n",
+ "systems_data.keys()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we define the inputs for the **PF category**.\n",
+ "\n",
+ "This category specifies the electrical properties, gains, and operational limits of the poloidal field (PF) coils. These inputs are used by the virtual PCS to convert coil current requests into physically consistent voltage commands.\n",
+ "\n",
+ "The following quantities must be provided:\n",
+ "\n",
+ "- **`R_matrix`**: Coil resistance array for the PF coils \\[Ω\\]. \n",
+ " Defined as a time-dependent waveform and restricted to the coils listed in `ctrl_coils`.\n",
+ "\n",
+ "- **`M_FF_matrix`**: Coil mutual inductance matrix used in feedforward (FF) terms \\[Vs/A\\]. \n",
+ " This accounts for inductive coupling between PF coils when computing FF voltage requests.\n",
+ "\n",
+ "- **`M_FB_matrix`**: Coil mutual inductance matrix used in feedback (FB) terms \\[Vs/A\\].\n",
+ " This accounts for inductive coupling between PF coils when computing FF voltage requests. Often it is identical to `M_FF_matrix`, but is provided separately for flexibility.\n",
+ "\n",
+ "- **`coil_gains`**: Feedback gain applied to each PF coil \\[s\\]. \n",
+ " These gains scale the feedback voltage contributions on a per-coil basis.\n",
+ "\n",
+ "- **`coil_voltage_lims`**: Maximum allowable voltages for each PF coil \\[V\\]. \n",
+ " These limits are enforced by the PCS when generating coil voltage commands.\n",
+ "\n",
+ "- **`coil_voltage_slew_lims`**: Maximum allowable voltage ramp rates for each PF coil \\[V·s⁻¹\\]. \n",
+ " In this example, very large values are used, effectively disabling slew-rate limiting.\n",
+ "\n",
+ "All quantities are defined as time-dependent waveforms, even when constant, to allow future extension to time-varying PF system parameters.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# PF CATEGORY\n",
+ "\n",
+ "# build the dictionary which will be passed to the virtual PCS (all keys below are required)\n",
+ "pf_data = {}\n",
+ "\n",
+ "# coil resistances\n",
+ "pf_data[\"R_matrix\"] = {\n",
+ " 'times': [0.0],\n",
+ " 'vals': [tokamak.coil_resist[0:11]],\n",
+ " }\n",
+ "\n",
+ "# coil mutual inductances (on feedforward terms)\n",
+ "pf_data[\"M_FF_matrix\"] = {\n",
+ " 'times': [0.0],\n",
+ " 'vals': [tokamak.coil_self_ind[0:11,0:11]],\n",
+ " }\n",
+ "\n",
+ "# coil mutual inductances (on feedback terms)\n",
+ "pf_data[\"M_FB_matrix\"] = {\n",
+ " 'times': [0.0],\n",
+ " 'vals': [tokamak.coil_self_ind[0:11,0:11]],\n",
+ " }\n",
+ "\n",
+ "\n",
+ "# gains on the coils (for feedback term)\n",
+ "pf_data[\"coil_gains\"] = {\n",
+ " 'times': [0.0],\n",
+ " 'vals': [0.015*np.ones(len(ctrl_coils))],\n",
+ " }\n",
+ "\n",
+ "# limits on voltages in the coils\n",
+ "pf_data[\"coil_voltage_lims\"] = {\n",
+ " \"times\": np.array([-2.6]),\n",
+ " 'vals': [[2000, 750, 750, 750, 750, 750, 750, 750, 750, 750, 750]],\n",
+ " } \n",
+ "\n",
+ "# limits on voltage ramp rates in the coils (we set no limits here)\n",
+ "pf_data[\"coil_voltage_slew_lims\"] = {\n",
+ " 'times': [-2.6],\n",
+ " 'vals': [1e12*np.ones(len(ctrl_coils))],\n",
+ " }\n",
+ "\n",
+ "pf_data.keys()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we define the inputs for the **vertical category**.\n",
+ "\n",
+ "This category configures the feedback controller responsible for stabilising and controlling the vertical position of the plasma column using the vertical control coil (here the `P6` coil). \n",
+ "\n",
+ "The following quantities must be specified:\n",
+ "\n",
+ "- **`z_ref`**: Plasma vertical position reference waveform \\[m\\]. \n",
+ " This defines the desired vertical position of the plasma as a function of time.\n",
+ "\n",
+ "- **`k_prop`**: Proportional gain waveform for the vertical position feedback controller \\[s⁻¹\\]. \n",
+ " This gain determines the strength of the restoring force in response to vertical displacement.\n",
+ "\n",
+ "- **`k_deriv`**: Derivative gain waveform for the vertical position feedback controller \\[s⁻²\\].\n",
+ " This term provides damping by responding to the rate of change of the vertical position.\n",
+ "\n",
+ "All inputs are defined as time-dependent waveforms, allowing for time-varying vertical control behaviour if required.\n",
+ "\n",
+ "In this example, we steadily increase the vertical position to before holding it constant thereafter. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# VERTICAL CATEGORY DATA\n",
+ "\n",
+ "# build the dictionary which will be passed to the virtual PCS (all keys below are required)\n",
+ "vertical_data = {}\n",
+ "\n",
+ "# plasma vertical position reference\n",
+ "vertical_data[\"z_ref\"] = {\"times\": np.array([tmin, tmin+0.25, tmax]), \"vals\": np.array([0.0, 0.01, 0.01])}\n",
+ "\n",
+ "# proportional gain\n",
+ "vertical_data[\"k_prop\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([0.025, 0.025])}\n",
+ "\n",
+ "# derivative gain\n",
+ "vertical_data[\"k_deriv\"] = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([5e-7, 5e-7])}\n",
+ "\n",
+ "vertical_data.keys()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finally, we define the inputs for the **coil activations category**.\n",
+ "\n",
+ "This category specifies when each active poloidal field (PF) coil is enabled or disabled during the simulation. Coil activations are represented as time-dependent waveforms and are applied on a per-coil basis.\n",
+ "\n",
+ "For each coil listed in `active_coils`, the following entry is defined:\n",
+ "\n",
+ "- **`_activation`**: Coil activation waveform. \n",
+ " A value of:\n",
+ " - `1.0` indicates that the coil is active and available to the PCS.\n",
+ " - `0.0` indicates that the coil is inactive and excluded from PCS.\n",
+ "\n",
+ "In this example, all coils are activated for the entire simulation interval $[t_\\mathrm{min},\\, t_\\mathrm{max}]$, using constant activation waveforms. This structure allows coils to be selectively enabled or disabled at different times in more advanced scenarios.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# COIL ACTIVATIONS CATEGORY\n",
+ "\n",
+ "# build the dictionary which will be passed to the virtual PCS (all keys below are required)\n",
+ "coil_activation_data = {}\n",
+ "\n",
+ "# coil activation times\n",
+ "default_coil_dict = {\"times\": np.array([tmin, tmax]), \"vals\": np.array([1.0, 1.0])}\n",
+ "for name in active_coils:\n",
+ " coil_activation_data[name+\"_activation\"] = default_coil_dict\n",
+ "\n",
+ "coil_activation_data.keys()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Initialise the virtual PCS class\n",
+ "\n",
+ "Now that the inputs have been prepared we can initialise the class and use in-built functions to view or plot the waveform data in each control category. This will build internal classes for each of the above listed controllers. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# FIRE UP THE PLASMA CONTROL SYSTEM\n",
+ "from freegsnke.control_loop.pcs import PlasmaControlSystem\n",
+ "\n",
+ "PCS = PlasmaControlSystem(\n",
+ " plasma_data=plasma_data,\n",
+ " shape_data=shape_data,\n",
+ " circuits_data=circuits_data,\n",
+ " systems_data=systems_data,\n",
+ " pf_data=pf_data,\n",
+ " vertical_data=vertical_data,\n",
+ " coil_activation_data=coil_activation_data,\n",
+ " active_coils=active_coils,\n",
+ " ctrl_coils=ctrl_coils,\n",
+ " solenoid_coils=solenoid_coils,\n",
+ " vertical_coils=vertical_coils,\n",
+ " ctrl_targets=ctrl_targets,\n",
+ " plasma_target=plasma_targets,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# each controller should contain a copy of the waveform dictionaries internally\n",
+ "PCS.PlasmaController.data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# it will also contain a dictionary of functions that has linearly (or previous value) interpolated the various waveforms\n",
+ "PCS.PlasmaController.interpolants"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# # it should also be possible to plot the data (uncomment following cell)\n",
+ "# # green shading indiciates times that FB control is ON, yellow that FF is ON, a mix that both are ON, and white that there is no control\n",
+ "# PCS.PlasmaController.plot_data(tmin=tmin, tmax=tmax)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# the \"run_control\" method will be used internally later on but can also be called explicitly if required\n",
+ "PCS.PlasmaController.run_control"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finally, suppose you wish to modify the shot setup data after initialising the PCS class. \n",
+ "\n",
+ "To do this, simply edit the waveforms required directly in the `.data` attribute of the controller of interest and then call `.update_interpolants()`. This has to be done to ensure that the controller refreshes any stale interpolants from a prior initialisation.\n",
+ "\n",
+ "Uncomment the code below to see how it works for an example in the `ShapeController`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# # update the data entry with your new waveform\n",
+ "# new_waveform = {\"times\": np.array([tmin, 0.05, 0.15, 0.35, 0.45, tmax]), \"vals\": np.array([0.28, 0.28, 0.30, 0.30, 0.28, 0.28])} # we vary Rin\n",
+ "# PCS.ShapeController.data[\"Rin\"][\"ref\"] = new_waveform\n",
+ "\n",
+ "# # call the update function to refresh inteprolants (essential)\n",
+ "# PCS.ShapeController.update_interpolants()\n",
+ "\n",
+ "# # plot to see new waveform\n",
+ "# PCS.ShapeController.plot_data(targ=\"Rin\", tmin=tmin, tmax=tmax)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Initialise the nonlinear solver object\n",
+ "Now choose the desired settings for the nonlinear solver object - recall the prior example notebook on this. \n",
+ "\n",
+ "The simulation timestep will be modified later on but be sure to choose a value that is 5-10x smaller than the vertical instability timescale returned in the output of this cell - this keeps the simulation numerically stable. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from freegsnke import nonlinear_solve\n",
+ "\n",
+ "stepping = nonlinear_solve.nl_solver(\n",
+ " eq=eq, \n",
+ " profiles=profiles, \n",
+ " GSStaticSolver=GSStaticSolver,\n",
+ " full_timestep=5e-4, \n",
+ " plasma_resistivity=1e-7,\n",
+ " fix_n_vessel_modes=30, \n",
+ " plasma_descriptor_function=plasma_descriptors,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Define FPDT simulation parameters\n",
+ "\n",
+ "Next set the key simulation parameters for the FPDT."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# FPDT SETUP\n",
+ "\n",
+ "# number of simulation time steps\n",
+ "n = 500\n",
+ "\n",
+ "# simulation time step size (must be smaller than the vertical instability timescale)\n",
+ "dt = 1e-3\n",
+ "\n",
+ "# PCS time step (this must be whole fraction of the simulation time step)\n",
+ "# it governs the frequency at which the PCS is called\n",
+ "dt_PCS = 1e-4\n",
+ "\n",
+ "# starting time (leave as it is)\n",
+ "tmin = 0.0\n",
+ "\n",
+ "# automatically calculates time array and sets time step in solver object\n",
+ "stepping.dt_step = dt\n",
+ "t_end = tmin + n*dt\n",
+ "times = np.arange(tmin, t_end, dt)\n",
+ "\n",
+ "# (re-)initialise the dynamic solver with the initial eq and profiles\n",
+ "stepping.initialize_from_ICs(eq, profiles)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Before moving forward, we note that there are two time-dependent quantities that have not been set explicitly here. In this example, we hold the:\n",
+ "\n",
+ "- plasma resistivity constant.\n",
+ "- plasma current density profiles constant.\n",
+ "\n",
+ "For modelling real plasma discharges, both of these quantities should passed as time-dependent inputs to the time-stepping loop below. For further details, see the earlier example notebooks where time-dependent resistivity and profile evolution are demonstrated.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### FPDT Simulation\n",
+ "\n",
+ "In the following cell, we initialise lists/arrays to store any equilbirium related data we wish to view after the simulation. See the FreeGSNKE example notebook (example03 - extracting_equilibrium_quantites) for a list of these and how to extract them. \n",
+ "\n",
+ "Note that some quantities can be computationally costly to extract and have not been optimised for speed yet!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# equilibrium-related data storage\n",
+ "dynamic_psi = np.zeros((stepping.eq1.psi().shape[0], stepping.eq1.psi().shape[1], len(times))) # total poloidal flux\n",
+ "dynamic_limiter_flag = np.zeros(len(times)) # flag if plasma is limited\n",
+ "dynamic_psi_boundary = np.zeros(len(times)) # poloidal flux on plasma boundary\n",
+ "dynamic_xpts = [] # list of X-point locations and associated poloidal flux\n",
+ "dynamic_opts = [] # list of O-point locations and associated poloidal flux\n",
+ "dynamic_currents = np.zeros((len(stepping.currents_vec[:-1]),len(times))) # PF coil and vessel eigenmode currents \n",
+ "dynamic_ip = np.zeros(len(times)) # total plasma current\n",
+ "dynamic_shape_targets = np.zeros((len(ctrl_targets),len(times))) # shape parameters\n",
+ "dynamic_timings = np.zeros(len(times)) # solver runtime at each time step\n",
+ "dynamic_triangularity = np.zeros(len(times)) # plasma triangularity\n",
+ "\n",
+ "# PCS-related data storage\n",
+ "V_approved = np.zeros((len(ctrl_coils)+1,len(times))) # final PF coil voltages from the PCS class (passed into solver)\n",
+ "ip_hist = np.zeros(len(times)) # integral term feeding into Plasma Category PID FB controller (at prior time step)\n",
+ "T_err = np.zeros((len(ctrl_targets),len(times))) # proportional term feeding into Shape Category PID FB controller (at prior time step)\n",
+ "T_hist = np.zeros((len(ctrl_targets),len(times))) # integral term feeding into Shape Category PID FB controller (at prior time step)\n",
+ "I_approved = np.zeros((len(ctrl_coils),len(times))) # approved PF coil currents from prior time step feeding into System Category\n",
+ "coil_resists = np.zeros((len(active_coils),len(times))) # PF coil resistances (to tell solver if a coil is switched on or off)\n",
+ "z_current = [] # vertical position of the plasma (average jtor position)\n",
+ "dynamic_jtor_norm = [] # norm change in jtor between time steps (used to trigger relinearisation)\n",
+ "threshold = None # relative jtor threshold (since last linearisation) above which relinearisation is triggered (here not used)\n",
+ "\n",
+ "# extract any initial values from the initial equilibrium\n",
+ "dynamic_psi[:,:,0] = stepping.eq1.psi()\n",
+ "dynamic_limiter_flag[0] = stepping.eq1._profiles.flag_limiter\n",
+ "dynamic_psi_boundary[0] = stepping.eq1._profiles.psi_bndry\n",
+ "dynamic_xpts.append(stepping.eq1.xpt)\n",
+ "dynamic_opts.append(stepping.eq1.opt)\n",
+ "dynamic_currents[:,0] = stepping.currents_vec[:-1].copy()\n",
+ "dynamic_ip[0] = stepping.profiles1.Ip\n",
+ "dynamic_shape_targets[:,0] = plasma_descriptors(eq=stepping.eq1)\n",
+ "dynamic_triangularity[0] = eq.triangularity()\n",
+ "z_current.append(stepping.eq1.Zcurrent())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Before running the following cell (which takes a few minutes), be sure to familiarise yourself with all of the steps.\n",
+ "\n",
+ "During each time step:\n",
+ "\n",
+ "1. The primary method in the `PCS` class, `calculate_ctrl_voltages`, is called. Using “measurements” from the current equilibrium (plasma current, coil currents, and shape parameters), it computes the voltages to be applied to the evolutive solver, as well as the coil resistances (to determine whether any coils are switched off). These voltages enact the plasma control.\n",
+ "\n",
+ "2. The plasma current density profile parameters are placed into a dictionary for use by the evolutive solver. \n",
+ " *Note:* they are constant here, but they can be made time-dependent at this stage.\n",
+ "\n",
+ "3. The evolutive solver is then called using these parameters (along with additional ones). At this stage, several choices can be made:\n",
+ " - Specify a time-dependent plasma resistivity (usually required for resimulation of prior discharges).\n",
+ " - Select, via `linear_only`, either a linear or fully nonlinear solution of the circuit, plasma, and Grad–Shafranov equations.\n",
+ " - Set a relinearisation threshold when using linear mode (see example notebook 05c).\n",
+ " - Choose not to solve the Grad–Shafranov equation fully (whenusing linear mode), but instead solve only for the shape parameters specified in `plasma_descriptors` (see example notebooks 05b and 05c).\n",
+ "\n",
+ "4. The required data are stored following successful completion of the time step.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# RUN FPDT SIMULATION\n",
+ "for i, t in enumerate(times[0:-1]):\n",
+ " print(\"-----\")\n",
+ " print(f\"t = {np.round(t,5)}s (step {i+1}/{n-1})\")\n",
+ " \n",
+ " # start timer\n",
+ " start_time = time.time()\n",
+ "\n",
+ " # initialise any historical quantities for PCS PID controllers\n",
+ " if i == 0:\n",
+ " ip_hist_prev = 0.0\n",
+ " T_err_prev = np.zeros(len(ctrl_targets))\n",
+ " T_hist_prev = np.zeros(len(ctrl_targets))\n",
+ " I_approved_prev = eq.tokamak.getCurrentsVec()[0:11] # must use starting currents here\n",
+ " V_approved_prev = np.zeros(len(ctrl_coils))\n",
+ " zipv_meas = 0.0\n",
+ " else:\n",
+ " ip_hist_prev=ip_hist[i-1].copy() # integral term from FB PID Plasma controller at prior step\n",
+ " T_err_prev=T_err[:,i-1].copy() # proportional term from FB PID Shape controller at prior step\n",
+ " T_hist_prev=T_hist[:,i-1].copy() # integral term from FB PID Shape controller at prior step\n",
+ " I_approved_prev=I_approved[:,i-1].copy() # approved PF coil currents from prior step (ctrl_coils only)\n",
+ " V_approved_prev=V_approved[0:-1,i-1].copy() # approved PF coil voltages from prior step (ctrl_coils only)\n",
+ " zipv_meas = ((z_current[-1]-z_current[-2])/dt)*dynamic_ip[i] # rate of change of vertical plasma position \n",
+ " \n",
+ " # call the PCS class to attain PF coil voltages (on ctrl_coils and vertical coil)\n",
+ " V_approved[:,i], ip_hist[i], T_err[:,i], T_hist[:,i], I_approved[:,i], coil_resists[:,i] = PCS.calculate_ctrl_voltages(\n",
+ " t=t, # current simulation time\n",
+ " dt=dt_PCS, # PCS time step\n",
+ " dt_simulator=dt, # simulator (solver) time step\n",
+ " ip_meas=dynamic_ip[i], # measured plasma current\n",
+ " ip_hist_prev=ip_hist_prev, # plasma controller integral term history\n",
+ " T_meas=dynamic_shape_targets[:,i].copy(), # measured shape parameters\n",
+ " T_err_prev=T_err_prev, # shape controller proportional term history\n",
+ " T_hist_prev=T_hist_prev, # shape controller integral term history\n",
+ " I_approved_prev=I_approved_prev, # approved PF currents from prior timestep\n",
+ " I_meas=dynamic_currents[0:11,i].copy(), # measured PF coil currents\n",
+ " V_approved_prev=V_approved_prev, # approved PF voltages from prior timestep\n",
+ " zip_meas=z_current[i]*dynamic_ip[i], # measured vertical position x measured plasma current\n",
+ " zipv_meas=zipv_meas, # derivative of above\n",
+ " active_coil_resists=tokamak.coil_resist[0:12].copy(), # PF coil resistances (these are constant)\n",
+ " verbose=False, # print some output?\n",
+ " )\n",
+ "\n",
+ " # extract plasma current density profile parameters (these are constant but can be time-dependent if desired)\n",
+ " profile_params = {\n",
+ " \"alpha\": profiles.alpha[0:2],\n",
+ " \"beta\": profiles.beta[0:2],\n",
+ " }\n",
+ "\n",
+ " # run freegsnke over the time step with the calculated voltages\n",
+ " stepping.nlstepper(\n",
+ " plasma_resistivity=1e-7, # assign plasma resistivity (chosen to be constant here)\n",
+ " active_voltage_vec=V_approved[:,i], # assign approved PF coil voltages from the PCS\n",
+ " profiles_parameters=profile_params, # assign profile parameters\n",
+ " custom_active_coil_resistances=coil_resists[:,i], # assign active coil resistances from PCS (tells solver if any coils are switched off)\n",
+ " linear_only=True, # linear or nonlinear solve?\n",
+ " target_relative_tol_currents=1e-2, # relative tolerance in the currents required for convergence\n",
+ " target_relative_tol_GS=1e-2, # relative tolerance in the plasma flux required for convergence\n",
+ " working_relative_tol_GS=(1e-2)/2, # tolerance used when solving GS equation, expressed in terms of the change in the plasma flux due to one timestep of evolution (must be smaller tolerance above)\n",
+ " verbose=False, # print some output?\n",
+ " relinearise_threshold=threshold, # if the relative jtor norm change from last linearisation is above threshold, relinearise around current equilibrium\n",
+ " no_GS=False, # do not solve GS at each time?\n",
+ " )\n",
+ "\n",
+ " # stop timer\n",
+ " end_time = time.time()\n",
+ "\n",
+ " # extract and store relevant data\n",
+ " dynamic_psi[:,:,i+1] = stepping.eq1.psi()\n",
+ " dynamic_limiter_flag[i+1] = stepping.eq1._profiles.flag_limiter\n",
+ " dynamic_psi_boundary[i+1] = stepping.eq1._profiles.psi_bndry\n",
+ " dynamic_xpts.append(stepping.eq1.xpt)\n",
+ " dynamic_opts.append(stepping.eq1.opt)\n",
+ " dynamic_currents[:,i+1] = stepping.currents_vec[:-1].copy()\n",
+ " dynamic_ip[i+1] = stepping.currents_vec[-1] * stepping.plasma_norm_factor\n",
+ " dynamic_shape_targets[:,i+1] = plasma_descriptors(stepping.eq1)\n",
+ " dynamic_timings[i] = end_time - start_time\n",
+ " dynamic_triangularity[i] = stepping.eq1.triangularity()\n",
+ " z_current.append(stepping.eq1.Zcurrent())\n",
+ " dynamic_jtor_norm.append(stepping.relinearise_criteria)\n",
+ "\n",
+ " # print some stuff to track solve\n",
+ " print(f\" Ip = {np.round(dynamic_ip[i+1]/1000,1)} [kA]\")\n",
+ " print(f\" Shape parameters {ctrl_targets} = {np.round(dynamic_shape_targets[:,i+1],2)}\")\n",
+ " print(f\" Z position = {np.round(z_current[i+1],5)} [m]\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Plot some results\n",
+ "\n",
+ "Let's now plot some of the output. We can compare the evolution of the controlled parameters (plasma current, shape parameters, and vertical position) against their FB reference waveforms. In these plots:\n",
+ "- the black dashed indicated the FB reference waveform. \n",
+ "- the solid blue the simulated quantity.\n",
+ "- green background shading indicates when FB control is ON. \n",
+ "- yellow background shading indicates when FF control is ON (not used). \n",
+ "- white background shading indicates no control is used."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# PLASMA CURRENT EVOLUTION\n",
+ "\n",
+ "fig, ax = plt.subplots(\n",
+ " nrows=1,\n",
+ " ncols=1,\n",
+ " figsize=(8, 4),\n",
+ " dpi=80\n",
+ ")\n",
+ "\n",
+ "# --- references and masks ---\n",
+ "FB_reference = PCS.PlasmaController.interpolants['ip_ref'](times)\n",
+ "FF_reference = PCS.PlasmaController.interpolants['vloop_ff'](times)\n",
+ "\n",
+ "blend = PCS.PlasmaController.interpolants['ip_blend'](times)\n",
+ "\n",
+ "FB_mask = (blend > 0) & (np.abs(FB_reference) > 0)\n",
+ "FF_mask = (blend < 1) & (np.abs(FF_reference) > 0)\n",
+ "\n",
+ "# --- shade FB regions (green) ---\n",
+ "on_regions = np.where(np.diff(FB_mask.astype(int)) != 0)[0] + 1\n",
+ "for seg_t, seg_state in zip(np.split(times, on_regions),\n",
+ " np.split(FB_mask, on_regions)):\n",
+ " if np.all(seg_state):\n",
+ " ax.axvspan(seg_t[0], seg_t[-1],\n",
+ " color='green', alpha=0.25,\n",
+ " label=\"FB active\")\n",
+ "\n",
+ "# --- shade FF regions (yellow) ---\n",
+ "on_regions = np.where(np.diff(FF_mask.astype(int)) != 0)[0] + 1\n",
+ "for seg_t, seg_state in zip(np.split(times, on_regions),\n",
+ " np.split(FF_mask, on_regions)):\n",
+ " if np.all(seg_state):\n",
+ " ax.axvspan(seg_t[0], seg_t[-1],\n",
+ " color='gold', alpha=0.25,\n",
+ " label=\"FF active\")\n",
+ "\n",
+ "# --- FreeGSNKE ---\n",
+ "ax.plot(times, dynamic_ip,\n",
+ " color='navy', linewidth=1,\n",
+ " marker='x', markersize=0,\n",
+ " label=\"FreeGSNKE\")\n",
+ "\n",
+ "# --- FB reference ---\n",
+ "ax.plot(times[FB_mask], FB_reference[FB_mask],\n",
+ " color='k', linestyle='--',\n",
+ " linewidth=1.5,\n",
+ " label=\"FB reference\")\n",
+ "\n",
+ "ax.set_xlabel(r\"Shot time [$s$]\")\n",
+ "ax.set_ylabel(\"Plasma current [A]\")\n",
+ "ax.grid()\n",
+ "\n",
+ "# deduplicate legend entries\n",
+ "handles, labels = ax.get_legend_handles_labels()\n",
+ "ax.legend(dict(zip(labels, handles)).values(),\n",
+ " dict(zip(labels, handles)).keys())\n",
+ "\n",
+ "ax.set_ylim([7.54e5, 7.64e5])\n",
+ "fig.suptitle(\"Plasma Current\", y=0.98)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# PLASMA VERTICAL POSITION EVOLUTION\n",
+ "\n",
+ "fig, ax = plt.subplots(\n",
+ " nrows=1,\n",
+ " ncols=1,\n",
+ " figsize=(8, 4),\n",
+ " dpi=80\n",
+ ")\n",
+ "\n",
+ "# --- FB reference ---\n",
+ "FB_reference = PCS.VerticalController.interpolants['z_ref'](times)\n",
+ "\n",
+ "# --- shade FB-active region (always on here) ---\n",
+ "ax.axvspan(times[0], times[-1],\n",
+ " color='green', alpha=0.2,\n",
+ " label=\"FB active\")\n",
+ "\n",
+ "# --- references ---\n",
+ "ax.plot(times, FB_reference,\n",
+ " color='k', linestyle='--',\n",
+ " linewidth=1.5,\n",
+ " label=\"FB reference\")\n",
+ "\n",
+ "# --- FreeGSNKE ---\n",
+ "ax.plot(times, z_current,\n",
+ " color='navy', linewidth=1,\n",
+ " marker='x', markersize=0,\n",
+ " label=\"FreeGSNKE\")\n",
+ "\n",
+ "ax.set_xlabel(r\"Shot time [$s$]\")\n",
+ "ax.set_ylabel(\"Vertical position [m]\")\n",
+ "ax.grid()\n",
+ "ax.legend()\n",
+ "\n",
+ "fig.suptitle(\"Vertical Position\", y=0.98)\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# SHAPE PARAMETER EVOLUTION\n",
+ "# the ramp down in Rx here is designed to show how one might increase triangularity of the plasma (see next plot for this effect)\n",
+ "\n",
+ "ntarg = len(ctrl_targets)\n",
+ "\n",
+ "fig, axes = plt.subplots(\n",
+ " nrows=ntarg,\n",
+ " ncols=1,\n",
+ " figsize=(8, 12),\n",
+ " dpi=80,\n",
+ " sharex=True\n",
+ ")\n",
+ "\n",
+ "for j, targ in enumerate(ctrl_targets):\n",
+ " ax = axes[j]\n",
+ "\n",
+ " # --- references and masks ---\n",
+ " FF_reference = PCS.ShapeController.interpolants[targ]['ff'](times)\n",
+ " FB_reference = PCS.ShapeController.interpolants[targ]['ref'](times)\n",
+ "\n",
+ " FF_mask = (\n",
+ " (PCS.ShapeController.interpolants[targ]['blend'](times) < 1)\n",
+ " & (np.abs(PCS.ShapeController.interpolants[targ]['ff'].derivative()(times)) > 0)\n",
+ " )\n",
+ "\n",
+ " FB_mask = (\n",
+ " (PCS.ShapeController.interpolants[targ]['blend'](times) > 0)\n",
+ " & (np.abs(FB_reference) > 0)\n",
+ " )\n",
+ "\n",
+ " # --- shade FB regions (green) ---\n",
+ " on_regions = np.where(np.diff(FB_mask.astype(int)) != 0)[0] + 1\n",
+ " for seg_t, seg_state in zip(np.split(times, on_regions),\n",
+ " np.split(FB_mask, on_regions)):\n",
+ " if np.all(seg_state):\n",
+ " ax.axvspan(seg_t[0], seg_t[-1], color='green', alpha=0.25,\n",
+ " label=\"FB active\" if j == 0 else None)\n",
+ "\n",
+ " # --- shade FF regions (yellow) ---\n",
+ " on_regions = np.where(np.diff(FF_mask.astype(int)) != 0)[0] + 1\n",
+ " for seg_t, seg_state in zip(np.split(times, on_regions),\n",
+ " np.split(FF_mask, on_regions)):\n",
+ " if np.all(seg_state):\n",
+ " ax.axvspan(seg_t[0], seg_t[-1], color='gold', alpha=0.25,\n",
+ " label=\"FF active\" if j == 0 else None)\n",
+ "\n",
+ " # --- references ---\n",
+ " ax.plot(times[FB_mask], FB_reference[FB_mask],\n",
+ " color='k', linestyle='--', linewidth=1.5,\n",
+ " label=\"FB reference\" if j == 0 else None)\n",
+ "\n",
+ " if np.any(FF_mask):\n",
+ " offset = interpolants[targ + '_meas'](times[FF_mask][0])\n",
+ " ax.plot(times[FF_mask], FF_reference[FF_mask] + offset,\n",
+ " color='r', linestyle='--', linewidth=1.5,\n",
+ " label=\"FF reference\" if j == 0 else None)\n",
+ "\n",
+ " # --- FreeGSNKE ---\n",
+ " ax.plot(times, dynamic_shape_targets[j, :],\n",
+ " color='navy', linewidth=1,\n",
+ " marker='x', markersize=0,\n",
+ " label=\"FreeGSNKE\" if j == 0 else None)\n",
+ "\n",
+ " ax.set_ylabel(f\"{ctrl_targets[j]} [m]\")\n",
+ " ax.grid()\n",
+ " ax.set_ylim([np.min(FB_reference)-0.02, np.max(FB_reference)+0.02])\n",
+ "\n",
+ "# shared x label\n",
+ "axes[-1].set_xlabel(r\"Shot time [$s$]\")\n",
+ "\n",
+ "# legend once\n",
+ "axes[0].legend(ncol=4, loc=\"upper right\")\n",
+ "\n",
+ "fig.suptitle(\"Shape parameters\", y=0.995, fontsize=14)\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# PLASMA TRIANGULARITY EVOLUTION \n",
+ "\n",
+ "fig, ax = plt.subplots(\n",
+ " nrows=1,\n",
+ " ncols=1,\n",
+ " figsize=(8, 4),\n",
+ " dpi=80\n",
+ ")\n",
+ "\n",
+ "# --- FreeGSNKE ---\n",
+ "ax.plot(times[0:-1], dynamic_triangularity[0:-1],\n",
+ " color='navy', linewidth=1,\n",
+ " marker='x', markersize=0,\n",
+ " label=\"FreeGSNKE\")\n",
+ "\n",
+ "ax.set_xlabel(r\"Shot time [$s$]\")\n",
+ "ax.set_ylabel(\"Triangularity\")\n",
+ "ax.grid()\n",
+ "ax.legend()\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can also show the voltages produced by the PCS and the subsequent evolution of the currents for each of the PF coils. To display the voltage/current limits, uncomment the relevant sections in the next cell."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# PF COIL VOLTAGES AND CURRENTS\n",
+ "\n",
+ "ncoils = len(active_coils)\n",
+ "\n",
+ "fig, axes = plt.subplots(\n",
+ " nrows=ncoils,\n",
+ " ncols=2,\n",
+ " figsize=(18, 45),\n",
+ " dpi=80,\n",
+ " sharex=True\n",
+ ")\n",
+ "\n",
+ "for j, coil in enumerate(active_coils):\n",
+ " axV = axes[j, 0] # voltage column\n",
+ " axI = axes[j, 1] # current column\n",
+ "\n",
+ " # VOLTAGES (left)\n",
+ "\n",
+ " # # voltage limits\n",
+ " # if coil not in [\"P6\"]:\n",
+ " # vlim = PCS.PFController.data['coil_voltage_lims']['vals'][0][j]\n",
+ " # axV.hlines([-vlim, vlim], times[0], times[-1],\n",
+ " # colors='k', linestyles='--', linewidth=1.2,\n",
+ " # label=\"Coil limits\" if j == 0 else None)\n",
+ "\n",
+ " axV.plot(times[0:-1], V_approved[j, 0:-1],\n",
+ " color='navy', linewidth=1,\n",
+ " marker='x', markersize=0,\n",
+ " label=\"FreeGSNKE\")\n",
+ "\n",
+ " axV.set_ylabel(f'{coil} voltage [V]')\n",
+ " axV.grid()\n",
+ " if j == 0:\n",
+ " axV.legend()\n",
+ "\n",
+ " # CURRENTS (right)\n",
+ "\n",
+ " # # current limits\n",
+ " # if coil not in [\"P6\"]:\n",
+ " # imin = PCS.SystemsController.data['min_coil_curr_lims']['vals'][0][j]\n",
+ " # imax = PCS.SystemsController.data['max_coil_curr_lims']['vals'][0][j]\n",
+ " # axI.hlines([imin, imax], times[0], times[-1],\n",
+ " # colors='k', linestyles='--', linewidth=1.2,\n",
+ " # label=\"Coil limits\" if j == 0 else None)\n",
+ "\n",
+ " axI.plot(times, dynamic_currents[j, :],\n",
+ " color='navy', linewidth=1,\n",
+ " marker='x', markersize=0,\n",
+ " label=\"FreeGSNKE\")\n",
+ "\n",
+ " axI.set_ylabel(f'{coil} current [A]')\n",
+ " axI.grid()\n",
+ " if j == 0:\n",
+ " axI.legend()\n",
+ "\n",
+ "# x-labels only on bottom row\n",
+ "axes[-1, 0].set_xlabel(r'Shot time [$s$]')\n",
+ "axes[-1, 1].set_xlabel(r'Shot time [$s$]')\n",
+ "\n",
+ "# column titles\n",
+ "axes[0, 0].set_title(\"PF Coil Voltages\")\n",
+ "axes[0, 1].set_title(\"PF Coil Currents\")\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The following cell can be uncommented to make an animation of the equilibrium in the machine over time. Might take a few seconds to run and will save the output .mp4 in the `data` directory. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# # SAVE AN ANIMATION OF THE SIMULATION\n",
+ "# %matplotlib inline\n",
+ "\n",
+ "# import matplotlib.animation as animation\n",
+ "\n",
+ "# plt.rcParams['figure.dpi'] = 90\n",
+ "\n",
+ "# fig1, ax1 = plt.subplots(1, 1, figsize=(8, 8))\n",
+ "\n",
+ "# # Plot static wall / tokamak background\n",
+ "# eq.tokamak.plot(axis=ax1, show=False)\n",
+ "# ax1.plot(tokamak.wall.R, tokamak.wall.Z, color='k', linewidth=1.2)\n",
+ "# ax1.grid(True, which='both', alpha=0.5)\n",
+ "# ax1.set_aspect('equal')\n",
+ "# ax1.set_xlabel(r'Major radius, $R$ $[m]$')\n",
+ "# ax1.set_ylabel(r'Height, $Z$ $[m]$')\n",
+ "# ax1.set_xlim(0.05, 2.15)\n",
+ "# ax1.set_ylim(-2.25, 2.25)\n",
+ "\n",
+ "# # Determine contour levels from entire dynamic_psi\n",
+ "# min_psi = np.min(dynamic_psi)\n",
+ "# max_psi = np.max(dynamic_psi)\n",
+ "# levels = np.linspace(min_psi, max_psi, 40)\n",
+ "\n",
+ "# # --- Storage for dynamic artists ---\n",
+ "# contour_artists = []\n",
+ "# scatter_artists = []\n",
+ "\n",
+ "# # --- Update function ---\n",
+ "# def update(i):\n",
+ " \n",
+ "# global contour_artists, scatter_artists\n",
+ "\n",
+ "# # Remove previous dynamic artists\n",
+ "# for c in contour_artists + scatter_artists:\n",
+ "# if isinstance(c, list):\n",
+ "# for coll in c:\n",
+ "# coll.remove()\n",
+ "# else:\n",
+ "# c.remove()\n",
+ "# contour_artists = []\n",
+ "# scatter_artists = []\n",
+ "\n",
+ "# ax1.set_title(rf\"$t$ = {np.round(times[i],3)}\")\n",
+ "\n",
+ "# # Main psi contours\n",
+ "# c1 = ax1.contour(eq.R, eq.Z, dynamic_psi[:,:,i], levels=levels, alpha=0.8, cmap='viridis')\n",
+ "# contour_artists.append(c1)\n",
+ "\n",
+ "# # plasma boundary\n",
+ "# c2 = ax1.contour(eq.R, eq.Z, dynamic_psi[:,:,i],\n",
+ "# levels=[dynamic_psi_boundary[i]],\n",
+ "# linestyles=\"-\", colors='r', linewidths=1.4)\n",
+ "# contour_artists.append(c2)\n",
+ "\n",
+ "# # adds separatrix of primary X-point if plasma limited\n",
+ "# if dynamic_limiter_flag[i]:\n",
+ "# c3 = ax1.contour(eq.R, eq.Z, dynamic_psi[:,:,i],\n",
+ "# levels=[dynamic_xpts[i][0,2]],\n",
+ "# linestyles=\"--\", colors='k', linewidths=1.4)\n",
+ "# contour_artists.append(c3)\n",
+ "\n",
+ "\n",
+ "# # X-points and O-points\n",
+ "# sc1 = ax1.scatter(dynamic_xpts[i][:,0], dynamic_xpts[i][:,1], color='r', marker='x', s=30)\n",
+ "# sc2 = ax1.scatter(dynamic_opts[i][:,0], dynamic_opts[i][:,1], color='g', marker='2', s=30)\n",
+ "\n",
+ "# # indicate which is the primary X-point more clearly\n",
+ "# sc3 = ax1.scatter(dynamic_xpts[i][0,0], dynamic_xpts[i][0,1], color='r', marker='x', s=60)\n",
+ "\n",
+ "# scatter_artists.extend([sc1, sc2, sc3])\n",
+ "\n",
+ "# return contour_artists + scatter_artists\n",
+ "\n",
+ "# # --- Animation ---\n",
+ "# num_frames = len(times)\n",
+ "# frames_to_plot = range(0, num_frames, 5)\n",
+ "# fps = int(len(frames_to_plot)/10)\n",
+ "# ani = animation.FuncAnimation(fig1, update, frames=frames_to_plot, interval=1, blit=True)\n",
+ "\n",
+ "# # save to video (much faster & smaller than GIF)\n",
+ "# ani.save(f\"data/animated_equilibrium.mp4\", writer=\"ffmpeg\", fps=fps)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "freegsnke_github",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/freegsnke/control_loop/__init__.py b/freegsnke/control_loop/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/freegsnke/control_loop/coil_activation_category.py b/freegsnke/control_loop/coil_activation_category.py
new file mode 100644
index 00000000..9709444e
--- /dev/null
+++ b/freegsnke/control_loop/coil_activation_category.py
@@ -0,0 +1,274 @@
+"""
+Module implementing time‑dependent coil activation scheduling for FreeGSNKE
+control loops.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from freegsnke.control_loop.useful_functions import (
+ check_data_entry,
+ interpolate_spline,
+ interpolate_step,
+)
+
+
+class CoilActivationController:
+ """
+ A controller class for managing time-dependent coil activation times.
+
+ Parameters
+ ----------
+ data : dict
+ A nested dictionary containing coil activation waveforms for the controller.
+ The required keys
+ for both spline-based and step-based waveforms are:
+ - Spline keys: "_activation"
+ - Step keys:
+ Each key should map to a waveform dictionary suitable for interpolation with keys:
+ - 'times': 1D array of time points
+ - 'vals': 1D array of values at those time points (same length).
+
+ active_coils : list of str
+ The list of active coils being used.
+
+ Attributes
+ ----------
+ active_coils : list of str
+ The list of active coils being used.
+
+ keys_to_spline : list of str
+ Keys corresponding to waveforms that will be interpolated using splines.
+
+ keys_to_step : list of str
+ Keys corresponding to waveforms that will be interpolated using step functions.
+
+ data : dict
+ Internal copy of the input control waveforms.
+
+ interpolants : dict
+ A nested dictionary storing interpolation functions of each input waveform.
+ Structure: {spline/step key: interpolant_function}
+ """
+
+ def __init__(
+ self,
+ data,
+ active_coils,
+ ):
+
+ # coils list
+ self.active_coils = active_coils
+
+ # check correct data is input and in correct format
+ self.keys_to_spline = []
+ self.keys_to_step = [coil + "_activation" for coil in self.active_coils]
+ for key in self.keys_to_spline + self.keys_to_step:
+ check_data_entry(
+ data=data, key=key, controller_name="CoilActivationController"
+ )
+
+ # create an internal copy of the data
+ self.data = data
+
+ # interpolate the input data
+ self.update_interpolants()
+
+ def update_interpolants(self):
+ """
+ Recompute all interpolant functions from the current `self.data`.
+
+ This method clears the existing `self.interpolants` dictionary and
+ rebuilds it by applying either `interpolate_spline` or `interpolate_step`
+ depending on whether each key belongs to `self.keys_to_spline` or
+ `self.keys_to_step`.
+
+ """
+
+ # create a dictionary to store the spline functions
+ self.interpolants = {}
+
+ # interpolate the input data
+ for key in self.data.keys():
+ self.interpolants[key] = {}
+ if key in self.keys_to_spline:
+ self.interpolants[key] = interpolate_spline(self.data[key])
+ elif key in self.keys_to_step:
+ self.interpolants[key] = interpolate_step(self.data[key])
+
+ def run_control(
+ self,
+ t,
+ dt,
+ active_coil_resists,
+ ):
+ """
+ Compute effective coil resistances at a given time step.
+
+ This function extracts coil activation values at time ``t`` and scales the
+ base resistances accordingly. Coils that are inactive (activation ~ 0)
+ are assigned a very large resistance to effectively disable them in the
+ control model.
+
+ Parameters
+ ----------
+ t : float
+ Current time at which to evaluate coil activations.
+ dt : float
+ Time step size (currently unused, but kept for interface consistency).
+ active_coil_resists : numpy.ndarray
+ Array of active coil resistances when coils are switched on [Ohms].
+
+ Returns
+ -------
+ numpy.ndarray
+ Array of effective coil resistances, where inactive coils are set to
+ a large resistance value (``1e12``).
+ """
+
+ # extract data
+ activations = self.extract_values(t=t, targets=self.active_coils, deriv=False)
+
+ # if coil is not active, set very large resistance
+ # final_coil_resists = active_coil_resists + (1.0 - activations) * 1e12
+ mask = activations.astype(bool)
+ final_coil_resists = active_coil_resists.copy()
+ final_coil_resists[~mask] = 1e12
+
+ return final_coil_resists
+
+ def extract_values(
+ self,
+ t,
+ targets,
+ deriv=False,
+ ):
+ """
+ Extracts interpolated values or their derivatives for specified shape targets at a given time.
+
+ This method queries the stored interpolation functions for each target and key, returning either
+ the interpolated value or its first derivative depending on the `deriv` flag.
+
+ Parameters
+ ----------
+ t : float
+ Time at which to evaluate the interpolants [s].
+ targets : list of str
+ List of keys. Each must correspond to a key in `self.interpolants`.
+ deriv : bool, optional
+ If True, returns the first derivative of the interpolant at time `t`. Default is False.
+
+ Returns
+ -------
+ np.ndarray
+ Array of interpolated values (or derivatives) for each target at time `t`.
+
+ Notes
+ -----
+ - Assumes that `self.interpolants[target]` is a valid `scipy.interpolate` object.
+ - If `deriv=True`, the method calls `.derivative()` on the interpolant before evaluation.
+ """
+
+ if deriv:
+ return np.array(
+ [
+ self.interpolants[target + "_activation"].derivative(n=1)(t)
+ for target in targets
+ ]
+ )
+ else:
+ return np.array(
+ [self.interpolants[target + "_activation"](t) for target in targets]
+ )
+
+ def plot_data(self, tmin=-1.0, tmax=1.0, nt=1001):
+ """
+ Visualizes interpolated control waveforms and corresponding raw inputs.
+
+ This method generates subplots for each control waveform (step types),
+ showing the interpolated time series alongside the original data points. It helps verify
+ the quality and behavior of the interpolation.
+
+ Parameters
+ ----------
+ tmin : float, optional
+ Start time for the evaluation grid (default is -1.0 seconds).
+ tmax : float, optional
+ End time for the evaluation grid (default is 1.0 seconds).
+ nt : int, optional
+ Number of time points to evaluate the interpolants over the interval [tmin, tmax] (default is 1001).
+
+ Notes
+ -----
+ - Each subplot corresponds to a control waveform.
+ - Interpolated curves are plotted in navy; raw data points are shown in red.
+ - Axis labels include units where applicable.
+ - Useful for debugging or validating the interpolation quality.
+ """
+
+ # times to plot at
+ t = np.linspace(tmin, tmax, nt)
+ nplots = len(self.keys_to_spline + self.keys_to_step) # number of plots
+
+ # start plotting
+ fig, axes = plt.subplots(nplots, 1, figsize=(6, 2.5 * nplots), sharex=True)
+
+ if nplots == 1:
+ axes = [axes]
+
+ for ax, key in zip(axes, self.data.keys()):
+ ax.scatter(
+ self.data[key]["times"],
+ self.data[key]["vals"],
+ s=10,
+ marker="x",
+ color="tab:orange",
+ alpha=0.9,
+ label=f"raw data",
+ )
+ ax.plot(
+ t,
+ self.interpolants[key](t),
+ color="navy",
+ linewidth=1.2,
+ label="interpolated",
+ )
+ ax.grid(True, linestyle="--", alpha=0.6)
+ ax.set_ylabel(key)
+
+ # y-scaling inside the window
+ times = np.array(self.data[key]["times"])
+ mask = (times >= tmin) & (times <= tmax)
+ if np.any(mask):
+ ydata = np.concatenate(
+ [self.interpolants[key](t), np.array(self.data[key]["vals"])[mask]]
+ )
+ ymin, ymax = np.min(ydata), np.max(ydata)
+ yrange = ymax - ymin
+ if yrange == 0:
+ yrange = 1.0
+ ax.set_ylim(ymin - 0.02 * yrange, ymax + 0.02 * yrange)
+
+ fig.suptitle("Coil activation schedule (0 = off, 1 = on)")
+ axes[0].legend(loc="best")
+ axes[-1].set_xlabel(r"Time [$s$]")
+ axes[-1].set_xlim([tmin, tmax])
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
+ plt.show()
diff --git a/freegsnke/control_loop/pcs.py b/freegsnke/control_loop/pcs.py
new file mode 100644
index 00000000..7464dfc5
--- /dev/null
+++ b/freegsnke/control_loop/pcs.py
@@ -0,0 +1,384 @@
+"""
+Module to implement a Plasma Control System (PCS) in FreeGSNKE.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+# imports
+import numpy as np
+
+from .coil_activation_category import CoilActivationController
+from .pf_category import PFController
+from .plasma_category import PlasmaController
+from .shape_category import ShapeController
+from .systems_category import SystemsController
+from .vertical_category import VerticalController
+from .virtual_circuits_category import VirtualCircuitsController
+
+
+class PlasmaControlSystem:
+ """
+ A high-level class for managing multiple controllers in a plasma
+ control system.
+
+ This class integrates several subsystem controllers responsible for different aspects
+ of plasma control, including shape control, vertical stabilization, coil current
+ regulation, and virtual circuit modelling. It provides a unified interface for
+ coordinating these controllers based on time-dependent input waveforms.
+
+ Attributes
+ ----------
+ active_coils : list of str
+ List of all active coils used.
+
+ ctrl_coils : list of str
+ List of all active coils being used for shape control.
+
+ solenoid_coils : list of str
+ List of all active coils being used for plasma current control.
+
+ vertical_coils : list of str
+ List of all active coils being used for vertical control.
+
+ ctrl_targets : list of str
+ List of all shape control targets (e.g., X-points, strike points).
+
+ plasma_target : list of str
+ List of all plasma control targets.
+
+ shape_control_mode : str
+ Select which shape control algorithm to use (see shape_category.py).
+
+ PlasmaController : PlasmaController
+ Handles plasma current control.
+
+ ShapeController : ShapeController
+ Handles shape target control.
+
+ VirtualCircuitsController : VirtualCircuitsController
+ Computes unapproved coil currents in ctrl coils using virtual circuits.
+
+ SystemsController : SystemsController
+ Applies perturbations and enforces coil current and ramp rate limits to
+ find approved coil currents.
+
+ PFController : PFController
+ Transforms approved coil currents into ctrl coil voltages.
+
+ VerticalControl : VerticalController
+ Controls vertical plasma position via vertical control coil.
+
+ CoilActivationControl : CoilActivationController
+ Controls coil resistances, depending on if they're switched on or off.
+
+ vc_generator : object, optional
+ An optional class object for applying emulated virtual circuits. If not
+ provided, deafult waveform-defined VCs will be used.
+
+ vc_update_rate : float, optional
+ Optional argument to specify how ofte, in seconds, new VCs are computed with vc_generator.
+ If None provided, defaults to zero and new VC computed at every iterration.
+ """
+
+ def __init__(
+ self,
+ plasma_data,
+ shape_data,
+ circuits_data,
+ systems_data,
+ pf_data,
+ vertical_data,
+ coil_activation_data,
+ active_coils,
+ ctrl_coils,
+ solenoid_coils,
+ vertical_coils,
+ ctrl_targets,
+ plasma_target,
+ shape_control_mode=None,
+ vc_generator=None,
+ vc_update_rate=None,
+ ):
+
+ # coil ordering
+ self.active_coils = active_coils
+ self.ctrl_coils = ctrl_coils
+ self.solenoid_coils = solenoid_coils
+ self.vertical_coils = vertical_coils
+
+ # shape targets
+ self.ctrl_targets = ctrl_targets
+ self.plasma_target = plasma_target
+
+ # initialise controllers and assign data to each
+ self.PlasmaController = PlasmaController(
+ data=plasma_data,
+ )
+
+ self.ShapeController = ShapeController(
+ data=shape_data,
+ ctrl_targets=self.ctrl_targets,
+ mode=shape_control_mode,
+ )
+
+ self.VirtualCircuitsController = VirtualCircuitsController(
+ data=circuits_data,
+ ctrl_coils=self.ctrl_coils,
+ ctrl_targets=self.ctrl_targets,
+ plasma_target=self.plasma_target,
+ vc_generator=vc_generator,
+ vc_update_rate=vc_update_rate,
+ )
+
+ self.SystemsController = SystemsController(
+ data=systems_data,
+ ctrl_coils=self.ctrl_coils,
+ )
+
+ self.PFController = PFController(
+ data=pf_data,
+ )
+
+ self.VerticalController = VerticalController(
+ data=vertical_data,
+ )
+
+ self.CoilActivationController = CoilActivationController(
+ data=coil_activation_data,
+ active_coils=self.active_coils,
+ )
+
+ def calculate_ctrl_voltages(
+ self,
+ t,
+ dt,
+ ip_meas,
+ ip_hist_prev,
+ T_meas,
+ T_err_prev,
+ T_hist_prev,
+ I_approved_prev,
+ I_meas,
+ V_approved_prev,
+ zip_meas,
+ zipv_meas,
+ active_coil_resists,
+ dt_simulator=None,
+ emulated_VC_targets=None,
+ emulated_VC_targets_calc=None,
+ emulator_coils_calc=None,
+ emu_inputs=None,
+ vc_update_rate=None,
+ verbose=False,
+ ):
+ """
+ Run the full control pipeline to compute approved coil voltage commands.
+
+ This method coordinates all subsystem controllers (plasma current, shape control,
+ virtual circuits, systems constraints, PF, and vertical) to compute the final
+ voltage commands for the coils. It also returns updated histories and error signals
+ for use in the next control cycle.
+
+ Parameters
+ ----------
+ t : float
+ Current time [s].
+
+ dt : float
+ Time step to run the controllers at (must have 'dt = dt_simulator/n' where n is a natural number) [s].
+
+ ip_meas : float
+ Measured plasma current [A].
+
+ ip_hist_prev : float
+ Previous value of the integrated plasma current error [A.s].
+
+ T_meas : np.ndarray
+ Measured values of the shape targets at the current time [m].
+
+ T_err_prev : np.ndarray
+ Previously shape target filtered error signal (used for damping) [m].
+
+ T_hist_prev : np.ndarray
+ Previous shape target integral term (used for PI control) [m.s].
+
+ I_approved_prev : numpy.ndarray
+ Previously approved coil currents [A].
+
+ I_meas : numpy.ndarray
+ Measured coil currents at the current time step [A].
+
+ V_approved_prev : numpy.ndarray
+ Previously approved control coil voltages from the last control step [V].
+
+ zip_meas : float
+ Measured vertical position of the plasma multiplied by measured Ip [A.m].
+
+ zipv_meas : float
+ Measured vertical velocity of the plasma multiplied by measured Ip [A.m/s].
+
+ active_coil_resists : numpy.ndarray
+ Array of active coil resistances when coils are switched on [Ohms].
+
+ dt_simulator : float
+ Time step of the simulator (must have 'dt_simulator = dt*n' where n is a natural number) [s].
+
+ emulated_VC_targets : list of str , optional
+ List of targets to be controlled using the emulated VC's. Must be subset of
+ ctrl_targets, and subset/equal to emulated_VC_targets_calc. Those not defined in this list will be taken from waveform-defined
+ VCs.
+
+ emulated_VC_targets_calc : list of str , optional
+ List of targets to be used when performing pseudoinverse of jacobian when calculating the emulated VC.
+
+ emulator_coils_calc : list of str, optional
+ List of coils to use in emulated VC compuation. These are coils to use in computing shape sensitivity matrix.
+
+ verbose : bool, optional
+ If True, prints diagnostic information from subsystem controllers.
+
+ Returns
+ -------
+ V_active : numpy.ndarray
+ Final (all active) coil voltage demands after applying all constraints [V].
+
+ ip_hist : list of float
+ Updated integrated plasma current error [A.s].
+
+ T_err : numpy.ndarray
+ Updated shape target filtered error signal (used for damping) [m].
+
+ T_hist : list of numpy.ndarray
+ Updated shape target integral term (used for PI control) [m.s].
+
+ I_approved : numpy.ndarray
+ Approved coil currents after applying perturbations and clipping [A].
+
+ coil_resists : numpy.ndarray
+ Active coil resistances to be used (some coils may be on or off at time t) [Ohms].
+ """
+
+ # check timesteps align correctly (default: dt_simulator = dt)
+ if dt_simulator is None:
+ dt_simulator = dt
+ n = 1
+ else:
+ n = round(dt_simulator / dt) # nearest integer
+ if n < 1:
+ n = 1 # enforce natural number >= 1
+
+ # call the PCS class (n times per simulator step if requested)
+ V_actives = []
+ for i in range(0, n):
+
+ # plasma category
+ self.dip_dt, ip_hist = self.PlasmaController.run_control(
+ t=t + (i * dt),
+ dt=dt,
+ ip_meas=ip_meas,
+ ip_hist_prev=ip_hist_prev,
+ )
+
+ # update "history" terms
+ ip_hist_prev = ip_hist.copy()
+
+ # shape category
+ self.dT_dt, T_err, T_hist = self.ShapeController.run_control(
+ t=t + (i * dt),
+ dt=dt,
+ T_meas=T_meas,
+ T_err_prev=T_err_prev,
+ T_hist_prev=T_hist_prev,
+ )
+
+ # update "history" terms
+ T_err_prev = T_err.copy()
+ T_hist_prev = T_hist.copy()
+
+ # virtual circuits category
+ self.I_unapproved, self.dI_dt_unapproved = (
+ self.VirtualCircuitsController.run_control(
+ t=t + (i * dt),
+ dt=dt,
+ dip_dt=self.dip_dt,
+ dT_dt=self.dT_dt,
+ I_approved_prev=I_approved_prev,
+ emulated_VC_targets=emulated_VC_targets,
+ emulated_VC_targets_calc=emulated_VC_targets_calc,
+ emulator_coils_calc=emulator_coils_calc,
+ emu_inputs=emu_inputs,
+ )
+ )
+
+ # systems category
+ self.I_approved, self.dI_dt_approved = self.SystemsController.run_control(
+ t=t + (i * dt),
+ dt=dt,
+ I_unapproved=self.I_unapproved,
+ dI_dt_unapproved=self.dI_dt_unapproved,
+ verbose=verbose,
+ )
+
+ # update "history" terms
+ I_approved_prev = self.I_approved.copy()
+
+ # PF category
+ self.V_ctrl = self.PFController.run_control(
+ t=t + (i * dt),
+ dt=dt,
+ I_meas=I_meas,
+ I_approved=self.I_approved,
+ dI_dt_approved=self.dI_dt_approved,
+ V_approved_prev=V_approved_prev,
+ verbose=verbose,
+ )
+
+ # update "history" terms
+ V_approved_prev = self.V_ctrl.copy()
+
+ # vertical category
+ self.V_vertical = self.VerticalController.run_control(
+ t=t + (i * dt),
+ dt=dt,
+ ip_meas=ip_meas,
+ zip_meas=zip_meas,
+ zipv_meas=zipv_meas,
+ )
+
+ # coil activations category
+ coil_resists = self.CoilActivationController.run_control(
+ t=t + (i * dt),
+ dt=dt,
+ active_coil_resists=active_coil_resists,
+ )
+
+ # lookup dictionaries
+ ctrl_dict = dict(zip(self.ctrl_coils, self.V_ctrl))
+ vert_dict = dict(zip(self.vertical_coils, np.array([self.V_vertical])))
+
+ # build active coil voltages vector
+ V_actives.append(
+ np.array(
+ [ctrl_dict.get(c, vert_dict.get(c, 0.0)) for c in self.active_coils]
+ )
+ )
+
+ # average the requested voltages for use in simulator
+ V_active = np.mean(V_actives, axis=0)
+
+ return V_active, ip_hist, T_err, T_hist, self.I_approved, coil_resists
diff --git a/freegsnke/control_loop/pf_category.py b/freegsnke/control_loop/pf_category.py
new file mode 100644
index 00000000..4faa9ef9
--- /dev/null
+++ b/freegsnke/control_loop/pf_category.py
@@ -0,0 +1,314 @@
+"""
+Module to implement PF control in FreeGSNKE control loops.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from freegsnke.control_loop.useful_functions import (
+ check_data_entry,
+ interpolate_spline,
+ interpolate_step,
+)
+
+
+class PFController:
+ """
+ A controller class for managing coil resistances, inductances, gains, voltage limits, and
+ voltage ramp rate limits.
+
+ Parameters
+ ----------
+ data : dict
+ A nested dictionary containing control waveforms for the PF controller.
+ The required keys
+ for both spline-based and step-based waveforms are:
+ - Spline keys:
+ - Step keys: "R_matrix", "M_FF_matrix", "M_FB_matrix", "coil_gains",
+ "coil_voltage_lims", "coil_voltage_slew_lims"
+ Each key should map to a waveform dictionary suitable for interpolation with keys:
+ - 'times': 1D array of time points
+ - 'vals': 1D array of values at those time points (same length).
+
+ Attributes
+ ----------
+ keys_to_spline : list of str
+ Keys corresponding to waveforms that will be interpolated using splines.
+
+ keys_to_step : list of str
+ Keys corresponding to waveforms that will be interpolated using step functions.
+
+ data : dict
+ Internal copy of the input control waveforms.
+
+ interpolants : dict
+ A nested dictionary storing interpolation functions of each input waveform.
+ Structure: {spline/step key: interpolant_function}
+
+ """
+
+ def __init__(
+ self,
+ data,
+ ):
+
+ # check correct data is input and in correct format
+ self.keys_to_spline = []
+ self.keys_to_step = [
+ "R_matrix",
+ "M_FF_matrix",
+ "M_FB_matrix",
+ "coil_gains",
+ "coil_voltage_lims",
+ "coil_voltage_slew_lims",
+ ]
+ for key in self.keys_to_spline + self.keys_to_step:
+ check_data_entry(data=data, key=key, controller_name="PFController")
+
+ # create an internal copy of the data
+ self.data = data
+
+ # interpolate the input data
+ self.update_interpolants()
+
+ def update_interpolants(self):
+ """
+ Recompute all interpolant functions from the current `self.data`.
+
+ This method clears the existing `self.interpolants` dictionary and
+ rebuilds it by applying either `interpolate_spline` or `interpolate_step`
+ depending on whether each key belongs to `self.keys_to_spline` or
+ `self.keys_to_step`.
+
+ """
+
+ # create a dictionary to store the spline functions
+ self.interpolants = {}
+
+ # interpolate the input data
+ for key in self.keys_to_step:
+ self.interpolants[key] = interpolate_step(self.data[key])
+
+ def run_control(
+ self,
+ t,
+ dt,
+ I_meas,
+ I_approved,
+ dI_dt_approved,
+ V_approved_prev,
+ verbose=False,
+ ):
+ """
+ Computes the approved coil voltage commands based on measured and approved currents,
+ while enforcing voltage and slew rate constraints.
+
+ This method calculates the total voltage demand using resistive, feedforward, and
+ feedback components. It then clips the voltage according to hardware limits and
+ applies slew rate constraints to ensure smooth transitions between time steps.
+
+ Parameters
+ ----------
+ t : float
+ Current time [s].
+
+ dt : float
+ Time step [s].
+
+ I_meas : numpy.ndarray
+ Measured coil currents at the current time step [A].
+
+ I_approved : numpy.ndarray
+ Approved coil currents after applying perturbations and clipping [A].
+
+ dI_dt_approved : numpy.ndarray
+ Approved coil current derivatives after clipping [A/s].
+
+ V_approved_prev : numpy.ndarray
+ Previously approved coil voltages from the last control step [V].
+
+ verbose : bool, optional
+ If True, prints diagnostic information about voltage clipping and slew rate limiting.
+
+ Returns
+ -------
+ V_approved : numpy.ndarray
+ Final coil voltage demands after applying all constraints [V].
+ """
+
+ # extract interpolated data
+ R = self.interpolants["R_matrix"](t)
+ M_FF = self.interpolants["M_FF_matrix"](t)
+ M_FB = self.interpolants["M_FB_matrix"](t)
+ coil_gains = self.interpolants["coil_gains"](t)
+ voltage_clips = self.interpolants["coil_voltage_lims"](t)
+ slew_rates = self.interpolants["coil_voltage_slew_lims"](t)
+
+ # resistive voltages
+ v_res = R * I_meas
+
+ # FF voltages
+ v_FF = M_FF @ dI_dt_approved
+
+ # FB voltages
+ delta_I = I_approved - I_meas
+ v_FB = M_FB @ (delta_I / coil_gains)
+
+ # initial voltage demands (pre-clipping)
+ v_init = v_res + v_FF + v_FB
+
+ # clip voltage to max/min allowed
+ v_clipped = np.clip(v_init, -voltage_clips, voltage_clips)
+
+ # apply slew rate constraints
+ delta_voltages = v_clipped - V_approved_prev
+ max_delta = slew_rates * dt
+ delta_clipped = np.clip(delta_voltages, -max_delta, max_delta)
+ V_approved = V_approved_prev + delta_clipped
+
+ return V_approved.squeeze()
+
+ def extract_values(
+ self,
+ t,
+ targets,
+ ):
+ """
+ Extracts interpolated values for specified shape targets at a given time.
+
+ Parameters
+ ----------
+ t : float
+ Time at which to evaluate the interpolants [s].
+ targets : list of str
+ List of keys. Each must correspond to a key in `self.interpolants`.
+
+ Returns
+ -------
+ np.ndarray
+ Array of interpolated values (or derivatives) for each target at time `t`.
+
+ Notes
+ -----
+ - Assumes that `self.interpolants[target]` is a valid `scipy.interpolate` object.
+ """
+
+ return np.array([self.interpolants[target](t) for target in targets])
+
+ def plot_data(self, tmin=-1.0, tmax=1.0, nt=1001):
+ """
+ Visualizes interpolated control waveforms and corresponding raw inputs.
+
+ This method generates subplots for each control waveform (step types),
+ showing the interpolated time series alongside the original data points. It helps verify
+ the quality and behavior of the interpolation.
+
+ Parameters
+ ----------
+ tmin : float, optional
+ Start time for the evaluation grid (default is -1.0 seconds).
+ tmax : float, optional
+ End time for the evaluation grid (default is 1.0 seconds).
+ nt : int, optional
+ Number of time points to evaluate the interpolants over the interval [tmin, tmax] (default is 1001).
+
+ Notes
+ -----
+ - Each subplot corresponds to a control waveform.
+ - Interpolated curves are plotted in navy; raw data points are shown in red.
+ - Axis labels include units where applicable.
+ - Useful for debugging or validating the interpolation quality.
+ """
+
+ # times to plot at
+ t = np.linspace(tmin, tmax, nt)
+ nplots = len(self.keys_to_step[3:6]) # number of plots
+
+ # start plotting
+ fig, axes = plt.subplots(nplots, 1, figsize=(6, 2.5 * nplots), sharex=True)
+
+ if nplots == 1:
+ axes = [axes]
+
+ for ax, key in zip(axes, self.keys_to_step[3:6]):
+ times = np.asarray(self.data[key]["times"])
+ vals_list = self.data[key]["vals"]
+
+ if np.isscalar(vals_list[0]):
+ ax.scatter(
+ self.data[key]["times"],
+ self.data[key]["vals"],
+ s=10,
+ marker="x",
+ color="tab:orange",
+ alpha=0.9,
+ label=f"raw data",
+ )
+ else:
+ m = len(vals_list[0])
+ times_repeated = np.repeat(times, m)
+ vals_flat = np.concatenate(vals_list)
+ ax.scatter(
+ times_repeated,
+ vals_flat,
+ s=10,
+ marker="x",
+ color="tab:orange",
+ alpha=0.9,
+ label=f"raw data",
+ )
+
+ ax.plot(
+ t,
+ self.interpolants[key](t),
+ color="navy",
+ linewidth=1.2,
+ label="interpolated",
+ )
+
+ ax.grid(True, linestyle="--", alpha=0.6)
+
+ if key == "coil_gains":
+ ax.set_ylabel(rf"{key} [$s$]")
+ elif key == "coil_voltage_lims":
+ ax.set_ylabel(rf"{key} [$V$]")
+ elif key == "coil_voltage_slew_lims":
+ ax.set_ylabel(rf"{key} [$V/s$]")
+ else:
+ ax.set_ylabel(key)
+
+ # y-scaling inside the window
+ times = np.array(self.data[key]["times"])
+ mask = (times >= tmin) & (times <= tmax)
+ if np.any(mask):
+ ydata = np.concatenate(
+ [self.interpolants[key](t), np.array(self.data[key]["vals"])[mask]]
+ )
+ ymin, ymax = np.min(ydata), np.max(ydata)
+ yrange = ymax - ymin
+ if yrange == 0:
+ yrange = 1.0
+ ax.set_ylim(ymin - 0.02 * yrange, ymax + 0.02 * yrange)
+
+ # axes[0].legend(loc='best')
+ axes[-1].set_xlabel(r"Time [$s$]")
+ axes[-1].set_xlim([tmin, tmax])
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
+ plt.show()
diff --git a/freegsnke/control_loop/plasma_category.py b/freegsnke/control_loop/plasma_category.py
new file mode 100644
index 00000000..d4eec11c
--- /dev/null
+++ b/freegsnke/control_loop/plasma_category.py
@@ -0,0 +1,290 @@
+"""
+Module to implement plasma control in FreeGSNKE control loops.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from freegsnke.control_loop.useful_functions import (
+ PID,
+ check_data_entry,
+ interpolate_spline,
+ interpolate_step,
+)
+
+
+class PlasmaController:
+ """
+ A controller class for managing plasma control waveforms.
+
+ Parameters
+ ----------
+ data : dict
+ A dictionary containing waveforms for the plasma current controller. The required keys
+ for both spline-based and step-based waveforms are:
+ - Spline keys: "ip_ref", "ip_blend", "vloop_ff"
+ - Step keys: "k_prop", "k_int", "M_solenoid"
+ Each key should map to a waveform dictionary suitable for interpolation with keys:
+ - 'times': 1D array of time points
+ - 'vals': 1D array of values at those time points (same length).
+
+ Attributes
+ ----------
+ keys_to_spline : list of str
+ Keys corresponding to waveforms that will be interpolated using splines.
+
+ keys_to_step : list of str
+ Keys corresponding to waveforms that will be interpolated using step functions.
+
+ data : dict
+ Internal copy of the input control waveforms.
+
+ interpolants : dict
+ A nested dictionary storing interpolation functions of each input waveform.
+ Structure: {spline/step key: interpolant_function}
+
+ """
+
+ def __init__(
+ self,
+ data,
+ ):
+
+ # check correct data is input and in correct format
+ self.keys_to_spline = ["ip_ref", "ip_blend", "vloop_ff"]
+ self.keys_to_step = ["k_prop", "k_int", "M_solenoid"]
+ for key in self.keys_to_spline + self.keys_to_step:
+ check_data_entry(data=data, key=key, controller_name="PlasmaController")
+
+ # create an internal copy of the data
+ self.data = data
+
+ # interpolate the input data
+ self.update_interpolants()
+
+ def update_interpolants(self):
+ """
+ Recompute all interpolant functions from the current `self.data`.
+
+ This method clears the existing `self.interpolants` dictionary and
+ rebuilds it by applying either `interpolate_spline` or `interpolate_step`
+ depending on whether each key belongs to `self.keys_to_spline` or
+ `self.keys_to_step`.
+
+ """
+
+ # create a dictionary to store the spline functions
+ self.interpolants = {}
+
+ # interpolate the input data
+ for key in self.data.keys():
+ self.interpolants[key] = {}
+ if key in self.keys_to_spline:
+ self.interpolants[key] = interpolate_spline(self.data[key])
+ elif key in self.keys_to_step:
+ self.interpolants[key] = interpolate_step(self.data[key])
+
+ def run_control(
+ self,
+ t,
+ dt,
+ ip_meas,
+ ip_hist_prev,
+ ):
+ """
+ Computes the time derivative of the plasma current request (`dip_dt`) and updates the
+ integral history of the plasma current error (`ip_hist`) using a blended feedback and
+ feedforward control strategy.
+
+ Parameters:
+ ----------
+ t : float
+ Current time [s].
+ dt : float
+ Time step [s].
+ ip_meas : float
+ Measured plasma current at time `t` [A].
+ ip_hist_prev : float
+ Previous value of the integrated plasma current error [A.s].
+
+ Returns:
+ -------
+ dip_dt : float
+ Time derivative of the requested plasma current [A/s].
+ ip_hist : float
+ Updated integral of the plasma current error [A.s].
+
+ Notes:
+ ------
+ - The control law uses time-dependent interpolants for reference current (`ip_ref`),
+ proportional gain (`k_prop`), integral gain (`k_int`), blend factor (`ip_blend`),
+ feedforward voltage (`vloop_ff`), and solenoid inductance (`M_solenoid`).
+ - The blend factor determines the weighting between feedback and feedforward control.
+ - The integral term is computed using the trapezoidal rule for numerical integration.
+ """
+
+ # extract data
+ ip_ref = self.interpolants["ip_ref"](t)
+ k_prop = self.interpolants["k_prop"](t)
+ k_int = self.interpolants["k_int"](t)
+ blend = self.interpolants["ip_blend"](t)
+ vloop_ff = self.interpolants["vloop_ff"](t)
+ M_solenoid = self.interpolants["M_solenoid"](t)
+
+ # proportional term
+ ip_err = ip_ref - ip_meas
+
+ # integral term
+ ip_int = ip_hist_prev + (0.5 * ip_err * dt)
+
+ # FB term
+ dip_dt_FB = PID(
+ error_prop=ip_err,
+ error_int=ip_int,
+ error_deriv=None,
+ k_prop=k_prop,
+ k_int=k_int,
+ k_deriv=0.0,
+ )
+
+ # FF term
+ dip_dt_FF = vloop_ff / M_solenoid
+
+ # time deriv of plasma current request
+ dip_dt = (blend * dip_dt_FB) + ((1 - blend) * dip_dt_FF)
+
+ # update ip_hist
+ ip_hist = ip_hist_prev + (ip_err * dt)
+
+ return dip_dt, ip_hist
+
+ def plot_data(self, tmin=-1.0, tmax=1.0, nt=1001):
+ """
+ Visualizes interpolated control waveforms and corresponding raw inputs.
+
+ This method generates subplots for each control waveform (both spline and step types),
+ showing the interpolated time series alongside the original data points. It helps verify
+ the quality and behavior of the interpolation.
+
+ Parameters
+ ----------
+ tmin : float, optional
+ Start time for the evaluation grid (default is -1.0 seconds).
+ tmax : float, optional
+ End time for the evaluation grid (default is 1.0 seconds).
+ nt : int, optional
+ Number of time points to evaluate the interpolants over the interval [tmin, tmax] (default is 1001).
+
+ Notes
+ -----
+ - Each subplot corresponds to a control waveform (e.g., 'ip_ref', 'ip_blend', 'vloop_ff', 'k_prop', etc.).
+ - Interpolated curves are plotted in navy; raw data points are shown in red.
+ - Axis labels include units where applicable.
+ - Useful for debugging or validating the interpolation quality.
+ """
+
+ # times to plot at
+ t = np.linspace(tmin, tmax, nt)
+ nplots = len(self.keys_to_spline + self.keys_to_step) # number of plots
+
+ # find out which control is ON and when
+ FB_reference = self.interpolants["ip_ref"](t)
+ FF_reference = self.interpolants["vloop_ff"](t)
+ FB_mask = (self.interpolants["ip_blend"](t) > 0) & (np.abs(FB_reference) > 0)
+ FF_mask = (self.interpolants["ip_blend"](t) < 1) & (np.abs(FF_reference) > 0)
+
+ # start plotting
+ fig, axes = plt.subplots(nplots, 1, figsize=(6, 2.5 * nplots), sharex=True)
+
+ if nplots == 1:
+ axes = [axes]
+
+ for ax, key in zip(axes, self.data.keys()):
+
+ # shade region of FB control
+ on_regions = np.where(np.diff(FB_mask.astype(int)) != 0)[0] + 1
+ segments = np.split(t, on_regions)
+ states = np.split(FB_mask, on_regions)
+
+ for seg_t, seg_state in zip(segments, states):
+ if np.all(seg_state): # region fully "on"
+ ax.axvspan(seg_t[0], seg_t[-1], color="green", alpha=0.25)
+
+ # shade region of FF control
+ on_regions = np.where(np.diff(FF_mask.astype(int)) != 0)[0] + 1
+ segments = np.split(t, on_regions)
+ states = np.split(FF_mask, on_regions)
+
+ for seg_t, seg_state in zip(segments, states):
+ if np.all(seg_state): # region fully "on"
+ ax.axvspan(seg_t[0], seg_t[-1], color="yellow", alpha=0.25)
+
+ # raw data
+ ax.scatter(
+ self.data[key]["times"],
+ self.data[key]["vals"],
+ s=12,
+ marker="x",
+ color="tab:orange",
+ alpha=0.9,
+ label=f"raw data",
+ )
+ # interpolated data
+ ax.plot(
+ t,
+ self.interpolants[key](t),
+ color="navy",
+ linewidth=1.5,
+ label="interpolated",
+ )
+
+ ax.grid(True, linestyle="--", alpha=0.6)
+
+ if key == "ip_ref":
+ ax.set_ylabel(rf"{key} [$A$]")
+ elif key == "vloop_ff":
+ ax.set_ylabel(rf"{key} [$V$]")
+ elif key == "k_prop":
+ ax.set_ylabel(rf"{key} [$1/s$]")
+ elif key == "k_int":
+ ax.set_ylabel(rf"{key} [$1/s^2$]")
+ elif key == "M_solenoid":
+ ax.set_ylabel(rf"{key} [$V.s/A$]")
+ else:
+ ax.set_ylabel(key)
+
+ # y-scaling inside the window
+ times = np.array(self.data[key]["times"])
+ mask = (times >= tmin) & (times <= tmax)
+ if np.any(mask):
+ ydata = np.concatenate(
+ [self.interpolants[key](t), np.array(self.data[key]["vals"])[mask]]
+ )
+ ymin, ymax = np.min(ydata), np.max(ydata)
+ yrange = ymax - ymin
+ if yrange == 0:
+ yrange = 1.0
+ ax.set_ylim(ymin - 0.02 * yrange, ymax + 0.02 * yrange)
+
+ axes[0].legend(loc="best")
+ axes[-1].set_xlabel(r"Time [$s$]")
+ axes[-1].set_xlim([tmin, tmax])
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
+ plt.show()
diff --git a/freegsnke/control_loop/shape_category.py b/freegsnke/control_loop/shape_category.py
new file mode 100644
index 00000000..4f41ed56
--- /dev/null
+++ b/freegsnke/control_loop/shape_category.py
@@ -0,0 +1,559 @@
+"""
+Module to implement shape control in FreeGSNKE control loops.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from freegsnke.control_loop.useful_functions import (
+ PID,
+ check_data_entry,
+ interpolate_spline,
+ interpolate_step,
+)
+
+
+class ShapeController:
+ """
+ A controller class for managing shape control waveforms.
+
+ Parameters
+ ----------
+ data : dict
+ A nested dictionary containing waveforms for each target to be controlled. Each target's
+ dictionary must include keys for both spline-based and step-based parameters:
+ - Spline keys: "ff", "ref", "blend"
+ - Step keys: "k_prop", "k_int", "damping"
+ Each key should map to a waveform dictionary suitable for interpolation with keys:
+ - 'times': 1D array of time points
+ - 'vals': 1D array of values at those time points (same length).
+
+ ctrl_targets : list of str
+ A list of shape target names (keys in `data`) that the controller will manage.
+
+ mode : str
+ Choose the type of controller to use, here the default is an "PI_with_P_damping"
+ controller, see "run_control" method for more information.
+
+ Attributes
+ ----------
+ ctrl_targets : list of str
+ The list of shape targets being managed.
+
+ keys_to_spline : list of str
+ Keys corresponding to waveforms that will be interpolated using splines.
+
+ keys_to_step : list of str
+ Keys corresponding to waveforms that will be interpolated using step functions.
+
+ data : dict
+ Internal copy of the input control waveforms.
+
+ interpolants : dict
+ A nested dictionary storing interpolation functions of each input waveform for each
+ shape target.
+ Structure: {target: {spline/step key: interpolant_function}}
+ """
+
+ def __init__(
+ self,
+ data,
+ ctrl_targets,
+ mode=None,
+ ):
+
+ # targets list
+ self.ctrl_targets = ctrl_targets
+
+ # create an internal copy of the data
+ self.data = data
+
+ # choose controller to use (more can be added)
+ if mode is None:
+ mode = "PI_with_P_damping"
+
+ if mode == "PI_with_P_damping":
+ # select control algorithm
+ self.run_control = self.run_control_PI_with_P_damping
+
+ # inputs required for this algorithm
+ self.keys_to_spline = ["ff", "ref", "blend"]
+ self.keys_to_step = ["k_prop", "k_int", "damping"]
+
+ elif mode == "PID_with_scaled_out_damping":
+ # select control algorithm
+ self.run_control = self.run_control_PID_with_scaled_out_damping
+
+ # inputs required for this algorithm
+ self.keys_to_spline = ["ff", "ref", "blend"]
+ self.keys_to_step = ["k_prop", "damping"]
+
+ elif mode == "PID":
+ # select control algorithm
+ self.run_control = self.run_control_PID
+
+ # inputs required for this algorithm
+ self.keys_to_spline = ["ff", "ref", "blend"]
+ self.keys_to_step = ["k_prop", "k_int", "k_deriv"]
+
+ # check correct data is input and in correct format
+ for targ in self.ctrl_targets:
+ for key in self.keys_to_spline + self.keys_to_step:
+ check_data_entry(
+ data=data[targ], key=key, controller_name="ShapeController"
+ )
+
+ # interpolate the input data
+ self.update_interpolants()
+
+ def update_interpolants(self):
+ """
+ Recompute all interpolant functions from the current `self.data`.
+
+ This method clears the existing `self.interpolants` dictionary and
+ rebuilds it by applying either `interpolate_spline` or `interpolate_step`
+ depending on whether each key belongs to `self.keys_to_spline` or
+ `self.keys_to_step`.
+
+ """
+
+ # create a dictionary to store the spline funcions
+ self.interpolants = {}
+
+ # interpolate the input data
+ for targ in self.ctrl_targets:
+ self.interpolants[targ] = {}
+ for key in self.keys_to_spline:
+ self.interpolants[targ][key] = interpolate_spline(self.data[targ][key])
+ for key in self.keys_to_step:
+ self.interpolants[targ][key] = interpolate_step(self.data[targ][key])
+
+ def run_control_PI_with_P_damping(
+ self,
+ t,
+ dt,
+ T_meas,
+ T_err_prev,
+ T_hist_prev,
+ ):
+ """
+ Computes the time derivative of shape target requests based on measured values,
+ reference trajectories, and control gains. It blends feedforward and feedback
+ contributions using a time-varying blend factor, and applies damping to the error
+ signal.
+
+ Parameters
+ ----------
+ t : float
+ Current time [s].
+ dt : float
+ Time step [s].
+ T_meas : np.ndarray
+ Measured values of the shape targets at the current time [m].
+ T_err_prev : np.ndarray
+ Previously filtered error signal (used for damping) [m].
+ T_hist_prev : np.ndarray
+ Previous integral term (used for PI control) [m.s].
+
+ Returns
+ -------
+ dT_dt : np.ndarray
+ Time derivative of the shape target requests [m/s].
+ T_err : np.ndarray
+ Filtered error signal at the current time [m].
+ T_hist : np.ndarray
+ Updated integral term for use in the next control step [m.s].
+
+ Notes
+ -----
+ - The error signal is filtered using a damping factor to smooth transitions.
+ - The integral term is updated using trapezoidal integration.
+ - The final output blends feedforward and feedback derivatives based on a dynamic blend factor.
+ """
+
+ # extract data
+ T_ref = self.extract_values(t=t, targets=self.ctrl_targets, key="ref")
+ T_ff_deriv = self.extract_values(
+ t=t, targets=self.ctrl_targets, key="ff", deriv=True
+ )
+ T_blend = self.extract_values(t=t, targets=self.ctrl_targets, key="blend")
+ k_prop = self.extract_values(t=t, targets=self.ctrl_targets, key="k_prop")
+ k_int = self.extract_values(t=t, targets=self.ctrl_targets, key="k_int")
+ alpha_inv = 1.0 / self.extract_values(
+ t=t, targets=self.ctrl_targets, key="damping"
+ )
+
+ # proportional term
+ T_err = ((1 - alpha_inv) * T_err_prev) + (alpha_inv * (T_ref - T_meas))
+
+ # integral term
+ T_int = T_hist_prev + (0.5 * T_err * dt)
+
+ # update hist
+ T_hist = T_hist_prev + (T_err * dt)
+
+ # FB term
+ T_fb_deriv = PID(
+ error_prop=T_err,
+ error_int=T_int,
+ error_deriv=None,
+ k_prop=k_prop,
+ k_int=k_int,
+ k_deriv=0.0,
+ )
+
+ # time deriv of shape target requests
+ dT_dt = ((T_blend * T_fb_deriv) + ((1.0 - T_blend) * T_ff_deriv)).squeeze()
+
+ return dT_dt.squeeze(), T_err.squeeze(), T_hist.squeeze()
+
+ def run_control_PID_with_scaled_out_damping(
+ self,
+ t,
+ dt,
+ T_meas,
+ T_err_prev,
+ T_hist_prev,
+ ):
+ """
+ Computes the time derivative of shape target requests based on measured values,
+ reference trajectories, and control gains. It blends feedforward and feedback
+ contributions using a time-varying blend factor, and applies damping to the error
+ signal.
+
+ This function re-formulates "run_control_PI_with_scaled_out_damping" to not include a
+ damping term.
+
+ Parameters
+ ----------
+ t : float
+ Current time [s].
+ dt : float
+ Time step [s].
+ T_meas : np.ndarray
+ Measured values of the shape targets at the current time [m].
+ T_err_prev : np.ndarray
+ Previously filtered error signal [m].
+ T_hist_prev : np.ndarray
+ Previous integral term (used for PI control) [m.s].
+
+ Returns
+ -------
+ dT_dt : np.ndarray
+ Time derivative of the shape target requests [m/s].
+ T_err : np.ndarray
+ Filtered error signal at the current time [m].
+ T_hist : np.ndarray
+ Updated integral term for use in the next control step [m.s].
+
+ """
+
+ # extract data
+ T_ref = self.extract_values(t=t, targets=self.ctrl_targets, key="ref")
+ T_ff_deriv = self.extract_values(
+ t=t, targets=self.ctrl_targets, key="ff", deriv=True
+ )
+ T_blend = self.extract_values(t=t, targets=self.ctrl_targets, key="blend")
+ k_prop = self.extract_values(t=t, targets=self.ctrl_targets, key="k_prop")
+ alpha_inv = 1.0 / self.extract_values(
+ t=t, targets=self.ctrl_targets, key="damping"
+ )
+
+ # build PID gains to match damping
+ beta = 1 - alpha_inv
+ abs_beta = np.abs(beta)
+
+ k_int = alpha_inv * (1 + beta) / (1e-4)
+ k_deriv = (abs_beta * k_int * dt - beta) * dt
+ k_prop_new = 1 - k_int * dt / 2 - k_deriv / dt
+
+ # rescale
+ k_int *= k_prop * alpha_inv
+ k_deriv *= k_prop * alpha_inv
+ k_prop = k_prop_new * k_prop * alpha_inv
+
+ # proportional term
+ T_err = T_ref - T_meas
+
+ # integral term
+ T_int = abs_beta ** (dt / 1e-4) * T_hist_prev + (0.5 * T_err * dt)
+
+ # derivative term
+ T_deriv = (T_err - T_err_prev) / dt
+
+ # FB term
+ T_fb_deriv = PID(
+ error_prop=T_err,
+ error_int=T_int,
+ error_deriv=T_deriv,
+ k_prop=k_prop,
+ k_int=k_int,
+ k_deriv=k_deriv,
+ )
+
+ # time deriv of shape target requests
+ dT_dt = ((T_blend * T_fb_deriv) + ((1.0 - T_blend) * T_ff_deriv)).squeeze()
+
+ # update hist
+ T_hist = T_int + (0.5 * T_err * dt)
+
+ return dT_dt.squeeze(), T_err.squeeze(), T_hist.squeeze()
+
+ def run_control_PID(
+ self,
+ t,
+ dt,
+ T_meas,
+ T_err_prev,
+ T_hist_prev,
+ ):
+ """
+ Computes the time derivative of shape target requests based on measured values,
+ reference trajectories, and control gains. It blends feedforward and feedback
+ contributions using a time-varying blend factor.
+
+ Parameters
+ ----------
+ t : float
+ Current time [s].
+ dt : float
+ Time step [s].
+ T_meas : np.ndarray
+ Measured values of the shape targets at the current time [m].
+ T_err_prev : np.ndarray
+ Previously filtered error signal [m].
+ T_hist_prev : np.ndarray
+ Previous integral term (used for PI control) [m.s].
+
+ Returns
+ -------
+ dT_dt : np.ndarray
+ Time derivative of the shape target requests [m/s].
+ T_err : np.ndarray
+ Filtered error signal at the current time [m].
+ T_hist : np.ndarray
+ Updated integral term for use in the next control step [m.s].
+
+ Notes
+ -----
+ - The integral term is updated using trapezoidal integration.
+ - The final output blends feedforward and feedback derivatives based on a dynamic blend factor.
+ - THIS FUNCTION IS UNTESTED.
+ """
+
+ # extract data
+ T_ref = self.extract_values(t=t, targets=self.ctrl_targets, key="ref")
+ T_ff_deriv = self.extract_values(
+ t=t, targets=self.ctrl_targets, key="ff", deriv=True
+ )
+ T_blend = self.extract_values(t=t, targets=self.ctrl_targets, key="blend")
+ k_prop = self.extract_values(t=t, targets=self.ctrl_targets, key="k_prop")
+ k_int = self.extract_values(t=t, targets=self.ctrl_targets, key="k_int")
+ k_deriv = self.extract_values(t=t, targets=self.ctrl_targets, key="k_deriv")
+
+ # proportional term
+ T_err = T_ref - T_meas
+
+ # integral term
+ T_int = T_hist_prev + (0.5 * T_err * dt)
+
+ # derivative term
+ T_deriv = (T_err - T_err_prev) / dt
+
+ # FB term
+ T_fb_deriv = PID(
+ error_prop=T_err,
+ error_int=T_int,
+ error_deriv=T_deriv,
+ k_prop=k_prop,
+ k_int=k_int,
+ k_deriv=k_deriv,
+ )
+
+ # time deriv of shape target requests
+ dT_dt = ((T_blend * T_fb_deriv) + ((1.0 - T_blend) * T_ff_deriv)).squeeze()
+
+ # update hist
+ T_hist = T_hist_prev + (T_err * dt)
+
+ return dT_dt.squeeze(), T_err.squeeze(), T_hist.squeeze()
+
+ def extract_values(
+ self,
+ t,
+ targets,
+ key,
+ deriv=False,
+ ):
+ """
+ Extracts interpolated values or their derivatives for specified shape targets at a given time.
+
+ This method queries the stored interpolation functions for each target and key, returning either
+ the interpolated value or its first derivative depending on the `deriv` flag.
+
+ Parameters
+ ----------
+ t : float
+ Time at which to evaluate the interpolants [s].
+ targets : list of str
+ List of shape target names. Each must correspond to a key in `self.interpolants`.
+ key : str
+ The waveform name (e.g., 'ff', 'ref', 'blend', 'k_prop', etc.) used to select the interpolant.
+ deriv : bool, optional
+ If True, returns the first derivative of the interpolant at time `t`. Default is False.
+
+ Returns
+ -------
+ np.ndarray
+ Array of interpolated values (or derivatives) for each target at time `t`.
+
+ Notes
+ -----
+ - Assumes that `self.interpolants[target][key]` is a valid `scipy.interpolate` object.
+ - If `deriv=True`, the method calls `.derivative()` on the interpolant before evaluation.
+ """
+
+ if deriv:
+ return np.array(
+ [self.interpolants[target][key].derivative()(t) for target in targets]
+ )
+ else:
+ return np.array([self.interpolants[target][key](t) for target in targets])
+
+ def plot_data(self, targ, tmin=-1.0, tmax=1.0, nt=1001):
+ """
+ Visualizes interpolated control waveforms and corresponding raw inputs for a specified
+ shape target.
+
+ This method generates subplots for each control waveform (both spline and step types),
+ showing the interpolated time series alongside the original data points. It helps verify
+ the quality and behavior of the interpolation.
+
+ Parameters
+ ----------
+ targ : str
+ The name of the shape target waveforms to plot. Must be a key in `self.interpolants`
+ and `self.data`.
+ tmin : float, optional
+ Start time for the evaluation grid (default is -1.0 seconds).
+ tmax : float, optional
+ End time for the evaluation grid (default is 1.0 seconds).
+ nt : int, optional
+ Number of time points to evaluate the interpolants over the interval [tmin, tmax]
+ (default is 10001).
+
+ Notes
+ -----
+ - Each subplot corresponds to a control parameter (e.g., 'ff', 'ref', 'blend', 'k_prop', etc.).
+ - Interpolated curves are plotted in navy; raw data points are shown in red.
+ - Axis labels include units where applicable.
+ - Useful for debugging or validating the interpolation quality.
+ """
+
+ # times to plot at
+ t = np.linspace(tmin, tmax, nt)
+ nplots = len(self.keys_to_spline + self.keys_to_step) # number of plots
+
+ # find out which control is ON and when
+ FF_reference = self.interpolants[targ]["ff"](t)
+ FF_mask = (self.interpolants[targ]["blend"](t) < 1) * (
+ np.abs(self.interpolants[targ]["ff"].derivative()(t)) > 0
+ )
+ FB_reference = self.interpolants[targ]["ref"](t)
+ FB_mask = (self.interpolants[targ]["blend"](t) > 0) * (
+ np.abs(self.interpolants[targ]["ref"](t)) > 0
+ )
+
+ # start plotting
+ fig, axes = plt.subplots(nplots, 1, figsize=(6, 2.5 * nplots), sharex=True)
+
+ if nplots == 1:
+ axes = [axes]
+
+ for ax, key in zip(axes, self.keys_to_spline + self.keys_to_step):
+
+ # shade region of FB control
+ on_regions = np.where(np.diff(FB_mask.astype(int)) != 0)[0] + 1
+ segments = np.split(t, on_regions)
+ states = np.split(FB_mask, on_regions)
+
+ for seg_t, seg_state in zip(segments, states):
+ if np.all(seg_state): # region fully "on"
+ ax.axvspan(seg_t[0], seg_t[-1], color="green", alpha=0.25)
+
+ # shade region of FF control
+ on_regions = np.where(np.diff(FF_mask.astype(int)) != 0)[0] + 1
+ segments = np.split(t, on_regions)
+ states = np.split(FF_mask, on_regions)
+
+ for seg_t, seg_state in zip(segments, states):
+ if np.all(seg_state): # region fully "on"
+ ax.axvspan(seg_t[0], seg_t[-1], color="yellow", alpha=0.25)
+
+ # raw data
+ ax.scatter(
+ self.data[targ][key]["times"],
+ self.data[targ][key]["vals"],
+ s=10,
+ marker="x",
+ color="tab:orange",
+ label=f"raw data",
+ )
+ # interpolated data
+ ax.plot(
+ t,
+ self.interpolants[targ][key](t),
+ color="navy",
+ linewidth=1.2,
+ label="interpolated",
+ )
+ ax.grid(True, linestyle="--", alpha=0.6)
+
+ if key in ["ref", "ff"]:
+ ax.set_ylabel(rf"{key} [$m$]")
+ elif key == "k_prop":
+ ax.set_ylabel(rf"{key} [$1/s$]")
+ elif key == "k_int":
+ ax.set_ylabel(rf"{key} [$1/s^2$]")
+ else:
+ ax.set_ylabel(key)
+
+ # y-scaling inside the window
+ times = np.array(self.data[targ][key]["times"])
+ mask = (times >= tmin) & (times <= tmax)
+ if np.any(mask):
+ ydata = np.concatenate(
+ [
+ self.interpolants[targ][key](t),
+ np.array(self.data[targ][key]["vals"])[mask],
+ ]
+ )
+ ymin, ymax = np.min(ydata), np.max(ydata)
+ yrange = ymax - ymin
+ if yrange == 0:
+ yrange = 1.0
+ ax.set_ylim(ymin - 0.02 * yrange, ymax + 0.02 * yrange)
+
+ fig.suptitle(targ)
+ axes[0].legend(loc="best")
+ axes[-1].set_xlabel(r"Time [$s$]")
+ axes[-1].set_xlim([tmin, tmax])
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
+ plt.show()
diff --git a/freegsnke/control_loop/systems_category.py b/freegsnke/control_loop/systems_category.py
new file mode 100644
index 00000000..42ee1078
--- /dev/null
+++ b/freegsnke/control_loop/systems_category.py
@@ -0,0 +1,344 @@
+"""
+Module to implement systems control in FreeGSNKE control loops.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from freegsnke.control_loop.useful_functions import (
+ check_data_entry,
+ interpolate_spline,
+ interpolate_step,
+)
+
+
+class SystemsController:
+ """
+ A controller class for managing coil current perturbations, coil current limits, and
+ coil current ramp rate limits.
+
+ Parameters
+ ----------
+ data : dict
+ A nested dictionary containing control waveforms for the systems controller.
+ The required keys for both spline-based and step-based waveforms are:
+ - Spline keys: "_pert"
+ - Step keys: "min_coil_curr_lims", "max_coil_curr_lims", "max_coil_curr_ramp_lims"
+ Each key should map to a waveform dictionary suitable for interpolation with keys:
+ - 'times': 1D array of time points
+ - 'vals': 1D array of values at those time points (same length).
+
+ ctrl_coils : list of str
+ The list of active coils being controlled.
+
+ Attributes
+ ----------
+ ctrl_coils : list of str
+ The list of active coils being controlled.
+
+ keys_to_spline : list of str
+ Keys corresponding to waveforms that will be interpolated using splines.
+
+ keys_to_step : list of str
+ Keys corresponding to waveforms that will be interpolated using step functions.
+
+ data : dict
+ Internal copy of the input control waveforms.
+
+ interpolants : dict
+ A nested dictionary storing interpolation functions of each input waveform.
+ Structure: {spline/step key: interpolant_function}
+
+ """
+
+ def __init__(
+ self,
+ data,
+ ctrl_coils,
+ ):
+
+ # coils list
+ self.ctrl_coils = ctrl_coils
+
+ # check correct data is input and in correct format
+ self.keys_to_spline = [coil + "_pert" for coil in self.ctrl_coils]
+ self.keys_to_step = [
+ "min_coil_curr_lims",
+ "max_coil_curr_lims",
+ "max_coil_curr_ramp_lims",
+ ]
+ for key in self.keys_to_spline + self.keys_to_step:
+ check_data_entry(data=data, key=key, controller_name="SystemsController")
+
+ # create an internal copy of the data
+ self.data = data
+
+ # interpolate the input data
+ self.update_interpolants()
+
+ def update_interpolants(self):
+ """
+ Recompute all interpolant functions from the current `self.data`.
+
+ This method clears the existing `self.interpolants` dictionary and
+ rebuilds it by applying either `interpolate_spline` or `interpolate_step`
+ depending on whether each key belongs to `self.keys_to_spline` or
+ `self.keys_to_step`.
+
+ """
+
+ # create a dictionary to store the spline functions
+ self.interpolants = {}
+
+ # interpolate the input data
+ for key in self.keys_to_spline:
+ self.interpolants[key] = interpolate_spline(self.data[key])
+ for key in self.keys_to_step:
+ self.interpolants[key] = interpolate_step(self.data[key])
+
+ def run_control(self, t, dt, I_unapproved, dI_dt_unapproved, verbose=False):
+ """
+ Applies coil current perturbations to unapproved coil currents and enforce coil current
+ constraints to produce approved control signals.
+
+ This method adjusts the unapproved coil currents and their rates of change by applying
+ time-dependent perturbations, then clips the results according to current and ramp rate
+ limits. It returns the final approved coil currents and their derivatives.
+
+ Parameters
+ ----------
+ t : float
+ Current time [s].
+
+ dt : float
+ Time step [s].
+
+ I_unapproved : numpy.ndarray
+ Coil currents (not yet approved), computed via Euler integration [A].
+
+ dI_dt_unapproved : numpy.ndarray
+ Rate of change of coil currents (not yet approved) [A/s].
+
+ verbose : bool, optional
+ If True, prints diagnostic messages about clipping and approved values.
+
+ Returns
+ -------
+ I_approved : numpy.ndarray
+ Coil currents (approved) [A].
+
+ dI_dt_approved : numpy.ndarray
+ Rate of change of coil currents (approved) [A/s].
+
+ """
+
+ # extract coil current perturbations
+ dI_pert_dt = self.extract_values(t=t, targets=self.ctrl_coils, deriv=True)
+
+ # add perturbations
+ I_perturbed = I_unapproved + dI_pert_dt * dt
+ dI_dt_perturbed = dI_dt_unapproved + dI_pert_dt
+
+ # extract coil current limits and ramp rate limits
+ min_coil_curr_lims = self.interpolants["min_coil_curr_lims"](t)
+ max_coil_curr_lims = self.interpolants["max_coil_curr_lims"](t)
+ max_coil_curr_ramp_lims = self.interpolants["max_coil_curr_ramp_lims"](t)
+
+ # apply the clipping
+ I_approved = np.clip(I_perturbed, min_coil_curr_lims, max_coil_curr_lims)
+ dI_dt_approved = np.clip(
+ dI_dt_perturbed, -max_coil_curr_ramp_lims, max_coil_curr_ramp_lims
+ )
+
+ # print if required
+ if verbose:
+ print("---")
+
+ if not np.allclose(I_approved, I_perturbed):
+ print(" Coil currents clipped (according to `min/max_coil_limits`).")
+
+ if not np.allclose(dI_dt_approved, dI_dt_perturbed):
+ print(
+ " Coil current deltas clipped (according to `max_coil_delta_limits`)."
+ )
+
+ print(f" Approved coil currents = {I_approved}")
+ print(f" Approved delta coil currents = {dI_dt_approved}")
+
+ return I_approved.squeeze(), dI_dt_approved.squeeze()
+
+ def extract_values(
+ self,
+ t,
+ targets,
+ deriv=False,
+ ):
+ """
+ Extracts interpolated values or their derivatives for specified shape targets at a given time.
+
+ This method queries the stored interpolation functions for each target and key, returning either
+ the interpolated value or its first derivative depending on the `deriv` flag.
+
+ Parameters
+ ----------
+ t : float
+ Time at which to evaluate the interpolants [s].
+ targets : list of str
+ List of keys. Each must correspond to a key in `self.interpolants`.
+ deriv : bool, optional
+ If True, returns the first derivative of the interpolant at time `t`. Default is False.
+
+ Returns
+ -------
+ np.ndarray
+ Array of interpolated values (or derivatives) for each target at time `t`.
+
+ Notes
+ -----
+ - Assumes that `self.interpolants[target]` is a valid `scipy.interpolate` object.
+ - If `deriv=True`, the method calls `.derivative()` on the interpolant before evaluation.
+ """
+
+ if deriv:
+ return np.array(
+ [
+ self.interpolants[target + "_pert"].derivative(n=1)(t)
+ for target in targets
+ ]
+ )
+ else:
+ return np.array(
+ [self.interpolants[target + "_pert"](t) for target in targets]
+ )
+
+ def plot_data(self, tmin=-1.0, tmax=1.0, nt=1001):
+ """
+ Visualizes interpolated control waveforms and corresponding raw inputs.
+
+ This method generates subplots for each control waveform (spline types),
+ showing the interpolated time series alongside the original data points. It helps verify
+ the quality and behavior of the interpolation.
+
+ Parameters
+ ----------
+ tmin : float, optional
+ Start time for the evaluation grid (default is -1.0 seconds).
+ tmax : float, optional
+ End time for the evaluation grid (default is 1.0 seconds).
+ nt : int, optional
+ Number of time points to evaluate the interpolants over the interval [tmin, tmax] (default is 1001).
+
+ Notes
+ -----
+ - Each subplot corresponds to a control waveform (e.g., '_pert').
+ - Interpolated curves are plotted in navy; raw data points are shown in red.
+ - Axis labels include units where applicable.
+ - Useful for debugging or validating the interpolation quality.
+ """
+
+ # times to plot at
+ t = np.linspace(tmin, tmax, nt)
+ nplots = len(self.keys_to_spline + self.keys_to_step) # number of plots
+
+ # start plotting
+ fig, axes = plt.subplots(nplots, 1, figsize=(6, 2.5 * nplots), sharex=True)
+
+ if nplots == 1:
+ axes = [axes]
+
+ for ax, key in zip(axes, self.data.keys()):
+ times = np.asarray(self.data[key]["times"])
+ vals_list = self.data[key]["vals"]
+
+ # find out which control is ON and when
+ if key in self.keys_to_spline:
+ FF_reference = self.interpolants[key].derivative()(t)
+ FF_mask = np.abs(FF_reference) > 0
+
+ # shade region of FF control
+ on_regions = np.where(np.diff(FF_mask.astype(int)) != 0)[0] + 1
+ segments = np.split(t, on_regions)
+ states = np.split(FF_mask, on_regions)
+
+ for seg_t, seg_state in zip(segments, states):
+ if np.all(seg_state): # region fully "on"
+ if len(seg_t) > 0:
+ ax.axvspan(seg_t[0], seg_t[-1], color="yellow", alpha=0.25)
+
+ if np.isscalar(vals_list[0]):
+ ax.scatter(
+ self.data[key]["times"],
+ self.data[key]["vals"],
+ s=10,
+ marker="x",
+ color="tab:orange",
+ alpha=0.9,
+ label=f"raw data",
+ )
+ else:
+ m = len(vals_list[0])
+ times_repeated = np.repeat(times, m)
+ vals_flat = np.concatenate(vals_list)
+ ax.scatter(
+ times_repeated,
+ vals_flat,
+ s=10,
+ marker="x",
+ color="tab:orange",
+ alpha=0.9,
+ label=f"raw data",
+ )
+
+ ax.plot(
+ t,
+ self.interpolants[key](t),
+ color="navy",
+ linewidth=1.2,
+ label="interpolated",
+ )
+ ax.grid(True, linestyle="--", alpha=0.6)
+
+ if key[-4:] == "pert":
+ ax.set_ylabel(rf"{key} [$A$]")
+ elif key in ["min_coil_curr_lims", "min_coil_curr_lims"]:
+ ax.set_ylabel(rf"{key} [$A$]")
+ elif key == "max_coil_curr_ramp_lims":
+ ax.set_ylabel(rf"{key} [$A/s$]")
+ else:
+ ax.set_ylabel(key)
+
+ # y-scaling inside the window
+ times = np.array(self.data[key]["times"])
+ mask = (times >= tmin) & (times <= tmax)
+ if np.any(mask):
+ ydata = np.concatenate(
+ [self.interpolants[key](t), np.array(self.data[key]["vals"])[mask]]
+ )
+ ymin, ymax = np.min(ydata), np.max(ydata)
+ yrange = ymax - ymin
+ if yrange == 0:
+ yrange = 1.0
+ ax.set_ylim(ymin - 0.02 * yrange, ymax + 0.02 * yrange)
+
+ axes[0].legend(loc="best")
+ axes[-1].set_xlabel(r"Time [$s$]")
+ axes[-1].set_xlim([tmin, tmax])
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
+ plt.show()
diff --git a/freegsnke/control_loop/useful_functions.py b/freegsnke/control_loop/useful_functions.py
new file mode 100644
index 00000000..c051813a
--- /dev/null
+++ b/freegsnke/control_loop/useful_functions.py
@@ -0,0 +1,209 @@
+"""
+Module of functions required by the PCS in FreeGSNKE.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+import numpy as np
+from scipy.interpolate import UnivariateSpline, interp1d
+
+
+def interpolate_step(
+ data,
+):
+ """
+ Creates a step-wise interpolator for time-series data using 'previous' value interpolation.
+
+ Parameters
+ ----------
+ data : dict
+ Dictionary with keys:
+ - 'times': 1D array of time points
+ - 'vals': 1D array of values at those time points (same length)
+
+ Returns
+ -------
+ f_interp : function
+ Callable function f(t) that returns the step-wise interpolated value at time t.
+ For t < min(times), returns the first value.
+ For t > max(times), returns the last value.
+ """
+
+ times = np.array(data["times"])
+ vals = np.stack(data["vals"])
+
+ # build interpolator
+ f_interp = interp1d(
+ times,
+ vals,
+ kind="previous",
+ axis=0,
+ bounds_error=False,
+ fill_value=(0.0, vals[-1]), # extrapolate for first and last values
+ )
+
+ return f_interp
+
+
+def interpolate_spline(data):
+ """
+ Creates a spline interpolator for time-series data in 'data'.
+
+ Parameters
+ ----------
+ data : dict
+ Dictionary with keys:
+ - 'times': 1D array of time points
+ - 'vals': 1D array of values at those time points (same length)
+
+ Returns
+ -------
+ f_interp : function
+ Callable function f(t) that returns the spline interpolated value at time t.
+ For t < min(times), returns the first value.
+ For t > max(times), returns the last value.
+ """
+
+ times = np.array(data["times"])
+ vals = np.array(data["vals"])
+
+ # build interpolator
+ f_interp = UnivariateSpline(
+ times,
+ vals,
+ k=1, # order (linear)
+ s=0, # interpolates points exactly
+ ext="zeros", # extrapolate to zeros outside of boundary points
+ )
+
+ return f_interp
+
+
+def check_data_entry(
+ data: dict,
+ key: str,
+ controller_name: str,
+) -> bool:
+ """
+ Validate that a specified sub-dictionary contains 'times' and 'vals' keys
+ of equal length.
+
+ Parameters
+ ----------
+ data : dict
+ A dictionary where each value is expected to be a sub-dictionary
+ containing at least 'times' and 'vals'.
+ key : str
+ The key in `data` corresponding to the sub-dictionary to validate.
+ controller_name : str
+ A string corresponding to which controller is being checked.
+
+ Returns
+ -------
+ bool
+ True if the checks pass.
+
+ Raises
+ ------
+ ValueError
+ If the specified key is missing from `data`, if 'times' or 'vals'
+ is missing from the sub-dictionary, or if 'times' and 'vals'
+ are not the same length.
+ """
+
+ # key not found
+ if key not in data:
+ raise ValueError(
+ f"{controller_name}: Key '{key}' not found in 'data'. "
+ f"Please include {{'times': [], 'vals': []}} for '{key}'."
+ )
+
+ subdict = data[key]
+
+ # key found, check for times and values
+ for required_key in ["times", "vals"]:
+ if required_key not in subdict:
+ raise ValueError(
+ f"{controller_name}: Missing '{required_key}' in data['{key}']."
+ )
+
+ # times and vals found, check equal lengths
+ times_len = len(subdict["times"])
+ vals_len = len(subdict["vals"])
+ if times_len != vals_len:
+ raise ValueError(
+ f"{controller_name}: Length mismatch in data['{key}']: "
+ f"'times' has length {times_len}, 'vals' has length {vals_len}. "
+ )
+
+
+def PID(
+ error_prop=None,
+ error_int=None,
+ error_deriv=None,
+ k_prop=0.0,
+ k_int=0.0,
+ k_deriv=0.0,
+):
+ """
+ Compute a flexible PID controller output.
+
+ Any of the P, I, or D components may be omitted. If a gain or the
+ corresponding error term is not provided (None), that component
+ contributes zero to the output.
+
+ Parameters
+ ----------
+ error_prop : float or array_like, optional
+ Proportional error term. If None, the P contribution is zero.
+ error_int : float or array_like, optional
+ Integral error term. If None, the I contribution is zero.
+ error_deriv : float or array_like, optional
+ Derivative error term. If None, the D contribution is zero.
+ k_prop : float or array_like, optional
+ Proportional gain. Default is 0.
+ k_int : float or array_like, optional
+ Integral gain. Default is 0.
+ k_deriv : float or array_like, optional
+ Derivative gain. Default is 0.
+
+ Returns
+ -------
+ float or ndarray
+ The PID (or PI, PD, P, I, D, or ID) controller output. Arrays must be
+ broadcast-compatible if array inputs are used.
+
+ Notes
+ -----
+ - A component contributes only if both its gain and error term are provided.
+ - This function performs no time integration or differentiation; the caller
+ must compute error_int and error_deriv externally.
+ """
+
+ out = 0
+
+ if error_prop is not None:
+ out += k_prop * error_prop
+
+ if error_int is not None:
+ out += k_int * error_int
+
+ if error_deriv is not None:
+ out += k_deriv * error_deriv
+
+ return out
diff --git a/freegsnke/control_loop/vc_provider.py b/freegsnke/control_loop/vc_provider.py
new file mode 100644
index 00000000..8d483613
--- /dev/null
+++ b/freegsnke/control_loop/vc_provider.py
@@ -0,0 +1,391 @@
+"""
+Defines the base class, `VirtualCircuitProvider`, for a Virtual Circuit provider. Such
+a provider promises to provide a Virtual Circuit given a timestamp and a means to
+extract observables regarding the equilibrium. The mechanism by which the Virtual
+Circuit is produced, and the observables that are or are not requested for the purpose
+of Virtual Circuit construction is not constrained.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+import abc
+import time
+from copy import deepcopy
+
+import numpy as np
+
+from freegsnke.observable_registry import ObservableRegistry
+from freegsnke.virtual_circuits import VirtualCircuit, VirtualCircuitHandling
+
+
+class VirtualCircuitProvider(abc.ABC):
+ """
+ Defines the interface for a Virtual Circuit provider.
+
+ TODO(Matthew): We here have get_vc require the ObservableRegistry, but as this can
+ be in-principle stateless (and if stateful it can be keyed on
+ timestamp).
+ """
+
+ def __init__(self, observable_registry: ObservableRegistry | None = None):
+ """
+ Initialise the virtual circuit provider.
+
+ Parameters
+ ----------
+ observable_registry : ObservableRegistry | None (default: None)
+ The observable registry to set the provider to use.
+ """
+ self._observable_registry = None
+ if observable_registry is None or self._validate_observable_registry(
+ observable_registry
+ ):
+ self._observable_registry = observable_registry
+
+ @abc.abstractmethod
+ def get_vc(
+ self,
+ targets: list[str],
+ coils: list[str],
+ coils_calc: list[str],
+ input_data,
+ ) -> np.ndarray | None:
+ """
+ Gets a Virtual Circuit for the given timestamp and observables requested from
+ the registry.
+
+ Parameters
+ ----------
+ targets : list[str]
+ list of targets to get a virtual circuit for
+ coils : list[str]
+ list of coils to return VC matrix for.
+ coils_calc : list[str]
+ list of coils to use for computing VC/jacobians.
+ Must be subset or equal to coils above
+ input_data :
+ data needed to compute/retrieve VC.
+ For example array of inputs for NN emulators, eq/profile for Freegsnke.
+
+ Returns
+ -------
+ vc : VirtualCircuit | None
+ virtual circuit object to be used by the control voltages class or None if
+ no virtual circuit could be obtained or constructed.
+ """
+ pass
+
+ def set_observable_registry(self, observable_registry: ObservableRegistry) -> bool:
+ """
+ Sets observable registry to provided registry if it provides the necessary
+ observables for the provider to execute get_vc correctly.
+
+ Parameters
+ ----------
+ observable_registry : ObservableRegistry | None (default: None)
+ The observable registry to set the provider to use.
+ """
+ if not self._validate_observable_registry(observable_registry):
+ return False
+
+ self._observable_registry = observable_registry
+ return True
+
+ @abc.abstractmethod
+ def _validate_observable_registry(
+ self, observable_registry: ObservableRegistry
+ ) -> bool:
+ """
+ Determine if the provided observable registry satisfies the necessary
+ requirements for get_vc to be executed correctly. E.g. does it provide access to
+ all the physical parameters of an equilibrium needed by a model.
+
+ Parameters
+ ----------
+ observable_registry : ObservableRegistry
+ The observable registry to validate.
+ """
+ pass
+
+
+class VCGenerator(VirtualCircuitProvider):
+ """
+ Virtual Circuit (VC) generator based on FreeGSNKE's
+ ``VirtualCircuitHandling`` infrastructure.
+
+ This class acts as an adapter between a control or optimisation framework
+ and FreeGSNKE's internal VC computation routines. It allows:
+
+ - Mapping between user-facing and internal target names
+ - Computation of virtual circuit matrices for a selected subset of coils
+ - Optional access to sensitivity (shape/derivative) matrices
+
+ The class assumes that equilibrium and profile objects are provided
+ externally (e.g. via an observable registry).
+ """
+
+ def __init__(self, solver):
+ """
+ Initialise the VC generator and bind it to a FreeGSNKE solver.
+
+ This sets up a ``VirtualCircuitHandling`` instance and registers the
+ solver object required for VC computations. It also defines the default
+ set of supported targets and their internal naming conventions.
+
+ Default available targets are:
+
+ - ``"R_in"`` : Inner midplane radius
+ - ``"R_out"`` : Outer midplane radius
+ - ``"Rx_lower"`` : Lower X-point radial position
+ - ``"Zx_lower"`` : Lower X-point vertical position
+ - ``"Rx_upper"`` : Upper X-point radial position
+ - ``"Zx_upper"`` : Upper X-point vertical position
+ - ``"Rs_lower_outer"`` : Lower outer strike-point radius
+ - ``"Rs_upper_outer"`` : Upper outer strike-point radius
+
+ Parameters
+ ----------
+ solver : object
+ A FreeGSNKE solver instance used internally by
+ ``VirtualCircuitHandling`` to compute virtual circuits.
+ """
+
+ self.VCH = VirtualCircuitHandling()
+ self.VCH.define_solver(solver)
+
+ # intenrnal names of targets, as prescribed in freegsnke.
+ self.target_names_internal = [
+ "R_in",
+ "R_out",
+ "Rx_lower",
+ "Zx_lower",
+ "Rx_upper",
+ "Zx_upper",
+ "Rs_lower_outer",
+ "Rs_upper_outer",
+ ]
+ self.target_names_user = deepcopy(self.target_names_internal)
+ self.targets_user_to_internal = dict(
+ zip(self.target_names_user, self.target_names_internal)
+ )
+
+ def rename_targets(self, names_user: list[str], names_internal: list[str]):
+ """
+ Rename target labels exposed to the user or control code.
+
+ This method allows user-facing target names to differ from the
+ internal FreeGSNKE naming scheme. The mapping is order-dependent:
+ each entry in ``names_user`` corresponds to the entry at the same
+ index in ``names_internal``.
+
+ Parameters
+ ----------
+ names_user : list[str]
+ New target names to be exposed to the user.
+ names_internal : list[str]
+ Existing internal target names to be replaced.
+
+ Returns
+ -------
+ None
+ Updates ``target_names_user`` and ``targets_user_to_internal`` in-place.
+ """
+
+ internal_to_user = dict(zip(names_internal, names_user))
+ user_to_internal = dict(zip(names_user, names_internal))
+
+ new_labels = []
+ for label in self.target_names_internal:
+ if label in internal_to_user.keys():
+ new_labels.append(internal_to_user[label])
+ else:
+ new_labels.append(label)
+ self.target_names_user = new_labels
+
+ for key, item in user_to_internal.items():
+ self.targets_user_to_internal[key] = item
+
+ print("Targets renamed")
+ print(self.target_names_user)
+ print("user to internal", self.targets_user_to_internal)
+
+ def get_vc(
+ self,
+ targets: list[str],
+ coils: list[str],
+ coils_calc: list[str],
+ input_data: tuple,
+ sensitivity=False,
+ ):
+ """
+ Compute the virtual circuit (VC) matrix for a given set of targets and coils.
+
+ The VC matrix maps coil current perturbations to changes in the selected
+ plasma shape or position targets. Only a subset of coils may be used
+ for the VC computation, but the returned matrix is expanded to include
+ all coils provided in ``coils``.
+
+ Parameters
+ ----------
+ targets : list[str]
+ User-facing names of targets to include (order is preserved).
+ coils : list[str]
+ Full list of coils defining the output matrix row ordering.
+ coils_calc : list[str]
+ Subset of coils actually used in the VC calculation.
+ input_data : tuple
+ Tuple of inputs required for VC computation.
+ Expected to be ``(equilibrium, profiles)``.
+ sensitivity : bool, optional
+ If ``True``, return the sensitivity (shape/derivative) matrix instead
+ of the VC matrix. Default is ``False``.
+
+ Returns
+ -------
+ vc_matrix : np.ndarray
+ Expanded virtual circuit matrix of shape
+ ``(len(coils), len(targets))`` if ``sensitivity=False``.
+ derivative_matrix : np.ndarray
+ Sensitivity (shape) matrix if ``sensitivity=True``.
+ """
+
+ # get inputs
+ # print(input_data)
+ eq = input_data[0]
+ profiles = input_data[1]
+ # print(eq)
+ # print(profiles)
+
+ t1 = time.time()
+ # print(f"Computing VCs for {targets}")
+ # print("targets user", targets)
+ # convert back to internal names
+ targets_internal = [self.targets_user_to_internal[t] for t in targets]
+ # print("targets internal ", targets_internal)
+
+ # print("coils for calc", coils_calc)
+
+ # compute vc's
+ self.VCH.calculate_VC(
+ eq=eq,
+ profiles=profiles,
+ coils=coils_calc,
+ targets=targets_internal,
+ targets_options=None,
+ )
+ vc_matrix = self.VCH.latest_VC.VCs_matrix
+ derivative_matrix = self.VCH.latest_VC.shape_matrix
+ # print("small matrix shape")
+ # print(np.shape(vc_matrix))
+
+ # larger full matrix, including zeros
+ vc_matrix_big = np.zeros((len(coils), len(targets)))
+ # print("big matrix shape")
+ # print(np.shape(vc_matrix_big))
+
+ # index dict
+ index_coils = {coil: i for i, coil in enumerate(coils)}
+ # fill out rows, keeping target order
+ for i, coil in enumerate(coils_calc):
+ ind = index_coils[coil]
+ vc_matrix_big[ind, :] = 1.0 * vc_matrix[i, :]
+
+ t2 = time.time()
+ # print("VC compute time", t2 - t1)
+ if sensitivity == False:
+ return vc_matrix_big
+ elif sensitivity == True:
+ return derivative_matrix
+
+ def get_inputs(self, eq, profiles):
+ """
+ Package equilibrium and profile data into the input format expected
+ by ``get_vc``.
+
+ This method exists for compatibility with higher-level infrastructure
+ (e.g. observable registries).
+
+ Parameters
+ ----------
+ eq : object
+ Equilibrium object.
+ profiles : object
+ Plasma profile data.
+
+ Returns
+ -------
+ tuple
+ ``(eq, profiles)``
+ """
+ return eq, profiles
+
+ def set_observable_registry(self, observable_registry: ObservableRegistry) -> bool:
+ """
+ Set the observable registry used by this provider.
+
+ The registry is only accepted if it satisfies the requirements checked
+ by ``_validate_observable_registry``.
+
+ Parameters
+ ----------
+ observable_registry : ObservableRegistry
+ Registry providing access to equilibrium and profile observables.
+
+ Returns
+ -------
+ bool
+ ``True`` if the registry was accepted, ``False`` otherwise.
+ """
+ if not self._validate_observable_registry(observable_registry):
+ return False
+
+ self._observable_registry = observable_registry
+ return True
+
+ def _validate_observable_registry(
+ self, observable_registry: ObservableRegistry
+ ) -> bool:
+ """
+ Validate that an observable registry provides all data required
+ for VC computation.
+
+ This method should check that the registry can supply, at a minimum,
+ the equilibrium and profile information needed by ``get_vc``.
+ Typical checks may include:
+
+ - Availability of equilibrium objects
+ - Availability of plasma profiles
+ - Consistent update semantics
+
+ Parameters
+ ----------
+ observable_registry : ObservableRegistry
+ The observable registry to validate.
+
+ Returns
+ -------
+ bool
+ ``True`` if the registry satisfies all requirements,
+ ``False`` otherwise.
+
+ Notes
+ -----
+ This method is currently unimplemented and should be extended
+ as the observable interface is finalised.
+ """
+ pass
diff --git a/freegsnke/control_loop/vertical_category.py b/freegsnke/control_loop/vertical_category.py
new file mode 100644
index 00000000..11b3b8c0
--- /dev/null
+++ b/freegsnke/control_loop/vertical_category.py
@@ -0,0 +1,243 @@
+"""
+Module to implement vertical plasma control in FreeGSNKE control loops.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from freegsnke.control_loop.useful_functions import (
+ PID,
+ check_data_entry,
+ interpolate_spline,
+ interpolate_step,
+)
+
+
+class VerticalController:
+ """
+ A controller class for managing vertical plasma control.
+
+ Parameters
+ ----------
+ data : dict
+ A nested dictionary containing control waveforms for the vertical controller.
+ The required keys for both spline-based and step-based waveforms are:
+ - Spline keys: "z_ref", "k_prop", "k_deriv"
+ - Step keys:
+ Each key should map to a waveform dictionary suitable for interpolation with keys:
+ - 'times': 1D array of time points
+ - 'vals': 1D array of values at those time points (same length).
+
+ Attributes
+ ----------
+ keys_to_spline : list of str
+ Keys corresponding to waveforms that will be interpolated using splines.
+
+ keys_to_step : list of str
+ Keys corresponding to waveforms that will be interpolated using step functions.
+
+ data : dict
+ Internal copy of the input control waveforms.
+
+ interpolants : dict
+ A nested dictionary storing interpolation functions of each input waveform.
+ Structure: {spline/step key: interpolant_function}
+
+ """
+
+ def __init__(
+ self,
+ data,
+ ):
+
+ # check correct data is input and in correct format
+ self.keys_to_spline = ["z_ref", "k_prop", "k_deriv"]
+ self.keys_to_step = []
+ for key in self.keys_to_spline + self.keys_to_step:
+ check_data_entry(data=data, key=key, controller_name="VerticalController")
+
+ # create an internal copy of the data
+ self.data = data
+
+ # interpolate the input data
+ self.update_interpolants()
+
+ def update_interpolants(self):
+ """
+ Recompute all interpolant functions from the current `self.data`.
+
+ This method clears the existing `self.interpolants` dictionary and
+ rebuilds it by applying either `interpolate_spline` or `interpolate_step`
+ depending on whether each key belongs to `self.keys_to_spline` or
+ `self.keys_to_step`.
+
+ """
+
+ # create a dictionary to store the spline functions
+ self.interpolants = {}
+
+ # interpolate the input data
+ for key in self.data.keys():
+ self.interpolants[key] = {}
+ if key in self.keys_to_spline:
+ self.interpolants[key] = interpolate_spline(self.data[key])
+ elif key in self.keys_to_step:
+ self.interpolants[key] = interpolate_step(self.data[key])
+
+ def run_control(
+ self,
+ t,
+ dt,
+ ip_meas,
+ zip_meas,
+ zipv_meas,
+ ):
+ """
+ Compute the control signal for plasma vertical position regulation using a
+ proportional-derivative (PD) control law.
+
+ This method uses interpolated reference and gain values to calculate the control
+ output based on the measured plasma current, vertical position, and vertical velocity.
+
+ Parameters
+ ----------
+ t : float
+ Current time [s].
+
+ dt : float
+ Time step [s].
+
+ ip_meas : float
+ Measured plasma current [A].
+
+ zip_meas : float
+ Measured vertical position of the plasma multiplied by measured Ip [A.m].
+
+ zipv_meas : float
+ Measured vertical velocity of the plasma multiplied by measured Ip [A.m/s].
+
+ Returns
+ -------
+ control_signal : float
+ Output of the PD controller, representing the voltage command
+ for vertical position regulation.
+ """
+
+ # extract data
+ z_ref = self.interpolants["z_ref"](t)
+ k_prop = self.interpolants["k_prop"](t)
+ k_deriv = self.interpolants["k_deriv"](t)
+
+ # proportional error
+ err_prop = (z_ref * ip_meas) - zip_meas
+
+ # FB term
+ output = PID(
+ error_prop=err_prop,
+ error_int=None,
+ error_deriv=zipv_meas,
+ k_prop=k_prop,
+ k_int=0.0,
+ k_deriv=k_deriv,
+ )
+
+ return output
+
+ def plot_data(self, tmin=-1.0, tmax=1.0, nt=10001):
+ """
+ Visualizes interpolated control waveforms and corresponding raw inputs.
+
+ This method generates subplots for each control waveform (step types),
+ showing the interpolated time series alongside the original data points. It helps verify
+ the quality and behavior of the interpolation.
+
+ Parameters
+ ----------
+ tmin : float, optional
+ Start time for the evaluation grid (default is -1.0 seconds).
+ tmax : float, optional
+ End time for the evaluation grid (default is 1.0 seconds).
+ nt : int, optional
+ Number of time points to evaluate the interpolants over the interval [tmin, tmax] (default is 10001).
+
+ Notes
+ -----
+ - Each subplot corresponds to a control waveform.
+ - Interpolated curves are plotted in navy; raw data points are shown in red.
+ - Axis labels include units where applicable.
+ - Useful for debugging or validating the interpolation quality.
+ """
+
+ # times to plot at
+ t = np.linspace(tmin, tmax, nt)
+ nplots = len(self.keys_to_spline + self.keys_to_step) # number of plots
+
+ # start plotting
+ fig, axes = plt.subplots(nplots, 1, figsize=(6, 2.5 * nplots), sharex=True)
+
+ if nplots == 1:
+ axes = [axes]
+
+ for ax, key in zip(axes, self.data.keys()):
+ ax.scatter(
+ self.data[key]["times"],
+ self.data[key]["vals"],
+ s=10,
+ marker="x",
+ color="tab:orange",
+ alpha=0.9,
+ label=f"raw data",
+ )
+ ax.plot(
+ t,
+ self.interpolants[key](t),
+ color="navy",
+ linewidth=1.2,
+ label="interpolated",
+ )
+ ax.grid(True, linestyle="--", alpha=0.6)
+
+ if key == "z_ref":
+ ax.set_ylabel(rf"{key} [$m$]")
+ # elif key == "k_prop":
+ # ax.set_ylabel(rf"{key} [$1/s$]")
+ # elif key == "k_deriv":
+ # ax.set_ylabel(rf"{key} [$1/s^2$]")
+ else:
+ ax.set_ylabel(key)
+
+ # y-scaling inside the window
+ times = np.array(self.data[key]["times"])
+ mask = (times >= tmin) & (times <= tmax)
+ if np.any(mask):
+ ydata = np.concatenate(
+ [self.interpolants[key](t), np.array(self.data[key]["vals"])[mask]]
+ )
+ ymin, ymax = np.min(ydata), np.max(ydata)
+ yrange = ymax - ymin
+ if yrange == 0:
+ yrange = 1.0
+ ax.set_ylim(ymin - 0.02 * yrange, ymax + 0.02 * yrange)
+
+ axes[0].legend(loc="best")
+ axes[-1].set_xlabel(r"Time [$s$]")
+ axes[-1].set_xlim([tmin, tmax])
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
+ plt.show()
diff --git a/freegsnke/control_loop/virtual_circuits_category.py b/freegsnke/control_loop/virtual_circuits_category.py
new file mode 100644
index 00000000..5289fd99
--- /dev/null
+++ b/freegsnke/control_loop/virtual_circuits_category.py
@@ -0,0 +1,501 @@
+"""
+Module to implement virtual circuits control in FreeGSNKE control loops.
+
+Copyright 2025 UKAEA, UKRI-STFC, and The Authors, as per the COPYRIGHT and README files.
+
+This file is part of FreeGSNKE.
+
+FreeGSNKE is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Lesser General Public License for more details.
+
+FreeGSNKE is free software: you can redistribute it and/or modify
+it under the terms of the GNU Lesser General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+You should have received a copy of the GNU Lesser General Public License
+along with FreeGSNKE. If not, see .
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from freegsnke.control_loop.useful_functions import (
+ check_data_entry,
+ interpolate_spline,
+ interpolate_step,
+)
+
+
+class VirtualCircuitsController:
+ """
+ A controller class for managing virtual circuit control matrices and coil current reference
+ waveforms.
+
+ This class supports both spline-based (linear) and step-based interpolation of control signals
+ for coils and plasma shaping parameters. It optionally integrates with an emulated virtual
+ circuit provider for enhanced control capabilities.
+
+ Parameters
+ ----------
+ data : dict
+ A nested dictionary containing control waveforms for each shape parameter to be controlled.
+ Each shape parameter's dictionary must include keys for both spline-based and step-based parameters:
+ - Spline keys: typically of the form '_ref'
+ - Step keys: typically shape target and plasma target names
+ Each key should map to a dictionary suitable for interpolation, with keys:
+ - 'times': 1D array of time points
+ - 'vals': 1D array of values at those time points (same length).
+
+ ctrl_coils : list of str
+ The list of active coils being controlled.
+
+ ctrl_targets : list of str
+ The list of shape parameters being managed.
+
+ plasma_target : list of str
+ The list of plasma parameters being managed.
+
+ vc_generator : object, optional
+ An optional class object for applying emulated virtual circuits. If not
+ provided, deafult waveform-defined VCs will be used.
+
+ vc_update_rate : float, optional
+ Optional argument to specify how often, in seconds, new VCs are computed with vc_generator.
+ If None provided, defaults to zero and new VC computed at every time step.
+
+ """
+
+ def __init__(
+ self,
+ data,
+ ctrl_coils,
+ ctrl_targets,
+ plasma_target,
+ vc_generator=None,
+ vc_update_rate=None,
+ ):
+
+ # active coils list (used for shape control)
+ self.ctrl_coils = ctrl_coils
+
+ # ordering of the ctrl coils in the virtual circuit matrices
+ self.vc_coil_order = data["coil_order"]
+ self.vc_coil_order_index = {
+ coil: i for i, coil in enumerate(self.vc_coil_order)
+ }
+
+ # shape parameter list to be controlled
+ self.ctrl_targets = ctrl_targets
+
+ # name of plasma parameter to be controlled
+ self.plasma_target = plasma_target
+
+ # check correct data is input and in correct format
+ self.keys_to_spline = [coil + "_ref" for coil in self.ctrl_coils]
+ self.keys_to_step = self.ctrl_targets + self.plasma_target
+ for key in self.keys_to_spline + self.keys_to_step:
+ check_data_entry(
+ data=data, key=key, controller_name="VirtualCircuitsController"
+ )
+
+ # create an internal copy of the data
+ self.data = data
+
+ # interpolate the input data
+ self.update_interpolants()
+
+ # store emulated VCs class if present
+ self.vc_generator = vc_generator
+
+ # how often to update emulated VCs (in seconds)
+ if vc_update_rate is None:
+ vc_update_rate = 0.0
+ self.vc_update_rate = vc_update_rate
+
+ # set placeholders for most recent VCs
+ self.latest_vc_time = None
+ self.latest_vc = None
+
+ # store emulated VCs that were used
+ self.emulated_jacobian_list = []
+ self.emulated_vc_list = []
+ self.emulated_vc_times = []
+ self.full_vc_matrix = []
+
+ def update_interpolants(self):
+ """
+ Recompute all interpolant functions from the current `self.data`.
+
+ This method clears the existing `self.interpolants` dictionary and
+ rebuilds it by applying either `interpolate_spline` or `interpolate_step`
+ depending on whether each key belongs to `self.keys_to_spline` or
+ `self.keys_to_step`.
+
+ """
+
+ # create a dictionary to store the spline functions
+ self.interpolants = {}
+
+ # interpolate the input data
+ for key in self.keys_to_spline:
+ self.interpolants[key] = interpolate_spline(self.data[key])
+ for key in self.keys_to_step:
+ self.interpolants[key] = interpolate_step(self.data[key])
+
+ def run_control(
+ self,
+ t,
+ dt,
+ dip_dt,
+ dT_dt,
+ I_approved_prev,
+ emulated_VC_targets=None,
+ emulated_VC_targets_calc=None,
+ emulator_coils_calc=None,
+ emu_inputs=None,
+ verbose=False,
+ ):
+ """
+ Computes the unapproved coil currents and their rates of change based on feedforward
+ coil current references and virtual circuit transformations.
+
+ This method extracts coil current reference derivatives, applies virtual circuit matrices
+ (either from an emulator or interpolated data), and computes the unapproved coil
+ current updates using Euler integration.
+
+ There is also the option to provide VCs from an emulator class object.
+
+ Parameters
+ ----------
+ t : float
+ Current time at which control values are evaluated [s].
+
+ dt : float
+ Time step for Euler integration [s].
+
+ dip_dt : float
+ Time derivative of the requested plasma current [A/s].
+
+ dT_dt : np.ndarray
+ Time derivative of the shape target requests [m/s].
+
+ I_approved_prev : numpy.ndarray
+ Previously approved coil currents [A].
+
+ emulated_VC_targets : list of str , optional
+ List of targets to be controlled using the emulated VC's. Must be subset of
+ ctrl_targets, and subset/equal to emulated_VC_targets_calc. Those not defined in this list will be taken from waveform-defined
+ VCs.
+
+ emulated_VC_targets_calc : list of str , optional
+ List of targets to be used when performing pseudoinverse of jacobian when calculating the emulated VC.
+
+ emulator_coils_calc : list of str, optional
+ List of coils to use in emulated VC compuation. These are coils to use in computing shape sensitivity matrix.
+
+ emu_inputs : np.ndarray , optional
+ Array of input values for all input parameters (currents and other plasma parameters) of the Neural Network emulator.
+
+ verbose : bool
+ Print some output if True.
+
+ Returns
+ -------
+ I_unapproved : numpy.ndarray
+ Coil currents (not yet approved), computed via Euler integration [A].
+
+ dI_dt_unapproved : numpy.ndarray
+ Rate of change of coil currents (not yet approved) [A/s].
+ """
+
+ # extract (feedforward) current references
+ dI_dt_ref = self.extract_values(
+ t=t, targets=[coil + "_ref" for coil in self.ctrl_coils], deriv=True
+ )
+
+ # extract shape target VCs from waveform data (targets x coils)
+ VC_shape = self.extract_values(t=t, targets=self.ctrl_targets)
+ if verbose:
+ print("VC's from file", VC_shape)
+
+ # extract plasma target VC from waveform data (targets x coils)
+ VC_plasma = self.extract_values(t=t, targets=self.plasma_target)
+
+ # if emulated VCs to be used, extract the data and overwrite relevant VC
+ # matrix columns
+ if (
+ (self.vc_generator is not None)
+ and (emulated_VC_targets is not None)
+ and (emulated_VC_targets_calc is not None)
+ and (emulator_coils_calc is not None)
+ ):
+ # error checks
+ assert (
+ self.vc_generator is not None
+ ), "Need to provide a VC emulator class to `VirtualCircuitsController`."
+ assert (
+ emulated_VC_targets is not None
+ ), "Need to provide targets for the VC emulator."
+ assert (
+ emulated_VC_targets_calc is not None
+ ), "Need to provide targets for calculation in the VC emulator."
+
+ if self.latest_vc is None:
+ # compute first emulated VC
+ if verbose:
+ print("...first emulated VCs being used.")
+ VC_shape_emu = self.vc_generator.get_vc(
+ targets=emulated_VC_targets,
+ targets_calc=emulated_VC_targets_calc,
+ coils=self.ctrl_coils,
+ coils_calc=emulator_coils_calc,
+ input_data=emu_inputs,
+ )
+ # update latest vcs/times
+ self.latest_vc_time = 1.0 * t
+ self.latest_vc = VC_shape_emu
+
+ # calculate time since last VC update
+ delta_t_vc = t - self.latest_vc_time
+
+ # update with new VCs if required
+ if delta_t_vc >= self.vc_update_rate:
+
+ if verbose:
+ print("...updating the emulated VCs being used.")
+ VC_shape_emu = self.vc_generator.get_vc(
+ targets=emulated_VC_targets,
+ targets_calc=emulated_VC_targets_calc,
+ coils=self.ctrl_coils,
+ coils_calc=emulator_coils_calc,
+ input_data=emu_inputs,
+ )
+
+ # update latest VCs and times
+ self.latest_vc_time = 1.0 * t
+ self.latest_vc = VC_shape_emu
+
+ # store sensitivity matrix (Jacobian)
+ # self.emulated_jacobian_list.append(self.vc_generator.jacobian_matrix)
+ # self.emulated_vc_list.append(self.vc_generator.vc_matrix)
+ self.emulated_vc_times.append(t)
+
+ else:
+ # use the existing emulated VC
+ VC_shape_emu = self.latest_vc
+
+ # fill appropriate columns from emulated vcs
+ ctrl_target_order = {
+ target: i for i, target in enumerate(self.ctrl_targets)
+ }
+ for j, emu_targ in enumerate(emulated_VC_targets):
+ # expand array as apropriate
+ VC_shape[ctrl_target_order[emu_targ], :] = 1.0 * VC_shape_emu[:, j]
+
+ # unapproved coil currents rates of change
+ dI_dt_unapproved = dI_dt_ref + (dT_dt @ VC_shape) + (dip_dt * VC_plasma)
+ self.full_vc_matrix.append(np.concatenate((VC_shape, VC_plasma), axis=0))
+
+ # unapproved coil currents (by simple Euler integration)
+ I_unapproved = I_approved_prev + (dI_dt_unapproved * dt)
+
+ return I_unapproved.squeeze(), dI_dt_unapproved.squeeze()
+
+ def extract_values(
+ self,
+ t,
+ targets,
+ deriv=False,
+ ):
+ """
+ Extracts interpolated values or their derivatives for specified shape targets at a given time.
+
+ This method queries the stored interpolation functions for each target and key, returning either
+ the interpolated value or its first derivative depending on the `deriv` flag.
+
+ Parameters
+ ----------
+ t : float
+ Time at which to evaluate the interpolants [s].
+ targets : list of str
+ List of keys. Each must correspond to a key in `self.interpolants`.
+ deriv : bool, optional
+ If True, returns the first derivative of the interpolant at time `t`. Default is False.
+
+ Returns
+ -------
+ np.ndarray
+ Array of interpolated values (or derivatives) for each target at time `t`.
+
+ Notes
+ -----
+ - Assumes that `self.interpolants[target]` is a valid `scipy.interpolate` object.
+ - If `deriv=True`, the method calls `.derivative()` on the interpolant before evaluation.
+ """
+
+ if deriv:
+ return np.array(
+ [self.interpolants[target].derivative(n=1)(t) for target in targets]
+ )
+ else:
+ return np.array([self.interpolants[target](t) for target in targets])
+
+ def plot_data_FF_currents(self, tmin=-1.0, tmax=1.0, nt=1001):
+ """
+ Visualizes interpolated control waveforms and corresponding raw inputs.
+
+ This method generates subplots for each control waveform (spline types),
+ showing the interpolated time series alongside the original data points. It helps verify
+ the quality and behavior of the interpolation.
+
+ Parameters
+ ----------
+ tmin : float, optional
+ Start time for the evaluation grid (default is -1.0 seconds).
+ tmax : float, optional
+ End time for the evaluation grid (default is 1.0 seconds).
+ nt : int, optional
+ Number of time points to evaluate the interpolants over the interval [tmin, tmax] (default is 1001).
+
+ Notes
+ -----
+ - Each subplot corresponds to a control waveform (e.g., '_ref').
+ - Interpolated curves are plotted in navy; raw data points are shown in red.
+ - Axis labels include units where applicable.
+ - Useful for debugging or validating the interpolation quality.
+ """
+
+ # times to plot at
+ t = np.linspace(tmin, tmax, nt)
+ nplots = len(self.keys_to_spline) # number of plots
+
+ # start plotting
+ fig, axes = plt.subplots(nplots, 1, figsize=(6, 2.5 * nplots), sharex=True)
+
+ if nplots == 1:
+ axes = [axes]
+
+ for ax, key in zip(axes, self.keys_to_spline):
+
+ # find out which control is ON and when
+ FF_reference = self.interpolants[key](t)
+ FF_mask = np.abs(FF_reference) > 0
+
+ # shade region of FF control
+ on_regions = np.where(np.diff(FF_mask.astype(int)) != 0)[0] + 1
+ segments = np.split(t, on_regions)
+ states = np.split(FF_mask, on_regions)
+
+ for seg_t, seg_state in zip(segments, states):
+ if np.all(seg_state): # region fully "on"
+ ax.axvspan(seg_t[0], seg_t[-1], color="yellow", alpha=0.25)
+
+ # raw data
+ ax.scatter(
+ self.data[key]["times"],
+ self.data[key]["vals"],
+ s=10,
+ marker="x",
+ color="tab:orange",
+ alpha=0.9,
+ label=f"raw data",
+ )
+ # interpolated data
+ ax.plot(
+ t,
+ self.interpolants[key](t),
+ color="navy",
+ linewidth=1.2,
+ label="interpolated",
+ )
+ ax.grid(True, linestyle="--", alpha=0.6)
+
+ if key[-3:] == "ref":
+ ax.set_ylabel(rf"{key} [$A$]")
+ else:
+ ax.set_ylabel(key)
+
+ # y-scaling inside the window
+ times = np.array(self.data[key]["times"])
+ mask = (times >= tmin) & (times <= tmax)
+ if np.any(mask):
+ ydata = np.concatenate(
+ [self.interpolants[key](t), np.array(self.data[key]["vals"])[mask]]
+ )
+ ymin, ymax = np.min(ydata), np.max(ydata)
+ yrange = ymax - ymin
+ if yrange == 0:
+ yrange = 1.0
+ ax.set_ylim(ymin - 0.02 * yrange, ymax + 0.02 * yrange)
+
+ axes[0].legend(loc="best")
+ axes[-1].set_xlabel(r"Time [$s$]")
+ axes[-1].set_xlim([tmin, tmax])
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
+ plt.show()
+
+ def plot_data_VCs(self, tmin=-1.0, tmax=1.0, nt=1001):
+ """
+ Visualizes virtual circuits times and corresponding raw inputs.
+
+ Parameters
+ ----------
+ tmin : float, optional
+ Start time for the evaluation grid (default is -1.0 seconds).
+ tmax : float, optional
+ End time for the evaluation grid (default is 1.0 seconds).
+ nt : int, optional
+ Number of time points to evaluate the interpolants over the interval [tmin, tmax] (default is 1001).
+
+ """
+
+ # times to plot at
+ t = np.linspace(tmin, tmax, nt)
+ nplots = len(self.keys_to_step) # number of plots
+
+ # start plotting
+ fig, axes = plt.subplots(nplots, 1, figsize=(6, 2.5 * nplots), sharex=True)
+
+ if nplots == 1:
+ axes = [axes]
+
+ # Convert each array to a hashable form
+ def make_key(arr):
+ return tuple(arr.tolist())
+
+ for ax, key in zip(axes, self.keys_to_step):
+
+ # Assign a unique ID to each unique array
+ state_ids = []
+
+ next_id = 1
+ for arr in self.data[key]["vals"]:
+
+ if np.all(np.abs(arr) < 1e-12):
+ state_ids.append(0)
+ else:
+ state_ids.append(next_id)
+ next_id += 1
+
+ state_ids = np.array(state_ids)
+
+ # plot different VC times
+ ax.step(
+ self.data[key]["times"],
+ state_ids,
+ where="post",
+ color="navy",
+ label=key,
+ )
+ ax.set_yticks(sorted(set(state_ids)))
+ ax.grid(True, linestyle="--", alpha=0.6)
+ ax.set_ylabel(f"Unique VC ID")
+ ax.legend(loc="best")
+
+ axes[-1].set_xlabel(r"Time [$s$]")
+ axes[-1].set_xlim([tmin, tmax])
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
+ plt.show()
diff --git a/freegsnke/equilibrium_update.py b/freegsnke/equilibrium_update.py
index 4f21fc19..f5f65ba8 100644
--- a/freegsnke/equilibrium_update.py
+++ b/freegsnke/equilibrium_update.py
@@ -109,10 +109,20 @@ def create_auxiliary_equilibrium(self):
equilibrium.Z = np.copy(self.Z)
equilibrium.tokamak_psi = np.copy(self.tokamak_psi)
equilibrium.plasma_psi = np.copy(self.plasma_psi)
+ equilibrium.psi_axis = np.copy(self.psi_axis)
+ equilibrium.psi_bndry = np.copy(self.psi_bndry)
equilibrium.mask_inside_limiter = np.copy(self.mask_inside_limiter)
equilibrium.mask_outside_limiter = np.copy(self.mask_outside_limiter)
equilibrium._pgreen = self._pgreen.copy()
equilibrium._vgreen = self._vgreen.copy()
+ copy_into(
+ self,
+ equilibrium,
+ "flag_limiter",
+ mutable=True,
+ strict=False,
+ allow_deepcopy=True,
+ )
copy_into(self, equilibrium, "current_vec", mutable=True, strict=False)
copy_into(
@@ -121,7 +131,7 @@ def create_auxiliary_equilibrium(self):
copy_into(
self, equilibrium, "xpt", mutable=True, strict=False, allow_deepcopy=True
)
- copy_into(self, equilibrium, "psi_bndry", strict=False)
+ # copy_into(self, equilibrium, "psi_bndry", strict=False)
if hasattr(self, "_profiles"):
equilibrium._profiles = self._profiles.copy()
diff --git a/freegsnke/mastu_tools.py b/freegsnke/mastu_tools.py
index a100eb57..ff23ae49 100644
--- a/freegsnke/mastu_tools.py
+++ b/freegsnke/mastu_tools.py
@@ -2191,10 +2191,6 @@ def plasma_resistivity_controller(
Plasma resistivity required to maintain the target plasma current.
"""
- # if not history, no control action
- if not history:
- return 0
-
# proportional term
error = history[-1] - target
output = (
diff --git a/freegsnke/nk_solver_H.py b/freegsnke/nk_solver_H.py
index 79d5e3dd..fcde8cfa 100644
--- a/freegsnke/nk_solver_H.py
+++ b/freegsnke/nk_solver_H.py
@@ -293,11 +293,12 @@ def Arnoldi_iteration(
self.collinear_aware_regulariz = collinear_aware_regulariz * nR0**2
# solve the regularised least sq problem
+ A = (
+ self.G[:, : self.n_it + 1].T @ self.G[:, : self.n_it + 1]
+ + self.collinear_aware_regulariz
+ )
coeffs = np.dot(
- np.linalg.inv(
- self.G[:, : self.n_it + 1].T @ self.G[:, : self.n_it + 1]
- + self.collinear_aware_regulariz
- ),
+ np.linalg.solve(A, np.eye(A.shape[0])),
np.dot(self.G[:, : self.n_it + 1].T, -R0),
)
coeffs = np.clip(coeffs, -clip, clip)
diff --git a/freegsnke/nonlinear_solve.py b/freegsnke/nonlinear_solve.py
index ced71962..5d2feaae 100644
--- a/freegsnke/nonlinear_solve.py
+++ b/freegsnke/nonlinear_solve.py
@@ -14,9 +14,9 @@
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
-
+
You should have received a copy of the GNU Lesser General Public License
-along with FreeGSNKE. If not, see .
+along with FreeGSNKE. If not, see .
"""
import warnings
@@ -82,6 +82,7 @@ def __init__(
l2_reg=1e-6,
collinearity_reg=1e-6,
verbose=False,
+ plasma_descriptor_function=None,
):
"""
Initialize the nonlinear solver.
@@ -519,6 +520,9 @@ def __init__(
)
automatic_timestep_flag = True
+ self.plasma_descriptor_function = plasma_descriptor_function or (
+ lambda _: np.array([0.0])
+ )
if automatic_timestep_flag + mode_removal + linearize:
# builds the linearization and sets everything up for the stepper
self.initialize_from_ICs(
@@ -529,6 +533,7 @@ def __init__(
dIydtheta=dIydtheta,
verbose=verbose,
force_core_mask_linearization=force_core_mask_linearization,
+ plasma_descriptor_function=plasma_descriptor_function,
)
print("-----")
@@ -554,6 +559,16 @@ def __init__(
self.dIydI_ICs = np.copy(self.dIydI)
self.dRZdI = self.dRZdI[:, self.retained_modes_mask]
self.final_dI_record = self.final_dI_record[self.retained_modes_mask]
+ self.dvdId = self.dvdId[:, self.retained_modes_mask]
+ self.initial_currents_plasma_descriptor = (
+ self.initial_currents_plasma_descriptor[self.retained_modes_mask]
+ )
+
+ # apply mask in case of re-linearisation
+ self.approved_target_dIy = self.approved_target_dIy[
+ self.retained_modes_mask
+ ]
+ self.starting_dI = self.starting_dI[self.retained_modes_mask]
self.remove_modes(eq, self.retained_modes_mask[:-1])
@@ -841,11 +856,7 @@ def remove_modes(self, eq, selected_modes_mask):
self.set_solvers()
- def set_linear_solution(
- self,
- active_voltage_vec,
- dtheta_dt,
- ):
+ def set_linear_solution(self, active_voltage_vec, dtheta_dt, no_GS=False):
"""
Compute an initial nonlinear solve guess using the linearised dynamics.
@@ -877,7 +888,8 @@ def set_linear_solution(
active_voltage_vec=active_voltage_vec,
dtheta_dt=dtheta_dt,
)
- self.assign_currents_solve_GS(self.trial_currents, self.rtol_NK)
+ if not no_GS:
+ self.assign_currents_solve_GS(self.trial_currents, self.rtol_NK)
self.trial_plasma_psi = np.copy(self.eq2.plasma_psi)
def prepare_build_dIydtheta(
@@ -886,6 +898,7 @@ def prepare_build_dIydtheta(
rtol_NK,
target_dIy,
starting_dtheta,
+ plasma_descriptor_function,
verbose=False,
):
"""
@@ -932,8 +945,19 @@ def prepare_build_dIydtheta(
dIy_0 = np.zeros((len(self.Iy), self.n_profiles_parameters))
rel_ndIy_0 = np.zeros(self.n_profiles_parameters)
+ dv = np.zeros(
+ (self.initial_plasma_descriptors.shape[0], self.n_profiles_parameters)
+ )
+
# carry out the initial perturbations
if self.profiles_param is not None:
+ self.initial_profiles_plasma_descriptor = np.array(
+ [
+ profiles.alpha_m,
+ profiles.alpha_n,
+ getattr(profiles, self.profiles_param),
+ ]
+ )
# vary alpha_m
self.check_and_change_profiles(
@@ -950,6 +974,9 @@ def prepare_build_dIydtheta(
dIy_0[:, 0] = (
self.limiter_handler.Iy_from_jtor(self.profiles2.jtor) - self.Iy
)
+ dv[:, 0] = (
+ plasma_descriptor_function(self.eq2) - self.initial_plasma_descriptors
+ )
rel_ndIy_0[0] = np.linalg.norm(dIy_0[:, 0]) / self.nIy
self.final_dtheta_record[0] = (
starting_dtheta[0] * target_dIy[0] / rel_ndIy_0[0]
@@ -977,6 +1004,9 @@ def prepare_build_dIydtheta(
dIy_0[:, 1] = (
self.limiter_handler.Iy_from_jtor(self.profiles2.jtor) - self.Iy
)
+ dv[:, 1] = (
+ plasma_descriptor_function(self.eq2) - self.initial_plasma_descriptors
+ )
rel_ndIy_0[1] = np.linalg.norm(dIy_0[:, 1]) / self.nIy
self.final_dtheta_record[1] = (
starting_dtheta[1] * target_dIy[1] / rel_ndIy_0[1]
@@ -1005,6 +1035,9 @@ def prepare_build_dIydtheta(
dIy_0[:, 2] = (
self.limiter_handler.Iy_from_jtor(self.profiles2.jtor) - self.Iy
)
+ dv[:, 2] = (
+ plasma_descriptor_function(self.eq2) - self.initial_plasma_descriptors
+ )
rel_ndIy_0[2] = np.linalg.norm(dIy_0[:, 2]) / self.nIy
self.final_dtheta_record[2] = (
starting_dtheta[2] * target_dIy[2] / rel_ndIy_0[2]
@@ -1027,6 +1060,31 @@ def prepare_build_dIydtheta(
)
else: # this is particular to the Lao profile coefficients (which there may be few or many of)
+ # extract parameters for the plasma descriptor function
+ self.initial_profiles_plasma_descriptor = np.concatenate(
+ (
+ self.profiles_parameters_vec[self.profiles_alpha_indices],
+ self.profiles_parameters_vec[self.profiles_beta_indices],
+ )
+ )
+
+ self.profiles_alpha_indices = slice(0, self.n_profiles_parameters_alpha)
+ alpha_shift = 0
+ if profiles.alpha_logic:
+ alpha_shift += 1
+
+ self.profiles_beta_indices = slice(
+ self.n_profiles_parameters_alpha + alpha_shift,
+ self.n_profiles_parameters_alpha
+ + alpha_shift
+ + self.n_profiles_parameters_beta,
+ )
+
+ self.profiles_parameters = {"alpha": profiles.alpha, "beta": profiles.beta}
+ self.profiles_param = None
+ self.profiles_parameters_vec = np.concatenate(
+ (profiles.alpha, profiles.beta)
+ )
# for each alpha coefficient
alpha_base = profiles.alpha.copy()
@@ -1051,6 +1109,10 @@ def prepare_build_dIydtheta(
dIy_0[:, i] = (
self.limiter_handler.Iy_from_jtor(self.profiles2.jtor) - self.Iy
)
+ dv[:, i] = (
+ plasma_descriptor_function(self.eq2)
+ - self.initial_plasma_descriptors
+ )
rel_ndIy_0[i] = np.linalg.norm(dIy_0[:, i]) / self.nIy
self.final_dtheta_record[i] = (
starting_dtheta[i] * target_dIy[i] / rel_ndIy_0[i]
@@ -1087,6 +1149,10 @@ def prepare_build_dIydtheta(
dIy_0[:, i + self.n_profiles_parameters_alpha] = (
self.limiter_handler.Iy_from_jtor(self.profiles2.jtor) - self.Iy
)
+ dv[:, i + self.n_profiles_parameters_alpha] = (
+ plasma_descriptor_function(self.eq2)
+ - self.initial_plasma_descriptors
+ )
rel_ndIy_0[i + self.n_profiles_parameters_alpha] = (
np.linalg.norm(dIy_0[:, i + self.n_profiles_parameters_alpha])
/ self.nIy
@@ -1101,13 +1167,13 @@ def prepare_build_dIydtheta(
print("")
print(f"Profile parameter: beta_{i}:")
print(
- f" Initial delta parameter = {starting_dtheta[i+self.n_profiles_parameters_alpha]}"
+ f" Initial delta parameter = {starting_dtheta[i + self.n_profiles_parameters_alpha]}"
)
print(
- f" Initial relative Iy change = {rel_ndIy_0[i+self.n_profiles_parameters_alpha]}"
+ f" Initial relative Iy change = {rel_ndIy_0[i + self.n_profiles_parameters_alpha]}"
)
print(
- f" Final delta parameter = {self.final_dtheta_record[i+self.n_profiles_parameters_alpha]}"
+ f" Final delta parameter = {self.final_dtheta_record[i + self.n_profiles_parameters_alpha]}"
)
# reset profiles in profiles1 and profiles2 objects
@@ -1118,9 +1184,11 @@ def prepare_build_dIydtheta(
}
)
- return dIy_0 / starting_dtheta, rel_ndIy_0
+ return dIy_0 / starting_dtheta, rel_ndIy_0, dv / starting_dtheta
- def build_dIydtheta(self, profiles, rtol_NK, verbose=False):
+ def build_dIydtheta(
+ self, profiles, rtol_NK, plasma_descriptor_function, verbose=False
+ ):
"""
Compute the finite-difference Jacobian d(Iy)/dθ using pre-scaled perturbations.
@@ -1164,9 +1232,12 @@ def build_dIydtheta(self, profiles, rtol_NK, verbose=False):
dIydtheta = np.zeros((len(self.Iy), self.n_profiles_parameters))
rel_ndIy = np.zeros(self.n_profiles_parameters)
+ dvdtheta = np.zeros(
+ (self.initial_plasma_descriptors.shape[0], self.n_profiles_parameters)
+ )
+
# carry out the initial perturbations
if self.profiles_param is not None:
-
# vary alpha_m
self.check_and_change_profiles(
profiles_parameters={
@@ -1181,6 +1252,8 @@ def build_dIydtheta(self, profiles, rtol_NK, verbose=False):
dIy_1 = self.limiter_handler.Iy_from_jtor(self.profiles2.jtor) - self.Iy
rel_ndIy[0] = np.linalg.norm(dIy_1) / self.nIy
dIydtheta[:, 0] = dIy_1 / final_theta[0]
+ dv = plasma_descriptor_function(self.eq2) - self.initial_plasma_descriptors
+ dvdtheta[:, 0] = dv / final_theta[0]
if verbose:
print("")
print(f"Profile parameter: alpha_m:")
@@ -1203,6 +1276,8 @@ def build_dIydtheta(self, profiles, rtol_NK, verbose=False):
dIy_1 = self.limiter_handler.Iy_from_jtor(self.profiles2.jtor) - self.Iy
rel_ndIy[1] = np.linalg.norm(dIy_1) / self.nIy
dIydtheta[:, 1] = dIy_1 / final_theta[1]
+ dv = plasma_descriptor_function(self.eq2) - self.initial_plasma_descriptors
+ dvdtheta[:, 1] = dv / final_theta[1]
if verbose:
print("")
print(f"Profile parameter: alpha_n:")
@@ -1226,6 +1301,8 @@ def build_dIydtheta(self, profiles, rtol_NK, verbose=False):
dIy_1 = self.limiter_handler.Iy_from_jtor(self.profiles2.jtor) - self.Iy
rel_ndIy[2] = np.linalg.norm(dIy_1) / self.nIy
dIydtheta[:, 2] = dIy_1 / final_theta[2]
+ dv = plasma_descriptor_function(self.eq2) - self.initial_plasma_descriptors
+ dvdtheta[:, 2] = dv / final_theta[2]
if verbose:
print("")
print(f"Profile parameter: {self.profiles_param}:")
@@ -1244,7 +1321,6 @@ def build_dIydtheta(self, profiles, rtol_NK, verbose=False):
)
else: # this is particular to the Lao profile coefficients (which there may be few or many of)
-
# for each alpha coefficient
alpha_base = profiles.alpha.copy()
@@ -1268,6 +1344,11 @@ def build_dIydtheta(self, profiles, rtol_NK, verbose=False):
dIy_1 = self.limiter_handler.Iy_from_jtor(self.profiles2.jtor) - self.Iy
rel_ndIy[i] = np.linalg.norm(dIy_1) / self.nIy
dIydtheta[:, i] = dIy_1 / final_theta[i]
+ dv = (
+ plasma_descriptor_function(self.eq2)
+ - self.initial_plasma_descriptors
+ )
+ dvdtheta[:, i] = dv / final_theta[i]
if verbose:
print("")
print(f"Profile parameter: alpha_{i}:")
@@ -1303,12 +1384,19 @@ def build_dIydtheta(self, profiles, rtol_NK, verbose=False):
dIydtheta[:, i + self.n_profiles_parameters_alpha] = (
dIy_1 / final_theta[i + self.n_profiles_parameters_alpha]
)
+ dv = (
+ plasma_descriptor_function(self.eq2)
+ - self.initial_plasma_descriptors
+ )
+ dvdtheta[:, i + self.n_profiles_parameters_alpha] = (
+ dv / final_theta[i + self.n_profiles_parameters_alpha]
+ )
if verbose:
print("")
print(f"Profile parameter: beta_{i}:")
print(
- f" Final relative Iy change = {rel_ndIy[i+self.n_profiles_parameters_alpha]}"
+ f" Final relative Iy change = {rel_ndIy[i + self.n_profiles_parameters_alpha]}"
)
print(
f" Initial vs. Final GS residual: {self.NK.initial_rel_residual} vs. {self.NK.relative_change}"
@@ -1322,7 +1410,7 @@ def build_dIydtheta(self, profiles, rtol_NK, verbose=False):
}
)
- return dIydtheta, rel_ndIy
+ return dIydtheta, rel_ndIy, dvdtheta
def prepare_build_dIydI_j(
self,
@@ -1429,6 +1517,26 @@ def build_dIydI_j(self, j, rtol_NK):
return dIydIj, rel_ndIy
+ def new_plasma_descriptors(
+ self, new_currents: np.ndarray, new_profiles: np.ndarray
+ ):
+ """Calculates the estimate plasma descriptors vector `v` from the linearisation."""
+
+ current_contribution = (
+ self.dvdId
+ @ (new_currents - self.initial_currents_plasma_descriptor)[:, np.newaxis]
+ ).squeeze()
+ profile_contribution = (
+ self.dvdtheta
+ @ (new_profiles - self.initial_profiles_plasma_descriptor)[:, np.newaxis]
+ ).squeeze()
+
+ return (
+ self.initial_plasma_descriptors
+ + current_contribution
+ + profile_contribution
+ )
+
def build_linearization(
self,
eq,
@@ -1438,6 +1546,7 @@ def build_linearization(
target_relative_tolerance_linearization,
force_core_mask_linearization,
verbose,
+ plasma_descriptor_function,
):
"""
Builds the Jacobians d(Iy)/dI and d(Iy)/dtheta for linearizing the plasma-current
@@ -1487,6 +1596,9 @@ def build_linearization(
self.Iy = self.limiter_handler.Iy_from_jtor(profiles.jtor).copy()
self.nIy = np.linalg.norm(self.Iy)
+ self.initial_plasma_descriptors = np.array(plasma_descriptor_function(eq))
+ self.plasma_descriptors_vec = np.copy(self.initial_plasma_descriptors)
+
self.R0 = eq.Rcurrent()
self.Z0 = eq.Zcurrent()
self.dRZdI = np.zeros((2, self.n_metal_modes + 1))
@@ -1509,8 +1621,12 @@ def build_linearization(
self.ddIyddI = np.zeros(self.n_metal_modes + 1)
self.final_dI_record = np.zeros(self.n_metal_modes + 1)
- for j in self.arange_currents:
+ self.dvdId = np.zeros(
+ (self.initial_plasma_descriptors.shape[0], self.n_metal_modes + 1)
+ )
+ self.initial_currents_plasma_descriptor = np.copy(self.currents_vec)
+ for j in self.arange_currents:
this_target_dIy = 1.0 * self.approved_target_dIy[j]
dIydIj, ndIy = self.prepare_build_dIydI_j(
j,
@@ -1597,7 +1713,10 @@ def build_linearization(
print(f" Initial relative Iy change = {ndIy}")
print(f" Final delta_current = {self.final_dI_record[j]}")
print("")
- print(f" Final relative Iy change = {rel_ndIy}")
+ if "rel_ndIy" in locals():
+ print(f" Final relative Iy change = {rel_ndIy}")
+ else:
+ print(f" Final relative Iy change = {ndIy}")
print(
f" Initial vs. Final GS residual: {self.NK.initial_rel_residual} vs. {self.NK.relative_change}"
)
@@ -1609,6 +1728,11 @@ def build_linearization(
self.dRZdI[0, j] = (R0 - self.R0) / self.final_dI_record[j]
self.dRZdI[1, j] = (Z0 - self.Z0) / self.final_dI_record[j]
+ v0 = plasma_descriptor_function(self.eq2)
+ self.dvdId[:, j] = (
+ v0 - self.initial_plasma_descriptors
+ ) / self.final_dI_record[j]
+
self.dIydI_ICs = np.copy(self.dIydI)
else:
self.dIydI = np.copy(self.dIydI_ICs)
@@ -1631,15 +1755,22 @@ def build_linearization(
self.dIydtheta = np.zeros(
(self.plasma_domain_size, self.n_profiles_parameters)
)
+ self.dvdtheta = np.zeros(
+ (
+ self.initial_plasma_descriptors.shape[0],
+ self.n_profiles_parameters,
+ )
+ )
profiles_copy = profiles.copy()
# prepare to build the Jacobian by finding appropriate step size
- dIydtheta, ndIy = self.prepare_build_dIydtheta(
+ dIydtheta, ndIy, dvdtheta = self.prepare_build_dIydtheta(
profiles=profiles_copy,
rtol_NK=target_relative_tolerance_linearization,
target_dIy=self.approved_target_dtheta,
starting_dtheta=self.starting_dtheta,
+ plasma_descriptor_function=plasma_descriptor_function,
verbose=verbose,
)
@@ -1647,10 +1778,10 @@ def build_linearization(
np.abs(np.log10(self.final_dtheta_record / self.starting_dtheta))
> 0.5
).any():
-
- dIydtheta, rel_ndIy = self.build_dIydtheta(
+ dIydtheta, rel_ndIy, dvdtheta = self.build_dIydtheta(
profiles=profiles_copy,
rtol_NK=target_relative_tolerance_linearization,
+ plasma_descriptor_function=plasma_descriptor_function,
verbose=verbose,
)
@@ -1659,6 +1790,20 @@ def build_linearization(
self.dIydtheta = np.copy(dIydtheta)
self.dIydtheta_ICs = np.copy(self.dIydtheta)
+ self.dvdtheta = np.copy(dvdtheta)
+
+ if plasma_descriptor_function is not None:
+ print(
+ f"Built the {len(self.initial_plasma_descriptors)} x {self.n_metal_modes + 1} Jacobian (ds/dI)",
+ "of plasma descriptors",
+ "with respect to all metal currents and the total plasma current.",
+ )
+ print(
+ f"Built the {len(self.initial_plasma_descriptors)} x {self.n_profiles_parameters} Jacobian (ds/dtheta)",
+ "of plasma descriptors",
+ "with respect to all plasma current density profile parameters within Jtor.",
+ )
+
else:
self.dIydtheta = np.copy(self.dIydtheta_ICs)
else:
@@ -1726,7 +1871,7 @@ def reset_plasma_resistivity(self, plasma_resistivity):
self.simplified_solver_J1.reset_plasma_resistivity(self.plasma_resistance_1d)
def check_and_change_plasma_resistivity(
- self, plasma_resistivity, relative_threshold_difference=0.01
+ self, plasma_resistivity, relative_threshold_difference=1e-5
):
"""
Check if the plasma resistivity differs from the current value and update it if necessary.
@@ -1755,7 +1900,7 @@ def check_and_change_plasma_resistivity(
# check how different
check = (
np.abs(plasma_resistivity - self.plasma_resistivity)
- / self.plasma_resistivity
+ / np.abs(self.plasma_resistivity)
) > relative_threshold_difference
if check:
self.reset_plasma_resistivity(plasma_resistivity=plasma_resistivity)
@@ -1961,6 +2106,7 @@ def initialize_from_ICs(
dIydtheta=None,
force_core_mask_linearization=False,
verbose=False,
+ plasma_descriptor_function=None,
):
"""
Initialize the dynamics solver from a given equilibrium and plasma profiles.
@@ -2056,10 +2202,17 @@ def initialize_from_ICs(
self.hatIy1 = np.copy(self.hatIy)
self.make_blended_hatIy(self.hatIy1)
+ # store norm of jtor for use if relinearising in future time steps
+ self.jtor0 = self.profiles1.jtor
+
self.time = 0
self.step_no = -1
# build the linearization if not provided
+ self._force_core_mask_linearization = force_core_mask_linearization
+ self._target_relative_tolerance_linearization = (
+ target_relative_tolerance_linearization
+ )
self.build_linearization(
self.eq1,
self.profiles1,
@@ -2068,6 +2221,8 @@ def initialize_from_ICs(
target_relative_tolerance_linearization=target_relative_tolerance_linearization,
force_core_mask_linearization=force_core_mask_linearization,
verbose=verbose,
+ plasma_descriptor_function=plasma_descriptor_function
+ or self.plasma_descriptor_function,
)
# set Myy matrix in place throught the handling object
@@ -2082,6 +2237,71 @@ def initialize_from_ICs(
Myy_hatIy0=self.Myy_hatIy0,
)
+ def relinearise(self, *, verbose=False):
+ """
+ Recompute (relinearise) the linearisation of the equilibrium
+ problem around the current plasma state.
+
+ This method temporarily constructs auxiliary equilibria and profile
+ objects in order to rebuild all linearised operators and sensitivity
+ matrices used by the solver. After the linearisation operators are
+ rebuilt, the original equilibria and profiles are restored so that the
+ nonlinear state of the main solver remains unchanged.
+
+ Parameters
+ ----------
+ verbose : bool, optional
+ If True, print progress messages during the relinearisation
+ process. Default is False.
+
+ """
+
+ if verbose:
+ print("Relinearising around the current plasma")
+
+ # create and store auxiliary copies of eq and profiles
+ original_eq1 = self.eq1.create_auxiliary_equilibrium()
+ original_eq2 = self.eq2.create_auxiliary_equilibrium()
+ original_profiles1 = self.profiles1.copy()
+ original_profiles2 = self.profiles2.copy()
+
+ # reset previous linearisations
+ self.dIydI_ICs = None
+ self.dIydtheta_ICs = None
+
+ # recompute the linearisations
+ self.build_linearization(
+ self.eq1.create_auxiliary_equilibrium(),
+ self.profiles1.copy(),
+ dIydI=None,
+ dIydtheta=None,
+ target_relative_tolerance_linearization=self._target_relative_tolerance_linearization,
+ force_core_mask_linearization=self._force_core_mask_linearization,
+ verbose=verbose,
+ plasma_descriptor_function=self.plasma_descriptor_function,
+ )
+
+ # rebuild operators based on new linearisations
+ self.handleMyy.force_build_Myy(self.hatIy)
+ self.Myy_hatIy0 = self.handleMyy.dot(self.hatIy)
+
+ # update internal linearisation point object
+ self.linearised_sol.set_linearization_point(
+ dIydI=self.dIydI_ICs,
+ dIydtheta=self.dIydtheta_ICs,
+ hatIy0=self.blended_hatIy,
+ Myy_hatIy0=self.Myy_hatIy0,
+ )
+
+ # reset internal objects incase modifed by above
+ self.eq1 = original_eq1
+ self.eq2 = original_eq2
+ self.profiles1 = original_profiles1
+ self.profiles2 = original_profiles2
+
+ # store norm of jtor at relinearisation point (used for next relinearisation triggering)
+ self.jtor0 = self.profiles1.jtor
+
def step_complete_assign(self, working_relative_tol_GS, from_linear=False):
"""
Finalize the timestep advancement and update the equilibrium and current state.
@@ -2590,8 +2810,11 @@ def check_and_change_active_coil_resistances(self, active_coil_resistances):
if active_coil_resistances is None:
return
else:
- if np.array_equal(
- active_coil_resistances, self.evol_metal_curr.active_coil_resistances
+ if np.allclose(
+ active_coil_resistances,
+ self.evol_metal_curr.active_coil_resistances,
+ atol=1e-5,
+ rtol=1e-3,
):
return
else:
@@ -2599,7 +2822,6 @@ def check_and_change_active_coil_resistances(self, active_coil_resistances):
active_coil_resistances
)
self.set_solvers()
- print(self.evol_metal_curr.coil_resist)
def nlstepper(
self,
@@ -2622,107 +2844,176 @@ def nlstepper(
linear_only=False,
max_solving_iterations=50,
custom_active_coil_resistances=None,
+ no_GS=False,
+ relinearise_threshold=None,
):
"""
- Advance the system by one timestep using a nonlinear Newton-Krylov (NK) stepper.
+ Advance the system by one timestep using a nonlinear Newton–Krylov solver.
- If ``linear_only=True``, only the linearised dynamic problem is advanced.
- Otherwise, a full nonlinear solution is sought using an iterative NK-based algorithm.
- On convergence, the timestep is advanced by ``self.dt_step`` and the updated
- currents, equilibrium, and profile objects are assigned to ``self.currents_vec``,
- ``self.eq1``, and ``self.profiles1``.
+ This method advances the plasma and coil current state from time ``t`` to
+ ``t + dt_step`` by solving the coupled dynamic–static inverse equilibrium
+ problem. The solver uses a Newton–Krylov (NK) strategy alternating between:
+
+ - A static Grad–Shafranov (GS) solve at fixed currents.
+ - A dynamic current solve at fixed plasma flux.
+
+ If ``linear_only=True``, only the linearised dynamic problem is evolved and
+ no nonlinear fixed-point or NK iterations are executed.
+
+ On successful convergence, the method updates:
+
+ - ``self.currents_vec``: advanced coil currents,
+ - ``self.eq1``: updated GS equilibrium,
+ - ``self.profiles1``: updated plasma profiles.
Algorithm overview
------------------
- The solver proceeds as follows:
-
- 1. Solve the linearised problem to obtain an initial guess for the currents and
- solve the associated static Grad–Shafranov (GS) problem, yielding
- ``trial_plasma_psi`` and ``trial_currents`` (including ``tokamak_psi``).
- 2. If the pair [``trial_plasma_psi``, ``tokamak_psi``] fails the GS tolerance check,
- update ``trial_plasma_psi`` toward the GS solution.
- 3. At fixed currents, update ``trial_plasma_psi`` via NK iterations on the
- root problem in plasma flux.
- 4. At fixed plasma flux, update currents via NK iterations on the root problem
- in currents.
- 5. If either the current residuals or the GS tolerance check fail, return to step 2.
- 6. On convergence, record the solution into ``self.currents_vec``, ``self.eq1``,
- and ``self.profiles1``.
+ The nonlinear step proceeds as:
+
+ 1. Solve the linearised dynamic problem to produce initial guesses for
+ currents and plasma flux (``trial_currents``, ``trial_plasma_psi``).
+ 2. Optionally (if ``no_GS=False``), solve the GS equation and blend the
+ result with the trial flux using ``blend_GS``.
+ 3. At fixed currents, perform NK iterations on the plasma flux until the
+ GS residual meets the working tolerance.
+ 4. At fixed plasma flux, perform NK iterations on the currents until the
+ current residual meets the required tolerance.
+ 5. Repeat steps 2–4 until both the current convergence criterion and the
+ GS convergence criterion are satisfied.
+ 6. If ``relinearise_threshold`` is set and the baseline toroidal current
+ changes sufficiently, trigger a full relinearisation before continuing.
+ 7. On convergence, commit the solution to the object's internal state.
Parameters
----------
- active_voltage_vec : np.ndarray
- Vector of applied voltages on the active coils between ``t`` and ``t+dt``.
+ active_voltage_vec : array-like
+ Voltages applied to the active coils between ``t`` and ``t + dt_step``.
profiles_parameters : dict or None, optional
- If None, profile parameters remain unchanged. Otherwise, dictionary specifying
- updated parameters for the profiles object. See ``get_profiles_values`` for
- dictionary structure. This enables time-dependent profile parameters.
+ Parameters to update the profile model. If None, profiles are unchanged.
plasma_resistivity : float or array-like, optional
- Updated plasma resistivity. If None, resistivity is left unchanged. Enables time-
- dependent resistivity.
- target_relative_tol_currents : float, optional, default=0.005
- Required relative tolerance on currents for convergence of the dynamic problem.
- Computed as ``residual / (currents(t+dt) - currents(t))``.
- target_relative_tol_GS : float, optional, default=0.003
- Required relative tolerance on plasma flux for convergence of the static GS problem.
- Computed as ``residual / Δψ`` where Δψ is the flux change between timesteps.
- working_relative_tol_GS : float, optional, default=0.001
- Tolerance used when solving intermediate GS problems during the step.
- Must be stricter than ``target_relative_tol_GS``.
- target_relative_unexplained_residual : float, optional, default=0.5
- NK solver stopping criterion: inclusion of additional Krylov basis vectors
- stops once more than ``1 - target_relative_unexplained_residual`` of the residual
- is canceled.
- max_n_directions : int, optional, default=3
- Maximum number of Krylov basis vectors used in NK updates.
- step_size_psi : float, optional, default=2.0
- Step size for finite difference calculations in the NK solver applied to ψ,
- measured in units of the residual norm.
- step_size_curr : float, optional, default=0.8
- Step size for finite difference calculations in the NK solver applied to currents,
- measured in units of the residual norm.
- scaling_with_n : int, optional, default=0
- Exponent controlling step scaling in NK updates:
- candidate step is scaled by ``(1 + n_iterations)**scaling_with_n``.
- blend_GS : float, optional, default=0.5
- Blending coefficient used when updating ``trial_plasma_psi`` toward the GS solution.
- Must be in [0, 1].
- curr_eps : float, optional, default=1e-5
- Regularisation parameter for relative current convergence checks,
- preventing division by small current changes.
- max_no_NK_psi : float, optional, default=5.0
- Threshold for triggering NK updates on ψ. Activated if
- ``relative_psi_residual > max_no_NK_psi * target_relative_tol_GS``.
- clip : float, optional, default=5
- Maximum allowed step size for each accepted Krylov basis vector, in units
- of the exploratory step.
- verbose : int, optional, default=0
- Verbosity level.
- * 0: silent
- * 1: report convergence progress per NK cycle
- * 2: include detailed intermediate output
- linear_only : bool, optional, default=False
- If True, only the linearised solution is used (skipping nonlinear solves).
- max_solving_iterations : int, optional, default=50
- Maximum number of nonlinear NK cycles before the solve is terminated.
+ New plasma resistivity. If None, resistivity remains unchanged.
+ target_relative_tol_currents : float, default=0.005
+ Convergence tolerance for the dynamic current solve.
+ target_relative_tol_GS : float, default=0.003
+ Convergence tolerance for the final GS solve.
+ working_relative_tol_GS : float, default=0.001
+ Tighter intermediate tolerance for GS updates inside NK cycles.
+ target_relative_unexplained_residual : float, default=0.5
+ NK termination threshold based on unexplained residual fraction.
+ max_n_directions : int, default=3
+ Maximum Krylov basis size for each NK update.
+ step_size_psi : float, default=2.0
+ Exploratory finite-difference step size for NK updates in flux space.
+ step_size_curr : float, default=0.8
+ Exploratory step size for NK updates in current space.
+ scaling_with_n : int, default=0
+ Exponent for scaling candidate steps by ``(1 + n_iter)**scaling_with_n``.
+ blend_GS : float, default=0.5
+ Blending factor for updating trial GS solutions.
+ curr_eps : float, default=1e-5
+ Regularisation term to avoid division by small current differences.
+ max_no_NK_psi : float, default=5.0
+ Threshold for enabling NK flux updates.
+ clip : float, default=5
+ Maximum allowed NK update magnitude (per direction).
+ verbose : int, default=0
+ Verbosity level: 0 = silent, 1 = coarse, 2 = detailed.
+ linear_only : bool, default=False
+ If True, perform only the linearised update and skip nonlinear solves.
+ max_solving_iterations : int, default=50
+ Maximum allowed NK cycles.
custom_active_coil_resistances : array-like or None, optional
- If provided, overrides default active coil resistances with those specifed.
- Enables time-dependent coil resistances (can be used for switching coils "on"
- and "off").
+ Custom resistances for active coils. If provided, overrides defaults.
+ no_GS : bool, default=False
+ If True, skip all GS solves (current-only evolution). Intended mainly for
+ specialised debugging or reduced models. Works with linear_only=True only.
+ relinearise_threshold : float or list[float] or None, optional
+ If None or linear_only=False, no relinearisation occurs.
+ If no_GS=False, triggers a relinearisation when the change in toroidal current
+ since the last linearisation exceeds this relative threshold (float).
+ If no_GS=True, triggers a relinearisation when the absolute change in the descriptors
+ exceeds this threshold(s). If this threshold is scalar then the maximum absolute
+ change is considered otherwise an elementwise comparison occurs.
Notes
-----
- On convergence, the method updates internal state:
- - ``self.currents_vec`` stores the evolved currents.
- - ``self.eq1`` stores the new Grad–Shafranov equilibrium.
- - ``self.profiles1`` stores the updated profile object.
+ The solver simultaneously advances plasma equilibrium and circuit dynamics,
+ using the linearised operators built by :meth:`relinearise`. If convergence
+ is not achieved, the method stops after ``max_solving_iterations`` iterations.
Raises
------
RuntimeError
- If the nonlinear solve does not converge within ``max_solving_iterations``.
+ If the nonlinear solver fails to converge.
"""
+ # GS can only be disabled in linear-only mode
+ if no_GS and not linear_only:
+ raise ValueError(
+ "The flag 'no_GS' can only be True when 'linear_only=True'."
+ )
+
+ # evaluate whether relinearisation criterion met or not
+ relinearise = False
+ self.relinearise_criteria = 0.0
+ if relinearise_threshold is not None:
+
+ # compare relative change in plasma descriptors since last linearisation for noGS
+ if no_GS:
+ if isinstance(relinearise_threshold, list):
+ relinearise_threshold = [
+ np.inf if r is None else r for r in relinearise_threshold
+ ]
+ relinearise_threshold = np.atleast_1d(np.array(relinearise_threshold))
+
+ if len(relinearise_threshold) == 1:
+ self.relinearise_criteria = np.max(
+ np.abs(
+ self.plasma_descriptors_vec
+ - self.initial_plasma_descriptors
+ )
+ # / (np.abs(self.initial_plasma_descriptors) + 1e-16)
+ )
+ relinearise = (
+ self.relinearise_criteria >= relinearise_threshold.item()
+ )
+ else:
+ self.relinearise_criteria = np.abs(
+ self.plasma_descriptors_vec - self.initial_plasma_descriptors
+ )
+ # / (np.abs(self.initial_plasma_descriptors) + 1e-16)
+ relinearise = (
+ self.relinearise_criteria >= relinearise_threshold
+ ).any()
+
+ # compare relative change in jtor since last linearisation otherwise
+ else:
+ self.relinearise_criteria = np.linalg.norm(
+ self.profiles1.jtor - self.jtor0
+ ) / np.linalg.norm(self.jtor0)
+ relinearise = self.relinearise_criteria >= relinearise_threshold
+
+ if linear_only and relinearise:
+ print("Re-linearising around current equilibrium!")
+ # before relinearisation we need to solve GS to update the eq object and obtain new plasma descriptors
+ if no_GS:
+ self.assign_currents_solve_GS(self.trial_currents, 1e-7)
+ self.step_complete_assign(working_relative_tol_GS, from_linear=True)
+ print(
+ f" Absolute relinearisation criteria change = {np.round(self.relinearise_criteria, 3)} "
+ f"(threshold = {np.round(relinearise_threshold, 3)}) "
+ )
+ else:
+ print(
+ f" Relative relinearisation criteria change = {np.round(self.relinearise_criteria * 100, 3)}% "
+ f"(threshold = {np.round(relinearise_threshold * 100, 3)}%) "
+ )
+ self.relinearise(verbose=verbose)
+
+ # we ned to update the initial descriptors to the values at the relinearisation time
+ if no_GS:
+ self.initial_plasma_descriptors = self.plasma_descriptors_vec
+
# retrieve the old profile parameter values
self.get_profiles_values(self.profiles1)
old_params = self.profiles_parameters_vec
@@ -2747,12 +3038,12 @@ def nlstepper(
old_betas = old_params[self.profiles_beta_indices]
new_alphas = new_params[self.profiles_alpha_indices]
new_betas = new_params[self.profiles_beta_indices]
- dtheta_dt = (
+ self.dtheta_dt = (
np.concatenate((new_alphas, new_betas))
- np.concatenate((old_alphas, old_betas))
) / self.dt_step
else:
- dtheta_dt = (new_params - old_params) / self.dt_step
+ self.dtheta_dt = (new_params - old_params) / self.dt_step
# check if plasma resistivity is being evolved
# and action the change where necessary
@@ -2764,8 +3055,7 @@ def nlstepper(
# results in preparation for the nonlinear calculations
# Solution and GS equilibrium are assigned to self.trial_currents and self.trial_plasma_psi
self.set_linear_solution(
- active_voltage_vec=active_voltage_vec,
- dtheta_dt=dtheta_dt,
+ active_voltage_vec=active_voltage_vec, dtheta_dt=self.dtheta_dt, no_GS=no_GS
)
# check Matrix is still applicable
@@ -2781,6 +3071,22 @@ def nlstepper(
"domain pixels. The linearization may not be accurate.",
)
+ # when not solving GS, evolve the plasma descriptors
+ if no_GS:
+ # extract correct profile parameters
+ if self.profiles_type == "Lao85":
+ new_alphas = new_params[self.profiles_alpha_indices]
+ new_betas = new_params[self.profiles_beta_indices]
+ self.profiles_vec = np.concatenate((new_alphas, new_betas))
+ else:
+ self.profiles_vec = new_params
+
+ # solve for new descriptors
+ self.plasma_descriptors_vec = self.new_plasma_descriptors(
+ new_currents=self.currents_vec,
+ new_profiles=self.profiles_vec,
+ )
+
else:
# seek solution of the full nonlinear problem
diff --git a/freegsnke/observable_registry.py b/freegsnke/observable_registry.py
new file mode 100644
index 00000000..2642675b
--- /dev/null
+++ b/freegsnke/observable_registry.py
@@ -0,0 +1,58 @@
+from typing import Any, Callable, TypeAlias
+
+ObservableFunc: TypeAlias = Callable[[float], Any]
+
+
+class ObservableRegistry:
+ """
+ An observable registry provides an abstracted means of obtaining observables
+ regarding an equilibrium and associated parameters. The idea is that any diagnostic
+ or postprocessing logic that uses this class to obtain parameters it needs does not
+ need to know how those parameters were obtained. For example, `betap` may be
+ obtained in a fully precise way from a FreeGNSKE equilibrium, but it may also be
+ calculated from noisy diagnositcs, or predicted using an emulator.
+
+ TODO(Matthew): I don't love how this is implemented, we need to be able to make this
+ useful generally and so it should support getting across the history
+ of an equilibrium and its evolution. As such, we maybe shouldn't
+ take timestamp in `get` but rather something that better reflects
+ other uses. This would then change the Callable definition.
+ """
+
+ def __init__(self):
+ # The registry is really just a dictionary of named functions called via the
+ # ObservableRegistry.get method.
+ self._observables: dict[str, ObservableFunc] = {}
+
+ def register(self, name: str, fn: ObservableFunc):
+ """
+ Register the named observable to be computed using the provided function.
+
+ Parameters
+ ----------
+ name : str
+ The name of the observable to register.
+ fn : ObservableFunc
+ The function with which the named observable is computed.
+ """
+ self._observables[name] = fn
+
+ def has(self, name: str) -> bool:
+ """
+ Determines if an observable with the given name has been registered.
+
+ Parameters
+ ----------
+ name : str
+ The name of the observable to determine registration of.
+ """
+ return name in self._observables
+
+ def get(self, name: str, timestamp: float) -> Any | None:
+ """
+ Obtains the value of the given observable at the given timestamp. Note that
+ """
+ if name not in self._observables:
+ return None
+
+ return self._observables[name](timestamp)