Skip to content

Commit

Permalink
Removed dependency from residual flow repository
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentStimper committed Jul 26, 2022
1 parent ff33d12 commit 314984d
Show file tree
Hide file tree
Showing 6 changed files with 807 additions and 31 deletions.
35 changes: 28 additions & 7 deletions example/residual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# Residual Flow"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# Import required packages\n",
Expand All @@ -29,6 +37,9 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"scrolled": false
},
"outputs": [],
Expand Down Expand Up @@ -71,6 +82,9 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"scrolled": false
},
"outputs": [],
Expand Down Expand Up @@ -103,19 +117,22 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"scrolled": false
},
"outputs": [],
"source": [
"# Train model\n",
"max_iter = 10000\n",
"max_iter = 20000\n",
"num_samples = 2 ** 9\n",
"show_iter = 500\n",
"\n",
"\n",
"loss_hist = np.array([])\n",
"\n",
"optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-3, weight_decay=1e-5)\n",
"optimizer = torch.optim.Adam(nfm.parameters(), lr=3e-4, weight_decay=1e-5)\n",
"for it in tqdm(range(max_iter)):\n",
" optimizer.zero_grad()\n",
" \n",
Expand All @@ -132,7 +149,7 @@
" optimizer.step()\n",
" \n",
" # Make layers Lipschitz continuous\n",
" nf.utils.update_lipschitz(nfm, 5)\n",
" nf.utils.update_lipschitz(nfm, 50)\n",
" \n",
" # Log loss\n",
" loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
Expand Down Expand Up @@ -160,7 +177,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# Plot learned posterior distribution\n",
Expand Down Expand Up @@ -193,7 +214,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
"version": "3.7.11"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 314984d

Please sign in to comment.