diff --git a/Examples/AD_equinox_optax.ipynb b/Examples/AD_equinox_optax.ipynb new file mode 100644 index 0000000..f747335 --- /dev/null +++ b/Examples/AD_equinox_optax.ipynb @@ -0,0 +1,615 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Notebook with Optax, Equinox, and JAX" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "import equinox as eqx\n", + "import functools\n", + "import jaxopt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "@functools.partial(jax.vmap, in_axes=(None, 0))\n", + "def network(params, x):\n", + " return jnp.dot(params, x)\n", + "\n", + "@jax.jit\n", + "def compute_loss(params, x, y):\n", + " y_pred = network(params, x)\n", + " loss = jnp.mean(optax.l2_loss(y_pred, y))\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + } + ], + "source": [ + "key = jax.random.PRNGKey(42)\n", + "target_params = 0.5\n", + "\n", + "# Generate some data.\n", + "xs = jax.random.normal(key, (16, 2))\n", + "ys = jnp.sum(xs * target_params, axis=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "start_learning_rate = 1e-1\n", + "optimizer = optax.adam(start_learning_rate)\n", + "\n", + "# Initialize parameters of the model + optimizer.\n", + "init_params = jnp.array([0.0, 0.0])\n", + "opt_state = optimizer.init(init_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss: 0.3705595135688782\n", + "Loss: 0.237158864736557\n", + "Loss: 0.13446101546287537\n" + ] + } + ], + "source": [ + "# A simple update loop.\n", + "params = init_params\n", + "for i in range(3):\n", + " loss_value, grads = jax.value_and_grad(compute_loss)(params, xs, ys)\n", + " print(f\"Loss: {loss_value}\")\n", + " if loss_value < 1e-15:\n", + " break\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.29512683, 0.2951268 ], dtype=float32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "params" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduce Equinox and a Model class" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class Model(eqx.Module):\n", + " xs: jnp.ndarray\n", + " ys: jnp.ndarray\n", + "\n", + " def __init__(self, seed: int=42, target_params: float=0.5):\n", + " key = jax.random.PRNGKey(seed)\n", + " target_params = 0.5\n", + "\n", + " # Generate some data.\n", + " self.xs = jax.random.normal(key, (16, 2))\n", + " self.ys = jnp.sum(xs * target_params, axis=-1)\n", + "\n", + " def model_network(self, params):\n", + " return network(params, self.xs)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def loss_function(params, model):\n", + " y_pred = model.model_network(params)\n", + " y_target = model.ys\n", + " loss = jnp.mean(optax.l2_loss(y_pred, y_target))\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "model = Model()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(16,)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.ys.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## JAXOPT LBFGSB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:1.0 Decrease Error:0.0 Curvature Error:0.0 \n", + "INFO: jaxopt.LBFGS: Iter: 1 Gradient Norm (stop. crit.): 0.7238786816596985 Objective Value:0.14937657117843628 Stepsize:1.0 Number Linesearch Iterations:1 \n", + "INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:1.0 Decrease Error:0.0 Curvature Error:0.0 \n", + "INFO: jaxopt.LBFGS: Iter: 2 Gradient Norm (stop. crit.): 0.052312254905700684 Objective Value:0.0011559441918507218 Stepsize:1.0 Number Linesearch Iterations:1 \n", + "INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:1.0 Decrease Error:0.0 Curvature Error:0.0 \n", + "INFO: jaxopt.LBFGS: Iter: 3 Gradient Norm (stop. crit.): 0.014985358342528343 Objective Value:0.00010153277253266424 Stepsize:1.0 Number Linesearch Iterations:1 \n", + "INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:1.0 Decrease Error:0.0 Curvature Error:0.0 \n", + "INFO: jaxopt.LBFGS: Iter: 4 Gradient Norm (stop. crit.): 2.0062932890141383e-05 Objective Value:1.2456524700610316e-10 Stepsize:1.0 Number Linesearch Iterations:1 \n" + ] + }, + { + "data": { + "text/plain": [ + "OptStep(params=Array([0.4999901 , 0.49999213], dtype=float32), state=LbfgsState(iter_num=Array(4, dtype=int32, weak_type=True), value=Array(1.2456525e-10, dtype=float32), grad=Array([-1.7826915e-05, -9.2044775e-06], dtype=float32), stepsize=Array(1., dtype=float32), error=Array(2.0062933e-05, dtype=float32), s_history=Array([[ 0.9058708 , 0.5763672 ],\n", + " [-0.38845462, -0.11766708],\n", + " [-0.01448965, 0.02802739],\n", + " [-0.00293642, 0.01326463],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ]], dtype=float32), y_history=Array([[ 1.6217207 , 0.68388146],\n", + " [-0.6877681 , -0.15165024],\n", + " [-0.02373748, 0.02979416],\n", + " [-0.00436213, 0.01433262],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ],\n", + " [ 0. , 0. ]], dtype=float32), rho_history=Array([5.3670061e-01, 3.5086374e+00, 8.4817627e+02, 4.9279062e+03,\n", + " 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n", + " 0.0000000e+00, 0.0000000e+00], dtype=float32), gamma=Array(0.90409404, dtype=float32), aux=None, failed_linesearch=Array(False, dtype=bool), num_fun_eval=Array(9, dtype=int32), num_grad_eval=Array(9, dtype=int32), num_linesearch_iter=Array(4, dtype=int32)))" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "params = init_params\n", + "solver_2 = jaxopt.LBFGS(fun=loss_function, maxiter = 100, verbose=True)\n", + "solver_2.run(init_params=init_params, model=model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss 0: 0.3705595135688782\n", + "Loss 1: 0.12533925473690033\n", + "Loss 2: 0.0013298687990754843\n", + "Loss 3: 0.06173867732286453\n", + "Loss 4: 0.17377790808677673\n", + "Loss 5: 0.19612154364585876\n", + "Loss 6: 0.1262005716562271\n", + "Loss 7: 0.03880932182073593\n", + "Loss 8: 0.00013180097448639572\n", + "Loss 9: 0.02495267614722252\n", + "Loss 10: 0.07474566251039505\n", + "Loss 11: 0.09856395423412323\n", + "Loss 12: 0.07876671850681305\n", + "Loss 13: 0.03622499853372574\n", + "Loss 14: 0.004622247070074081\n", + "Loss 15: 0.0033412498887628317\n", + "Loss 16: 0.025381486862897873\n", + "Loss 17: 0.04659378528594971\n", + "Loss 18: 0.04799724742770195\n", + "Loss 19: 0.030167732387781143\n", + "Loss 20: 0.00887465849518776\n", + "Loss 21: 8.324929012815119e-07\n", + "Loss 22: 0.0071021090261638165\n", + "Loss 23: 0.02030024118721485\n", + "Loss 24: 0.0263848677277565\n", + "Loss 25: 0.020409852266311646\n", + "Loss 26: 0.008420398458838463\n", + "Loss 27: 0.0005593497771769762\n", + "Loss 28: 0.0019398077856749296\n", + "Loss 29: 0.00903375819325447\n", + "Loss 30: 0.013975651003420353\n", + "Loss 31: 0.012135028839111328\n", + "Loss 32: 0.005696813575923443\n", + "Loss 33: 0.0006245278054848313\n", + "Loss 34: 0.0007463011424988508\n", + "Loss 35: 0.004612915217876434\n", + "Loss 36: 0.007606457453221083\n", + "Loss 37: 0.006707998923957348\n", + "Loss 38: 0.0030708550475537777\n", + "Loss 39: 0.00027308432618156075\n", + "Loss 40: 0.0005436749197542667\n", + "Loss 41: 0.002829235512763262\n", + "Loss 42: 0.004314153455197811\n", + "Loss 43: 0.003452245146036148\n", + "Loss 44: 0.0012931658420711756\n", + "Loss 45: 2.544407470850274e-05\n", + "Loss 46: 0.000600383267737925\n", + "Loss 47: 0.00195945892482996\n", + "Loss 48: 0.0024300753138959408\n", + "Loss 49: 0.0015480443835258484\n", + "Loss 50: 0.000346560962498188\n", + "Loss 51: 3.357574314577505e-05\n", + "Loss 52: 0.0006870031356811523\n", + "Loss 53: 0.0013502087676897645\n", + "Loss 54: 0.001224467996507883\n", + "Loss 55: 0.0005078407702967525\n", + "Loss 56: 1.73994449141901e-05\n", + "Loss 57: 0.00019082888320554048\n", + "Loss 58: 0.0006623659282922745\n", + "Loss 59: 0.000800388224888593\n", + "Loss 60: 0.0004602676199283451\n", + "Loss 61: 6.813267100369558e-05\n", + "Loss 62: 4.137941141379997e-05\n", + "Loss 63: 0.00031056831358000636\n", + "Loss 64: 0.00047908653505146503\n", + "Loss 65: 0.0003383336297702044\n", + "Loss 66: 8.088175673037767e-05\n", + "Loss 67: 6.7552518885349855e-06\n", + "Loss 68: 0.00014920612738933414\n", + "Loss 69: 0.00027894083177670836\n", + "Loss 70: 0.00022330899082589895\n", + "Loss 71: 6.565159128513187e-05\n", + "Loss 72: 9.255417126041721e-07\n", + "Loss 73: 7.803384505677968e-05\n", + "Loss 74: 0.00016383372712880373\n", + "Loss 75: 0.0001386395888403058\n", + "Loss 76: 4.3154672312084585e-05\n", + "Loss 77: 2.999283310600731e-07\n", + "Loss 78: 4.614111821865663e-05\n", + "Loss 79: 9.868330380413681e-05\n", + "Loss 80: 8.239800808951259e-05\n", + "Loss 81: 2.3992310161702335e-05\n", + "Loss 82: 5.38855999820953e-07\n", + "Loss 83: 3.086286596953869e-05\n", + "Loss 84: 6.0833353927591816e-05\n", + "Loss 85: 4.6722678234800696e-05\n", + "Loss 86: 1.103203248931095e-05\n", + "Loss 87: 1.3235082860774128e-06\n", + "Loss 88: 2.2550448193214834e-05\n", + "Loss 89: 3.768577516893856e-05\n", + "Loss 90: 2.4701799702597782e-05\n", + "Loss 91: 3.7494082789635286e-06\n", + "Loss 92: 2.5068302420550026e-06\n", + "Loss 93: 1.7073180060833693e-05\n", + "Loss 94: 2.2736299797543325e-05\n", + "Loss 95: 1.1604023711697664e-05\n", + "Loss 96: 6.126589369159774e-07\n", + "Loss 97: 3.6579083371179877e-06\n", + "Loss 98: 1.2684155080933124e-05\n", + "Loss 99: 1.2782886187778786e-05\n", + "Loss 100: 4.395686119096354e-06\n", + "Loss 101: 1.4398995773490242e-08\n", + "Loss 102: 4.303292371332645e-06\n", + "Loss 103: 8.784206329437438e-06\n", + "Loss 104: 6.284868959482992e-06\n", + "Loss 105: 1.043906195263844e-06\n", + "Loss 106: 5.784594918623043e-07\n", + "Loss 107: 4.177700247964822e-06\n", + "Loss 108: 5.372005944082048e-06\n", + "Loss 109: 2.417846644675592e-06\n", + "Loss 110: 3.3921040198947594e-08\n", + "Loss 111: 1.291252146984334e-06\n", + "Loss 112: 3.3414621611882467e-06\n", + "Loss 113: 2.694527893254417e-06\n", + "Loss 114: 5.485443921315891e-07\n", + "Loss 115: 1.6461275720303092e-07\n", + "Loss 116: 1.5862076452322071e-06\n", + "Loss 117: 2.1367354747781064e-06\n", + "Loss 118: 9.619462844057125e-07\n", + "Loss 119: 9.73589386887852e-09\n", + "Loss 120: 5.533798344004026e-07\n", + "Loss 121: 1.3530473097489448e-06\n", + "Loss 122: 1.0095428706335952e-06\n", + "Loss 123: 1.5624914340151008e-07\n", + "Loss 124: 1.1541985145413491e-07\n", + "Loss 125: 7.141218247852521e-07\n", + "Loss 126: 8.183931186067639e-07\n", + "Loss 127: 2.8267874085941003e-07\n", + "Loss 128: 3.273285198446274e-09\n", + "Loss 129: 3.160826054227073e-07\n", + "Loss 130: 5.62515992896806e-07\n", + "Loss 131: 3.1368051622848725e-07\n", + "Loss 132: 1.411469607859317e-08\n", + "Loss 133: 1.1377032649306784e-07\n", + "Loss 134: 3.4310966157136136e-07\n", + "Loss 135: 2.7551737957765e-07\n", + "Loss 136: 4.547453258396672e-08\n", + "Loss 137: 2.991617620295983e-08\n", + "Loss 138: 1.9068534129473846e-07\n", + "Loss 139: 2.1063939925625164e-07\n", + "Loss 140: 6.309591782382995e-08\n", + "Loss 141: 3.804345727331793e-09\n", + "Loss 142: 9.825419056141982e-08\n", + "Loss 143: 1.4709161177961505e-07\n", + "Loss 144: 6.428952303849655e-08\n", + "Loss 145: 6.239380540007389e-11\n", + "Loss 146: 4.7505857025953446e-08\n", + "Loss 147: 9.6699508844722e-08\n", + "Loss 148: 5.569851424525041e-08\n", + "Loss 149: 2.2289132761699193e-09\n", + "Loss 150: 2.1790121706999344e-08\n", + "Loss 151: 6.116655981713848e-08\n", + "Loss 152: 4.3732992338618715e-08\n", + "Loss 153: 4.335954706391476e-09\n", + "Loss 154: 9.596959671398508e-09\n", + "Loss 155: 3.785125102240272e-08\n", + "Loss 156: 3.222789146661853e-08\n", + "Loss 157: 5.0790056604910205e-09\n", + "Loss 158: 4.125412900179981e-09\n", + "Loss 159: 2.325440462414008e-08\n", + "Loss 160: 2.277997523947306e-08\n", + "Loss 161: 4.739058034886057e-09\n", + "Loss 162: 1.7799456353273513e-09\n", + "Loss 163: 1.4346271726140003e-08\n", + "Loss 164: 1.566213292392149e-08\n", + "Loss 165: 3.874447873641884e-09\n", + "Loss 166: 8.056747491380634e-10\n", + "Loss 167: 8.978004117921046e-09\n", + "Loss 168: 1.0580302856055823e-08\n", + "Loss 169: 2.8876094848584444e-09\n", + "Loss 170: 4.070161430114183e-10\n", + "Loss 171: 5.740678155063961e-09\n", + "Loss 172: 7.059888673666137e-09\n", + "Loss 173: 2.0058337213413324e-09\n", + "Loss 174: 2.43222553208966e-10\n", + "Loss 175: 3.767480993843719e-09\n", + "Loss 176: 4.665432484785015e-09\n", + "Loss 177: 1.3064589410305416e-09\n", + "Loss 178: 1.7626082537969268e-10\n", + "Loss 179: 2.5374908840802846e-09\n", + "Loss 180: 3.056962238900951e-09\n", + "Loss 181: 7.975731741716174e-10\n", + "Loss 182: 1.5096457417484999e-10\n", + "Loss 183: 1.7494150572616718e-09\n", + "Loss 184: 1.9783685800689454e-09\n", + "Loss 185: 4.5333042875128626e-10\n", + "Loss 186: 1.4345685750427606e-10\n", + "Loss 187: 1.225659129744372e-09\n", + "Loss 188: 1.2603058596738492e-09\n", + "Loss 189: 2.3224130951682298e-10\n", + "Loss 190: 1.4217682586803448e-10\n", + "Loss 191: 8.68282390431574e-10\n", + "Loss 192: 7.83208831123261e-10\n", + "Loss 193: 1.0280110790406027e-10\n", + "Loss 194: 1.4245790658229396e-10\n", + "Loss 195: 6.144409425701269e-10\n", + "Loss 196: 4.676999343367072e-10\n", + "Loss 197: 3.553232708064513e-11\n", + "Loss 198: 1.3867462733685443e-10\n", + "Loss 199: 4.2995784710342377e-10\n", + "Loss 200: 2.6711752254549026e-10\n", + "Loss 201: 6.589541412527211e-12\n", + "Loss 202: 1.305643648752408e-10\n", + "Loss 203: 2.930397313694044e-10\n", + "Loss 204: 1.4090817401779532e-10\n", + "Loss 205: 9.174952464441333e-15\n", + "Loss 206: 1.1694259449690492e-10\n", + "Loss 207: 1.9275579743460725e-10\n", + "Loss 208: 6.620887171848722e-11\n", + "Loss 209: 4.214469051522229e-12\n", + "Loss 210: 9.977262072080606e-11\n", + "Loss 211: 1.2057402298815134e-10\n", + "Loss 212: 2.587291665634428e-11\n", + "Loss 213: 1.1754161879928837e-11\n", + "Loss 214: 8.035444531984126e-11\n", + "Loss 215: 6.989845507954229e-11\n", + "Loss 216: 6.716434700071439e-12\n", + "Loss 217: 1.825243428621448e-11\n", + "Loss 218: 6.076163877599683e-11\n", + "Loss 219: 3.612675436581725e-11\n", + "Loss 220: 3.573200763051787e-13\n", + "Loss 221: 2.1771079730670273e-11\n", + "Loss 222: 4.134578790448984e-11\n", + "Loss 223: 1.546416467101963e-11\n", + "Loss 224: 7.056299988761339e-13\n", + "Loss 225: 2.1127920593611016e-11\n", + "Loss 226: 2.550295565006966e-11\n", + "Loss 227: 4.966346778267905e-12\n", + "Loss 228: 3.169471629593801e-12\n", + "Loss 229: 1.8025735418203404e-11\n", + "Loss 230: 1.388602600960187e-11\n", + "Loss 231: 6.704376637189924e-13\n", + "Loss 232: 5.362476088177637e-12\n", + "Loss 233: 1.3374122989628923e-11\n", + "Loss 234: 5.995551277671041e-12\n", + "Loss 235: 6.05696048872062e-14\n", + "Loss 236: 6.1493379832633366e-12\n", + "Loss 237: 8.423458111583848e-12\n", + "Loss 238: 1.7896031878628094e-12\n", + "Loss 239: 8.984219618257683e-13\n", + "Loss 240: 5.673635172787073e-12\n", + "Loss 241: 4.373982079308725e-12\n", + "Loss 242: 2.0531666644618696e-13\n", + "Loss 243: 1.8369264442874567e-12\n", + "Loss 244: 4.279916698823882e-12\n", + "Loss 245: 1.7418445158456919e-12\n", + "Loss 246: 6.807228392080589e-14\n", + "Loss 247: 2.2400987315096543e-12\n", + "Loss 248: 2.5243002133024106e-12\n", + "Loss 249: 3.573200763051787e-13\n", + "Loss 250: 4.722628538234019e-13\n", + "Loss 251: 1.919036110575867e-12\n", + "Loss 252: 1.1565887136910646e-12\n", + "Loss 253: 6.510417205340957e-15\n", + "Loss 254: 8.289705727415608e-13\n", + "Loss 255: 1.3337542875691e-12\n", + "Loss 256: 3.0203617384927384e-13\n", + "Loss 257: 1.310496849926679e-13\n", + "Loss 258: 8.511680943401601e-13\n", + "Loss 259: 6.271719255046548e-13\n", + "Loss 260: 2.3420501649162873e-14\n", + "Loss 261: 2.936557247368299e-13\n", + "Loss 262: 6.290679782638975e-13\n", + "Loss 263: 2.0531666644618696e-13\n", + "Loss 264: 2.9145089119886336e-14\n", + "Loss 265: 3.7457016655029207e-13\n", + "Loss 266: 3.19973214590874e-13\n", + "Loss 267: 2.3531523951625388e-14\n", + "Loss 268: 1.0997452948302566e-13\n", + "Loss 269: 2.936557247368299e-13\n", + "Loss 270: 1.28050348102704e-13\n", + "Loss 271: 1.0354564428105562e-14\n", + "Loss 272: 1.527181159310942e-13\n", + "Loss 273: 1.6880420672382712e-13\n", + "Loss 274: 2.070045523883124e-14\n", + "Loss 275: 5.701168703797777e-14\n", + "Loss 276: 1.5941588327184064e-13\n", + "Loss 277: 5.652596446470426e-14\n", + "Loss 278: 1.0354564428105562e-14\n", + "Loss 279: 8.152506447700603e-14\n", + "Loss 280: 7.380207556195728e-14\n", + "Loss 281: 4.865899350114944e-15\n", + "Loss 282: 3.3209546224099995e-14\n", + "Loss 283: 6.807228392080589e-14\n", + "Loss 284: 1.814694228219338e-14\n", + "Loss 285: 4.274358644806853e-15\n", + "Loss 286: 3.489743216622543e-14\n", + "Loss 287: 2.9145089119886336e-14\n", + "Loss 288: 6.123573870198129e-16\n" + ] + } + ], + "source": [ + "# A simple update loop.\n", + "params = init_params\n", + "for i in range(1000):\n", + " loss_value, grads = jax.value_and_grad(loss_function)(params, model)\n", + " print(f\"Loss {i}: {loss_value}\")\n", + " if loss_value < 1e-15:\n", + " break\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "quocsproj3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Examples/Ising_AD_Equinox.ipynb b/Examples/Ising_AD_Equinox.ipynb new file mode 100644 index 0000000..10169a9 --- /dev/null +++ b/Examples/Ising_AD_Equinox.ipynb @@ -0,0 +1,873 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "import equinox as eqx\n", + "import functools\n", + "import jaxopt\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "256" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.power(2, 8)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from jax.scipy.linalg import sqrtm\n", + "from quocslib.tools.fidelities import fidelity_AD\n", + "from quocslib.timeevolution.piecewise_integrator_AD import pw_final_evolution_AD_scan, pw_final_evolution_AD\n", + "\n", + "\n", + "class IsingModel(eqx.Module):\n", + "\n", + " n_qubits: int\n", + " J: float\n", + " g: float\n", + " n_slices: int\n", + " H_drift: jnp.ndarray\n", + " H_control: jnp.ndarray\n", + " rho_0: jnp.ndarray\n", + " sqrt_rho_target: jnp.ndarray\n", + " rho_target: jnp.ndarray\n", + " rho_final: jnp.ndarray\n", + " u0: jnp.array\n", + " # _pw_evolution_transform: callable\n", + "\n", + "\n", + " def __init__(self, args_dict: dict = None):\n", + " if args_dict is None:\n", + " args_dict = {}\n", + " ################################################################################################################\n", + " # Dynamics variables\n", + " ################################################################################################################\n", + " self.n_qubits = args_dict.setdefault(\"n_qubits\", 5)\n", + " self.J = args_dict.setdefault(\"J\", 1)\n", + " self.g = args_dict.setdefault(\"g\", 2)\n", + " self.n_slices = args_dict.setdefault(\"n_slices\", 100)\n", + "\n", + " self.H_drift = jnp.asarray(get_static_hamiltonian(self.n_qubits, self.J, self.g))\n", + " self.H_control = jnp.asarray(get_control_hamiltonian(self.n_qubits))\n", + " self.rho_0 = jnp.asarray(get_initial_state(self.n_qubits))\n", + " self.rho_target = jnp.asarray(get_target_state(self.n_qubits))\n", + " self.rho_final = jnp.asarray(jnp.zeros_like(self.rho_target))\n", + " self.sqrt_rho_target = sqrtm(self.rho_target)\n", + " self.u0 = jnp.identity(2 ** self.n_qubits, dtype=np.complex128)\n", + "\n", + " # Let JAX know to jit the following function\n", + " @jax.jit\n", + " def _pw_evolution_transform(drive, dt):\n", + " \"\"\"\n", + " A wrapper function for the piecewise evolution function of QuOCS\n", + " :param drive: list of drive pulses\n", + " :param dt: time step\n", + " :return: final unitary propagator\n", + " \"\"\"\n", + " return pw_final_evolution_AD(drive,\n", + " self.H_drift,\n", + " jnp.asarray([self.H_control]),\n", + " self.n_slices,\n", + " dt,\n", + " jnp.identity(2 ** self.n_qubits, dtype=np.complex128))\n", + "\n", + " # self._pw_evolution_transform = _pw_evolution_transform\n", + "\n", + " def get_control_Hamiltonians(self):\n", + " return self.H_control\n", + "\n", + " def get_drift_Hamiltonian(self):\n", + " return self.H_drift\n", + "\n", + " def get_target_state(self):\n", + " return self.rho_target\n", + "\n", + " def get_initial_state(self):\n", + " return self.rho_0\n", + "\n", + " def get_propagator(self,\n", + " pulse: jnp.ndarray,\n", + " timegrid: jnp.ndarray) -> jnp.ndarray:\n", + " \"\"\"\n", + " Function to calculate the propagator from the pulses, parameters and timegrids.\n", + " :param pulses_list:\n", + " :param time_grids_list:\n", + " :param parameters_list:\n", + " :return: final propagator\n", + " \"\"\"\n", + "\n", + " # drive = pulses_list[0, :].reshape(1, len(pulses_list[0, :]))\n", + " drive = pulse.reshape(1, len(pulse))\n", + " # time_grid = time_grids_list[0, :]\n", + " dt = timegrid[-1] / len(timegrid)\n", + "\n", + " # Compute the time evolution\n", + " # propagator = self._pw_evolution_transform(drive, dt)\n", + " propagator = pw_final_evolution_AD_scan(drive,\n", + " self.H_drift,\n", + " jnp.asarray([self.H_control]),\n", + " self.n_slices,\n", + " dt,\n", + " self.u0)\n", + "\n", + " return propagator\n", + "\n", + " def get_final_state(self,\n", + " pulse: jnp.ndarray,\n", + " timegrid: jnp.ndarray) -> jnp.array:\n", + " \"\"\"\n", + " Function to calculate the final state from the pulse.\n", + " :param pulses: jnp.arrays of the pulses to be optimized.\n", + " :param timegrids: jnp.arrays of the timegrids connected to the pulses.\n", + " :param parameters: jnp.array of the parameters to be optimized.\n", + " :return dict: The figure of merit in a dictionary\n", + " \"\"\"\n", + " U_final = self.get_propagator(pulse=pulse, timegrid=timegrid)\n", + " # print(U_final)\n", + " rho_final = U_final @ self.rho_0 @ U_final.T.conj()\n", + " return rho_final\n", + " # fidelity = fom_funct(rho_final, self.rho_target)\n", + " # return {\"FoM\": fidelity}\n", + "\n", + "\n", + "i2 = np.eye(2)\n", + "sz = 0.5 * np.array([[1, 0], [0, -1]], dtype=np.complex128)\n", + "sx = 0.5 * np.array([[0, 1], [1, 0]], dtype=np.complex128)\n", + "psi0 = np.array([[1, 0], [0, 0]], dtype=np.complex128)\n", + "psiT = np.array([[0, 0], [0, 1]], dtype=np.complex128)\n", + "\n", + "\n", + "def tensor_together(A):\n", + " res = np.kron(A[0], A[1])\n", + " if len(A) > 2:\n", + " for two in A[2:]:\n", + " res = np.kron(res, two)\n", + " else:\n", + " res = res\n", + " return res\n", + "\n", + "\n", + "def get_static_hamiltonian(nqu, J, g):\n", + " dim = 2**nqu\n", + " H0 = np.zeros((dim, dim), dtype=np.complex128)\n", + " for j in range(nqu):\n", + " # set up holding array\n", + " rest = [i2] * nqu\n", + " # set the correct elements to sz\n", + " # check, so we can implement a loop around\n", + " if j == nqu - 1:\n", + " idx1 = j\n", + " idx2 = 0\n", + " else:\n", + " idx1 = j\n", + " idx2 = j + 1\n", + " rest[idx1] = sz\n", + " rest[idx2] = sz\n", + " H0 = H0 - J * tensor_together(rest)\n", + "\n", + " for j in range(nqu):\n", + " # set up holding array\n", + " rest = [i2] * nqu\n", + " # set the correct elements to sz\n", + " # check, so we can implement a loop around\n", + " if j == nqu - 1:\n", + " idx1 = j\n", + " idx2 = 1\n", + " elif j == nqu - 2:\n", + " idx1 = j\n", + " idx2 = 0\n", + " else:\n", + " idx1 = j\n", + " idx2 = j + 2\n", + " rest[idx1] = sz\n", + " rest[idx2] = sz\n", + " H0 = H0 - g * tensor_together(rest)\n", + " return H0\n", + "\n", + "\n", + "def get_control_hamiltonian(nqu: int):\n", + " dim = 2**nqu\n", + " H_at_t = np.zeros((dim, dim), dtype=np.complex128)\n", + " for j in range(nqu):\n", + " # set up holding array\n", + " rest = [i2] * nqu\n", + " # set the correct elements to sx\n", + " rest[j] = sx\n", + " H_at_t = H_at_t + tensor_together(rest)\n", + " return H_at_t\n", + "\n", + "\n", + "def get_initial_state(nqu: int):\n", + " state = [psi0] * nqu\n", + " return tensor_together(state)\n", + "\n", + "\n", + "def get_target_state(nqu: int):\n", + " state = [psiT] * nqu\n", + " return tensor_together(state)\n", + "\n", + "\n", + "@jax.jit\n", + "def fom_funct(rho_evolved, sqrt_rho_aim):\n", + " \"\"\"\n", + " Function to calculate the overlap between two density matrices.\n", + " :param rho_evolved:\n", + " :param rho_aim:\n", + " :return: overlap fidelity\n", + " \"\"\"\n", + " fidelity = fidelity_AD(sqrt_rho_aim, rho_evolved)\n", + " return fidelity" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "\n", + "# @partial(jax.jit, static_argnums=2)\n", + "@jax.jit\n", + "def loss_function(params, model, timegrid):\n", + " final_state = model.get_final_state(pulse=params, timegrid=timegrid)\n", + " loss = 1.0 - fom_funct(final_state, model.sqrt_rho_target)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + } + ], + "source": [ + "model = IsingModel(args_dict={\"n_qubits\": 5, \"J\": 1, \"g\": 2, \"n_slices\": 100})" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Initial parameters\n", + "timegrid = jnp.linspace(0.0, 1.0, model.n_slices)\n", + "key = jax.random.PRNGKey(50)\n", + "init_params = jax.random.normal(key, (model.n_slices,), dtype=jnp.float64) * 30" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# # Unconstrained optimization\n", + "# params = init_params\n", + "# solver_unconstrained = jaxopt.LBFGS(fun=loss_function, maxiter = 100, verbose=True)\n", + "# result_unconstrained = solver_unconstrained.run(params, model, timegrid)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# loss_function(result_unconstrained.params, model, timegrid)\n", + "# result_unconstrained.params" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUNNING THE L-BFGS-B CODE\n", + "\n", + " * * *\n", + "\n", + "Machine precision = 2.220D-16\n", + " N = 100 M = 10\n", + "\n", + "At X0 0 variables are exactly at the bounds\n", + "\n", + "At iterate 0 f= 9.07539D-01 |proj g|= 2.91678D-03\n", + "\n", + "At iterate 1 f= 9.07217D-01 |proj g|= 2.92581D-03\n", + " ys=-1.153E-06 -gs= 3.213E-04 BFGS update SKIPPED\n", + "\n", + "At iterate 2 f= 2.13250D-01 |proj g|= 2.46702D-03\n", + "\n", + "At iterate 3 f= 1.79902D-01 |proj g|= 3.62414D-03\n", + "\n", + "At iterate 4 f= 1.53983D-01 |proj g|= 2.13901D-03\n", + "\n", + "At iterate 5 f= 1.41952D-01 |proj g|= 1.46240D-03\n", + "\n", + "At iterate 6 f= 1.24540D-01 |proj g|= 2.40044D-03\n", + "\n", + "At iterate 7 f= 1.06012D-01 |proj g|= 2.93769D-03\n", + "\n", + "At iterate 8 f= 8.99431D-02 |proj g|= 2.42906D-03\n", + "\n", + "At iterate 9 f= 6.83901D-02 |proj g|= 1.30597D-03\n", + "\n", + "At iterate 10 f= 5.59505D-02 |proj g|= 5.16569D-04\n", + "\n", + "At iterate 11 f= 5.12111D-02 |proj g|= 3.01676D-04\n", + "\n", + "At iterate 12 f= 4.92608D-02 |proj g|= 3.08979D-04\n", + "\n", + "At iterate 13 f= 4.82851D-02 |proj g|= 2.61103D-04\n", + "\n", + "At iterate 14 f= 4.65644D-02 |proj g|= 2.61050D-04\n", + "\n", + "At iterate 15 f= 4.15832D-02 |proj g|= 5.38849D-04\n", + "\n", + "At iterate 16 f= 3.39687D-02 |proj g|= 5.21037D-04\n", + "\n", + "At iterate 17 f= 2.48523D-02 |proj g|= 6.53287D-04\n", + "\n", + "At iterate 18 f= 1.32613D-02 |proj g|= 5.95343D-04\n", + "\n", + "At iterate 19 f= 1.23235D-02 |proj g|= 4.54825D-04\n", + "\n", + "At iterate 20 f= 1.20367D-02 |proj g|= 5.32173D-04\n", + "\n", + "At iterate 21 f= 1.14707D-02 |proj g|= 1.95866D-04\n", + "\n", + "At iterate 22 f= 1.13951D-02 |proj g|= 1.27256D-04\n", + "\n", + "At iterate 23 f= 1.12615D-02 |proj g|= 1.41486D-04\n", + "\n", + "At iterate 24 f= 1.08380D-02 |proj g|= 1.73639D-04\n", + "\n", + "At iterate 25 f= 9.94562D-03 |proj g|= 2.32858D-04\n", + "\n", + "At iterate 26 f= 7.97085D-03 |proj g|= 2.50614D-04\n", + "\n", + "At iterate 27 f= 7.37679D-03 |proj g|= 5.02153D-04\n", + "\n", + "At iterate 28 f= 6.00745D-03 |proj g|= 2.66585D-04\n", + "\n", + "At iterate 29 f= 5.81270D-03 |proj g|= 5.68141D-05\n", + "\n", + "At iterate 30 f= 5.78955D-03 |proj g|= 5.39699D-05\n", + "\n", + "At iterate 31 f= 5.55432D-03 |proj g|= 6.18602D-05\n", + "\n", + "At iterate 32 f= 4.93623D-03 |proj g|= 1.53044D-04\n", + "\n", + "At iterate 33 f= 4.47174D-03 |proj g|= 1.43372D-04\n", + "\n", + "At iterate 34 f= 3.95106D-03 |proj g|= 5.99204D-05\n", + "\n", + "At iterate 35 f= 3.65240D-03 |proj g|= 8.55730D-05\n", + "\n", + "At iterate 36 f= 3.51046D-03 |proj g|= 9.09376D-05\n", + "\n", + "At iterate 37 f= 3.17293D-03 |proj g|= 7.96691D-05\n", + "\n", + "At iterate 38 f= 2.88185D-03 |proj g|= 4.97541D-05\n", + "\n", + "At iterate 39 f= 2.79225D-03 |proj g|= 2.54529D-04\n", + "\n", + "At iterate 40 f= 2.62712D-03 |proj g|= 7.60501D-05\n", + "\n", + "At iterate 41 f= 2.56534D-03 |proj g|= 4.05593D-05\n", + "\n", + "At iterate 42 f= 2.54707D-03 |proj g|= 4.55404D-05\n", + "\n", + "At iterate 43 f= 2.54140D-03 |proj g|= 4.51762D-05\n", + "\n", + "At iterate 44 f= 2.53075D-03 |proj g|= 2.30345D-05\n", + "\n", + "At iterate 45 f= 2.50801D-03 |proj g|= 4.38706D-05\n", + "\n", + "At iterate 46 f= 2.47093D-03 |proj g|= 8.42758D-05\n", + "\n", + "At iterate 47 f= 2.41582D-03 |proj g|= 1.12465D-04\n", + "\n", + "At iterate 48 f= 2.33527D-03 |proj g|= 1.26286D-04\n", + "\n", + "At iterate 49 f= 2.27944D-03 |proj g|= 7.87102D-05\n", + "\n", + "At iterate 50 f= 2.22975D-03 |proj g|= 3.42571D-05\n", + "\n", + "At iterate 51 f= 2.20604D-03 |proj g|= 4.14458D-05\n", + "\n", + "At iterate 52 f= 2.19903D-03 |proj g|= 4.20987D-05\n", + "\n", + "At iterate 53 f= 2.19335D-03 |proj g|= 4.61738D-05\n", + "\n", + "At iterate 54 f= 2.18186D-03 |proj g|= 3.17922D-05\n", + "\n", + "At iterate 55 f= 2.17108D-03 |proj g|= 5.92841D-05\n", + "\n", + "At iterate 56 f= 2.16157D-03 |proj g|= 4.70063D-05\n", + "\n", + "At iterate 57 f= 2.14720D-03 |proj g|= 3.71802D-05\n", + "\n", + "At iterate 58 f= 2.11371D-03 |proj g|= 4.09236D-05\n", + "\n", + "At iterate 59 f= 2.06090D-03 |proj g|= 3.11918D-05\n", + "\n", + "At iterate 60 f= 2.03666D-03 |proj g|= 8.80935D-05\n", + "\n", + "At iterate 61 f= 1.99258D-03 |proj g|= 4.31227D-05\n", + "\n", + "At iterate 62 f= 1.96272D-03 |proj g|= 2.13248D-05\n", + "\n", + "At iterate 63 f= 1.95962D-03 |proj g|= 1.87338D-05\n", + "\n", + "At iterate 64 f= 1.94990D-03 |proj g|= 3.15452D-05\n", + "\n", + "At iterate 65 f= 1.94151D-03 |proj g|= 4.33426D-05\n", + "\n", + "At iterate 66 f= 1.93140D-03 |proj g|= 2.09992D-05\n", + "\n", + "At iterate 67 f= 1.91590D-03 |proj g|= 2.51751D-05\n", + "\n", + "At iterate 68 f= 1.89012D-03 |proj g|= 5.29816D-05\n", + "\n", + "At iterate 69 f= 1.85613D-03 |proj g|= 3.56835D-05\n", + "\n", + "At iterate 70 f= 1.81603D-03 |proj g|= 1.08870D-04\n", + "\n", + "At iterate 71 f= 1.74239D-03 |proj g|= 5.92094D-05\n", + "\n", + "At iterate 72 f= 1.56365D-03 |proj g|= 1.30912D-04\n", + "\n", + "At iterate 73 f= 1.47260D-03 |proj g|= 1.01277D-04\n", + "\n", + "At iterate 74 f= 1.45400D-03 |proj g|= 1.17827D-04\n", + "\n", + "At iterate 75 f= 1.41175D-03 |proj g|= 3.93404D-05\n", + " Positive dir derivative in projection \n", + " Using the backtracking step \n", + "\n", + "At iterate 76 f= 1.40817D-03 |proj g|= 1.72233D-05\n", + "\n", + "At iterate 77 f= 1.40174D-03 |proj g|= 1.15739D-05\n", + "\n", + "At iterate 78 f= 1.39992D-03 |proj g|= 2.51578D-05\n", + "\n", + "At iterate 79 f= 1.39208D-03 |proj g|= 1.57015D-05\n", + "\n", + "At iterate 80 f= 1.38812D-03 |proj g|= 1.62617D-05\n", + "\n", + "At iterate 81 f= 1.38644D-03 |proj g|= 1.24040D-05\n", + "\n", + "At iterate 82 f= 1.38113D-03 |proj g|= 2.20448D-05\n", + "\n", + "At iterate 83 f= 1.37713D-03 |proj g|= 2.21080D-05\n", + "\n", + "At iterate 84 f= 1.34091D-03 |proj g|= 1.51704D-05\n", + "\n", + "At iterate 85 f= 1.34044D-03 |proj g|= 1.64328D-05\n", + "\n", + "At iterate 86 f= 1.33568D-03 |proj g|= 1.64203D-05\n", + "\n", + "At iterate 87 f= 1.33340D-03 |proj g|= 1.54351D-05\n", + "\n", + "At iterate 88 f= 1.32984D-03 |proj g|= 2.55528D-05\n", + "\n", + "At iterate 89 f= 1.32524D-03 |proj g|= 2.13558D-05\n", + "\n", + "At iterate 90 f= 1.32085D-03 |proj g|= 2.64522D-05\n", + "\n", + "At iterate 91 f= 1.31897D-03 |proj g|= 1.70917D-05\n", + "\n", + "At iterate 92 f= 1.31895D-03 |proj g|= 2.04341D-05\n", + "\n", + "At iterate 93 f= 1.31895D-03 |proj g|= 2.04427D-05\n", + "\n", + " * * *\n", + "\n", + "Tit = total number of iterations\n", + "Tnf = total number of function evaluations\n", + "Tnint = total number of segments explored during Cauchy searches\n", + "Skip = number of BFGS updates skipped\n", + "Nact = number of active bounds at final generalized Cauchy point\n", + "Projg = norm of the final projected gradient\n", + "F = final function value\n", + "\n", + " * * *\n", + "\n", + " N Tit Tnf Tnint Skip Nact Projg F\n", + " 100 93 115 93 1 0 2.044D-05 1.319D-03\n", + " F = 1.3189480151647554E-003\n", + "\n", + "CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH \n" + ] + } + ], + "source": [ + "# Box constrained optimization\n", + "from jaxopt import ScipyBoundedMinimize\n", + "lbounds = [-100.0] * model.n_slices\n", + "ubounds = [100.0] * model.n_slices\n", + "bounds = (lbounds, ubounds)\n", + "lbfgsb = ScipyBoundedMinimize(fun=loss_function, method=\"l-bfgs-b\", options={'disp': True})\n", + "result_bounded = lbfgsb.run(init_params, bounds=bounds, model=model, timegrid=timegrid)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(0.00131895, dtype=float64)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss_function(result_bounded.params, model, timegrid)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sparisification" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9025" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "0.95 * 0.95" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from jax.experimental import sparse" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "model_8 = IsingModel(args_dict={\"n_qubits\": 8, \"J\": 1, \"g\": 2, \"n_slices\": 100})" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# Sparsify the log function\n", + "@jax.jit\n", + "def loss_function(params, model, timegrid):\n", + " final_state = model.get_final_state(pulse=params, timegrid=timegrid)\n", + " loss = 1.0 - fom_funct(final_state, model.sqrt_rho_target)\n", + " return loss\n", + "loss_function_sparse = sparse.sparsify(loss_function)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUNNING THE L-BFGS-B CODE\n", + "\n", + " * * *\n", + "\n", + "Machine precision = 2.220D-16\n", + " N = 100 M = 10\n", + "\n", + "At X0 0 variables are exactly at the bounds\n", + "\n", + "At iterate 0 f= 9.75748D-01 |proj g|= 1.26414D-03\n", + "\n", + "At iterate 1 f= 9.75682D-01 |proj g|= 1.26712D-03\n", + " ys=-1.613E-07 -gs= 6.579E-05 BFGS update SKIPPED\n", + "\n", + "At iterate 2 f= 2.99806D-01 |proj g|= 3.46228D-03\n", + " ys=-7.678E-02 -gs= 1.412E-01 BFGS update SKIPPED\n", + "\n", + "At iterate 3 f= 2.92351D-01 |proj g|= 2.94393D-03\n", + "\n", + "At iterate 4 f= 2.42572D-01 |proj g|= 3.22604D-03\n", + "\n", + "At iterate 5 f= 2.13329D-01 |proj g|= 2.20855D-03\n", + "\n", + "At iterate 6 f= 1.54407D-01 |proj g|= 2.29238D-03\n", + "\n", + "At iterate 7 f= 1.17306D-01 |proj g|= 2.11694D-03\n", + "\n", + "At iterate 8 f= 9.19909D-02 |proj g|= 1.96433D-03\n", + "\n", + "At iterate 9 f= 7.64864D-02 |proj g|= 1.28822D-03\n", + "\n", + "At iterate 10 f= 7.23439D-02 |proj g|= 1.03853D-03\n", + "\n", + "At iterate 11 f= 7.15126D-02 |proj g|= 5.76083D-04\n", + "\n", + "At iterate 12 f= 7.12896D-02 |proj g|= 3.30536D-04\n", + "\n", + "At iterate 13 f= 7.11391D-02 |proj g|= 2.89537D-04\n", + "\n", + "At iterate 14 f= 7.04326D-02 |proj g|= 5.03251D-04\n", + "\n", + "At iterate 15 f= 6.88293D-02 |proj g|= 8.20320D-04\n", + "\n", + "At iterate 16 f= 6.87519D-02 |proj g|= 1.03365D-03\n", + "\n", + "At iterate 17 f= 6.87494D-02 |proj g|= 1.07153D-03\n", + " Positive dir derivative in projection \n", + " Using the backtracking step \n", + "\n", + "At iterate 18 f= 6.79627D-02 |proj g|= 8.20566D-04\n", + "\n", + "At iterate 19 f= 6.32043D-02 |proj g|= 3.65340D-04\n", + "\n", + "At iterate 20 f= 5.00450D-02 |proj g|= 1.53637D-03\n", + "\n", + "At iterate 21 f= 4.65249D-02 |proj g|= 5.56224D-04\n", + "\n", + "At iterate 22 f= 4.57450D-02 |proj g|= 6.71994D-04\n", + "\n", + "At iterate 23 f= 3.65657D-02 |proj g|= 2.57924D-03\n", + "\n", + "At iterate 24 f= 2.38833D-02 |proj g|= 8.07379D-04\n", + "\n", + "At iterate 25 f= 2.25032D-02 |proj g|= 7.25181D-04\n", + "\n", + "At iterate 26 f= 2.17051D-02 |proj g|= 5.26638D-04\n", + "\n", + "At iterate 27 f= 2.15264D-02 |proj g|= 3.97915D-04\n", + "\n", + "At iterate 28 f= 2.12762D-02 |proj g|= 2.36289D-04\n", + "\n", + "At iterate 29 f= 2.10099D-02 |proj g|= 2.83114D-04\n", + "\n", + "At iterate 30 f= 2.04912D-02 |proj g|= 3.19073D-04\n", + "\n", + "At iterate 31 f= 1.93402D-02 |proj g|= 3.71012D-04\n", + "\n", + "At iterate 32 f= 1.64399D-02 |proj g|= 3.56023D-04\n", + "\n", + "At iterate 33 f= 1.47177D-02 |proj g|= 7.38411D-04\n", + "\n", + "At iterate 34 f= 1.11097D-02 |proj g|= 5.54292D-04\n", + "\n", + "At iterate 35 f= 8.85840D-03 |proj g|= 1.11372D-04\n", + "\n", + "At iterate 36 f= 8.26873D-03 |proj g|= 1.31398D-04\n", + "\n", + "At iterate 37 f= 8.19953D-03 |proj g|= 1.49970D-04\n", + "\n", + "At iterate 38 f= 8.06978D-03 |proj g|= 1.07299D-04\n", + "\n", + "At iterate 39 f= 8.02892D-03 |proj g|= 8.98772D-05\n", + "\n", + "At iterate 40 f= 7.87716D-03 |proj g|= 1.53473D-04\n", + "\n", + "At iterate 41 f= 7.59293D-03 |proj g|= 1.49842D-04\n", + "\n", + "At iterate 42 f= 6.85149D-03 |proj g|= 8.36149D-05\n", + "\n", + "At iterate 43 f= 6.31969D-03 |proj g|= 6.20456D-05\n", + "\n", + "At iterate 44 f= 5.98830D-03 |proj g|= 1.37045D-04\n", + "\n", + "At iterate 45 f= 5.93577D-03 |proj g|= 1.60148D-04\n", + "\n", + "At iterate 46 f= 5.89043D-03 |proj g|= 6.74273D-05\n", + "\n", + "At iterate 47 f= 5.87736D-03 |proj g|= 8.26680D-05\n", + "\n", + "At iterate 48 f= 5.82628D-03 |proj g|= 1.04539D-04\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[22], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m bounds \u001b[38;5;241m=\u001b[39m (lbounds, ubounds)\n\u001b[1;32m 4\u001b[0m lbfgsb \u001b[38;5;241m=\u001b[39m ScipyBoundedMinimize(fun\u001b[38;5;241m=\u001b[39mloss_function_sparse, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ml-bfgs-b\u001b[39m\u001b[38;5;124m\"\u001b[39m, options\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdisp\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n\u001b[0;32m----> 5\u001b[0m result_bounded \u001b[38;5;241m=\u001b[39m \u001b[43mlbfgsb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43minit_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbounds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbounds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_8\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimegrid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimegrid\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/jaxopt/_src/implicit_diff.py:251\u001b[0m, in \u001b[0;36m_custom_root..wrapped_solver_fun\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 249\u001b[0m args, kwargs \u001b[38;5;241m=\u001b[39m _signature_bind(solver_fun_signature, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 250\u001b[0m keys, vals \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mkeys()), \u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mvalues())\n\u001b[0;32m--> 251\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmake_custom_vjp_solver_fun\u001b[49m\u001b[43m(\u001b[49m\u001b[43msolver_fun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mvals\u001b[49m\u001b[43m)\u001b[49m\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/jax/_src/custom_derivatives.py:622\u001b[0m, in \u001b[0;36mcustom_vjp.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 619\u001b[0m flat_fwd, out_trees \u001b[38;5;241m=\u001b[39m _flatten_fwd(fwd, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msymbolic_zeros, primal_name,\n\u001b[1;32m 620\u001b[0m fwd_name, in_tree, out_type)\n\u001b[1;32m 621\u001b[0m flat_bwd \u001b[38;5;241m=\u001b[39m _flatten_bwd(bwd, in_tree, in_avals, out_trees)\u001b[38;5;241m.\u001b[39mcall_wrapped\n\u001b[0;32m--> 622\u001b[0m out_flat \u001b[38;5;241m=\u001b[39m \u001b[43mcustom_vjp_call_p\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[43mflat_fun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflat_fwd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflat_bwd\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 623\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs_flat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_trees\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout_trees\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 624\u001b[0m \u001b[43m \u001b[49m\u001b[43msymbolic_zeros\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msymbolic_zeros\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 625\u001b[0m _, (out_tree, _) \u001b[38;5;241m=\u001b[39m lu\u001b[38;5;241m.\u001b[39mmerge_linear_aux(out_type, out_trees)\n\u001b[1;32m 626\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tree_unflatten(out_tree, out_flat)\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/jax/_src/custom_derivatives.py:806\u001b[0m, in \u001b[0;36mCustomVJPCallPrimitive.bind\u001b[0;34m(self, fun, fwd, bwd, out_trees, symbolic_zeros, *args)\u001b[0m\n\u001b[1;32m 804\u001b[0m tracers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmap\u001b[39m(top_trace\u001b[38;5;241m.\u001b[39mfull_raise, args) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 805\u001b[0m bwd_ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39margs: bwd(\u001b[38;5;241m*\u001b[39margs)\n\u001b[0;32m--> 806\u001b[0m outs \u001b[38;5;241m=\u001b[39m \u001b[43mtop_trace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocess_custom_vjp_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfwd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbwd_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtracers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 807\u001b[0m \u001b[43m \u001b[49m\u001b[43mout_trees\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout_trees\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 808\u001b[0m \u001b[43m \u001b[49m\u001b[43msymbolic_zeros\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msymbolic_zeros\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 809\u001b[0m fst, env_trace_todo \u001b[38;5;241m=\u001b[39m lu\u001b[38;5;241m.\u001b[39mmerge_linear_aux(env_trace_todo1, env_trace_todo2)\n\u001b[1;32m 810\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fst:\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/jax/_src/core.py:932\u001b[0m, in \u001b[0;36mEvalTrace.process_custom_vjp_call\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 930\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m primitive, fwd, bwd, _ \u001b[38;5;66;03m# Unused.\u001b[39;00m\n\u001b[1;32m 931\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m new_sublevel():\n\u001b[0;32m--> 932\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfun\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall_wrapped\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtracers\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/jax/_src/linear_util.py:192\u001b[0m, in \u001b[0;36mWrappedFun.call_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 189\u001b[0m gen \u001b[38;5;241m=\u001b[39m gen_static_args \u001b[38;5;241m=\u001b[39m out_store \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 191\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 192\u001b[0m ans \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# Some transformations yield from inside context managers, so we have to\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# interrupt them before reraising the exception. Otherwise they will only\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;66;03m# get garbage-collected at some later time, running their cleanup tasks\u001b[39;00m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;66;03m# only after this exception is handled, which can corrupt the global\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[38;5;66;03m# state.\u001b[39;00m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m stack:\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/jaxopt/_src/implicit_diff.py:207\u001b[0m, in \u001b[0;36m_custom_root..make_custom_vjp_solver_fun..solver_fun_flat\u001b[0;34m(*flat_args)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;129m@jax\u001b[39m\u001b[38;5;241m.\u001b[39mcustom_vjp\n\u001b[1;32m 205\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msolver_fun_flat\u001b[39m(\u001b[38;5;241m*\u001b[39mflat_args):\n\u001b[1;32m 206\u001b[0m args, kwargs \u001b[38;5;241m=\u001b[39m _extract_kwargs(kwarg_keys, flat_args)\n\u001b[0;32m--> 207\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msolver_fun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/jaxopt/_src/scipy_wrappers.py:459\u001b[0m, in \u001b[0;36mScipyBoundedMinimize.run\u001b[0;34m(self, init_params, bounds, *args, **kwargs)\u001b[0m\n\u001b[1;32m 443\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 444\u001b[0m init_params: Any,\n\u001b[1;32m 445\u001b[0m bounds: Optional[Any],\n\u001b[1;32m 446\u001b[0m \u001b[38;5;241m*\u001b[39margs,\n\u001b[1;32m 447\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m base\u001b[38;5;241m.\u001b[39mOptStep:\n\u001b[1;32m 448\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Runs the solver.\u001b[39;00m\n\u001b[1;32m 449\u001b[0m \n\u001b[1;32m 450\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 457\u001b[0m \u001b[38;5;124;03m (params, info).\u001b[39;00m\n\u001b[1;32m 458\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 459\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43minit_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbounds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/jaxopt/_src/scipy_wrappers.py:343\u001b[0m, in \u001b[0;36mScipyMinimize._run\u001b[0;34m(self, init_params, bounds, *args, **kwargs)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bounds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 340\u001b[0m bounds \u001b[38;5;241m=\u001b[39m osp\u001b[38;5;241m.\u001b[39moptimize\u001b[38;5;241m.\u001b[39mBounds(lb\u001b[38;5;241m=\u001b[39mjnp_to_onp(bounds[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype),\n\u001b[1;32m 341\u001b[0m ub\u001b[38;5;241m=\u001b[39mjnp_to_onp(bounds[\u001b[38;5;241m1\u001b[39m], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype))\n\u001b[0;32m--> 343\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mosp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimize\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mminimize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mscipy_fun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjnp_to_onp\u001b[49m\u001b[43m(\u001b[49m\u001b[43minit_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 344\u001b[0m \u001b[43m \u001b[49m\u001b[43mjac\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 345\u001b[0m \u001b[43m \u001b[49m\u001b[43mtol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtol\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 346\u001b[0m \u001b[43m \u001b[49m\u001b[43mbounds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbounds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 348\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscipy_callback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 349\u001b[0m \u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 351\u001b[0m params \u001b[38;5;241m=\u001b[39m tree_util\u001b[38;5;241m.\u001b[39mtree_map(jnp\u001b[38;5;241m.\u001b[39masarray, onp_to_jnp(res\u001b[38;5;241m.\u001b[39mx))\n\u001b[1;32m 353\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(res, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhess_inv\u001b[39m\u001b[38;5;124m'\u001b[39m):\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/scipy/optimize/_minimize.py:713\u001b[0m, in \u001b[0;36mminimize\u001b[0;34m(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)\u001b[0m\n\u001b[1;32m 710\u001b[0m res \u001b[38;5;241m=\u001b[39m _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,\n\u001b[1;32m 711\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39moptions)\n\u001b[1;32m 712\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m meth \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124ml-bfgs-b\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m--> 713\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43m_minimize_lbfgsb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjac\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbounds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 714\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m meth \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtnc\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 716\u001b[0m res \u001b[38;5;241m=\u001b[39m _minimize_tnc(fun, x0, args, jac, bounds, callback\u001b[38;5;241m=\u001b[39mcallback,\n\u001b[1;32m 717\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39moptions)\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/scipy/optimize/_lbfgsb_py.py:407\u001b[0m, in \u001b[0;36m_minimize_lbfgsb\u001b[0;34m(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)\u001b[0m\n\u001b[1;32m 401\u001b[0m task_str \u001b[38;5;241m=\u001b[39m task\u001b[38;5;241m.\u001b[39mtobytes()\n\u001b[1;32m 402\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m task_str\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFG\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m 403\u001b[0m \u001b[38;5;66;03m# The minimization routine wants f and g at the current x.\u001b[39;00m\n\u001b[1;32m 404\u001b[0m \u001b[38;5;66;03m# Note that interruptions due to maxfun are postponed\u001b[39;00m\n\u001b[1;32m 405\u001b[0m \u001b[38;5;66;03m# until the completion of the current minimization iteration.\u001b[39;00m\n\u001b[1;32m 406\u001b[0m \u001b[38;5;66;03m# Overwrite f and g:\u001b[39;00m\n\u001b[0;32m--> 407\u001b[0m f, g \u001b[38;5;241m=\u001b[39m \u001b[43mfunc_and_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 408\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m task_str\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNEW_X\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m 409\u001b[0m \u001b[38;5;66;03m# new iteration\u001b[39;00m\n\u001b[1;32m 410\u001b[0m n_iterations \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:296\u001b[0m, in \u001b[0;36mScalarFunction.fun_and_grad\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np\u001b[38;5;241m.\u001b[39marray_equal(x, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx):\n\u001b[1;32m 295\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_x_impl(x)\n\u001b[0;32m--> 296\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_update_fun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_grad()\n\u001b[1;32m 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mf, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mg\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:262\u001b[0m, in \u001b[0;36mScalarFunction._update_fun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_update_fun\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 261\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mf_updated:\n\u001b[0;32m--> 262\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_update_fun_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mf_updated \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:163\u001b[0m, in \u001b[0;36mScalarFunction.__init__..update_fun\u001b[0;34m()\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mupdate_fun\u001b[39m():\n\u001b[0;32m--> 163\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mf \u001b[38;5;241m=\u001b[39m \u001b[43mfun_wrapped\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:145\u001b[0m, in \u001b[0;36mScalarFunction.__init__..fun_wrapped\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnfev \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;66;03m# Send a copy because the user may overwrite it.\u001b[39;00m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;66;03m# Overwriting results in undefined behaviour because\u001b[39;00m\n\u001b[1;32m 144\u001b[0m \u001b[38;5;66;03m# fun(self.x) will change self.x, with the two no longer linked.\u001b[39;00m\n\u001b[0;32m--> 145\u001b[0m fx \u001b[38;5;241m=\u001b[39m \u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;66;03m# Make sure the function returns a true scalar\u001b[39;00m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np\u001b[38;5;241m.\u001b[39misscalar(fx):\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/scipy/optimize/_optimize.py:79\u001b[0m, in \u001b[0;36mMemoizeJac.__call__\u001b[0;34m(self, x, *args)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, \u001b[38;5;241m*\u001b[39margs):\n\u001b[1;32m 78\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\" returns the function value \"\"\"\u001b[39;00m\n\u001b[0;32m---> 79\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_compute_if_needed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_value\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/scipy/optimize/_optimize.py:73\u001b[0m, in \u001b[0;36mMemoizeJac._compute_if_needed\u001b[0;34m(self, x, *args)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np\u001b[38;5;241m.\u001b[39mall(x \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mjac \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mx \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39masarray(x)\u001b[38;5;241m.\u001b[39mcopy()\n\u001b[0;32m---> 73\u001b[0m fg \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mjac \u001b[38;5;241m=\u001b[39m fg[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_value \u001b[38;5;241m=\u001b[39m fg[\u001b[38;5;241m0\u001b[39m]\n", + "File \u001b[0;32m~/miniconda3/envs/quocsproj3/lib/python3.11/site-packages/jaxopt/_src/scipy_wrappers.py:336\u001b[0m, in \u001b[0;36mScipyMinimize._run..scipy_fun\u001b[0;34m(x_onp)\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mscipy_fun\u001b[39m(x_onp: onp\u001b[38;5;241m.\u001b[39mndarray) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[onp\u001b[38;5;241m.\u001b[39mndarray, onp\u001b[38;5;241m.\u001b[39mndarray]:\n\u001b[1;32m 335\u001b[0m x_jnp \u001b[38;5;241m=\u001b[39m onp_to_jnp(x_onp)\n\u001b[0;32m--> 336\u001b[0m value, grads \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_value_and_grad_fun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_jnp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 337\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m onp\u001b[38;5;241m.\u001b[39masarray(value, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype), jnp_to_onp(grads, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "lbounds = [-100.0] * model.n_slices\n", + "ubounds = [100.0] * model.n_slices\n", + "bounds = (lbounds, ubounds)\n", + "lbfgsb = ScipyBoundedMinimize(fun=loss_function_sparse, method=\"l-bfgs-b\", options={'disp': True})\n", + "result_bounded = lbfgsb.run(init_params, bounds=bounds, model=model_8, timegrid=timegrid)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Stocastic optimization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "start_learning_rate = 0.9\n", + "# optimizer = optax.adamw(start_learning_rate)\n", + "# optimizer = optax.adamax(start_learning_rate)\n", + "optimizer = optax.sgd(start_learning_rate)\n", + "# init_params = jnp.zeros(model.n_slices, dtype=jnp.complex128)\n", + "opt_state = optimizer.init(init_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "params = init_params\n", + "loss_value, grads = jax.value_and_grad(loss_function)(params, model, timegrid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "params = init_params\n", + "for i in range(1000):\n", + " loss_value, grads = jax.value_and_grad(loss_function)(params, model, timegrid)\n", + " print(f\"Loss {i}: {loss_value}\")\n", + " # print(f\"Grads {i}: {grads}\")\n", + " if loss_value < 1e-6:\n", + " break\n", + " # updates, opt_state = optimizer.update(grads, opt_state) for standard adam\n", + " updates, opt_state = optimizer.update(grads, opt_state, params)\n", + " params = optax.apply_updates(params, updates)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "quocsproj3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/quocslib/timeevolution/piecewise_integrator_AD.py b/src/quocslib/timeevolution/piecewise_integrator_AD.py index 41728d0..8365aec 100644 --- a/src/quocslib/timeevolution/piecewise_integrator_AD.py +++ b/src/quocslib/timeevolution/piecewise_integrator_AD.py @@ -15,6 +15,7 @@ # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import jax import jax.scipy as jsp +import jax.numpy as jnp from functools import partial @@ -83,7 +84,7 @@ def pw_evolution_AD(U_store, drive, A, B, n_slices, dt): # return U -@partial(jax.jit, static_argnames=["n_slices"]) +@partial(jax.jit, static_argnames=["n_slices", "dt"]) def pw_final_evolution_AD(drive, A, B, n_slices, dt, U0): """ Computes the piecewise evolution of a system defined by the @@ -98,9 +99,9 @@ def pw_final_evolution_AD(drive, A, B, n_slices, dt, U0): :return np.matrix: The final propagator """ U = U0 - + K = len(B) def body_fun(i, val): - K = len(B) + # K = len(B) H = A for k in range(K): H = H + drive[k, i] * B[k] @@ -109,3 +110,30 @@ def body_fun(i, val): U = jax.lax.fori_loop(0, n_slices, body_fun, U) return U + +@partial(jax.jit, static_argnames=["n_slices", "n_controls"]) +def pw_final_evolution_AD_scan(drive: jnp.ndarray, A: jnp.ndarray, B: jnp.ndarray, n_slices: int, dt: float, U0: jnp.ndarray, n_controls: int): + """ + Computes the piecewise evolution of a system defined by the + Hamiltonian H = A + drive * B and concatenate all the propagators + + :param np.array drive: An array of dimension n_controls x n_slices that contains the amplitudes of the pulse + :param np.matrix A: The drift Hamiltonian + :param List[np.matrix] B: The control Hamiltonians in a list + :param int n_slices: Number of slices + :param float dt: The duration of each time slice + :param np.matrix U0: The initial propagator to start from + :return np.matrix: The final propagator + """ + U = U0 + H_drive = jnp.transpose(drive) + def body_fun(carry, drive_i): + val, i = carry + H = A + for k in range(n_controls): + H = H + drive_i[k] * B[k] + Uint = jsp.linalg.expm(-1.0j * dt * H) + return (Uint @ val, i+1), None + + (U, _), _ = jax.lax.scan(body_fun, init=(U0, 0), xs=H_drive) + return U