Skip to content

Commit

Permalink
Use black[jupyter] in notebooks (pyro-ppl#1162)
Browse files Browse the repository at this point in the history
* use black[jupyter]

* update version number
  • Loading branch information
MarcoGorelli authored Sep 25, 2021
1 parent a8c1cf4 commit 968349e
Show file tree
Hide file tree
Showing 11 changed files with 2,688 additions and 2,556 deletions.
86 changes: 44 additions & 42 deletions notebooks/source/bad_posterior_geometry.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
"from numpyro.diagnostics import summary\n",
"\n",
"from numpyro.infer import MCMC, NUTS\n",
"assert numpyro.__version__.startswith('0.7.2')\n",
"\n",
"assert numpyro.__version__.startswith(\"0.7.2\")\n",
"\n",
"# NB: replace cpu by gpu to run this notebook on gpu\n",
"numpyro.set_platform(\"cpu\")"
Expand All @@ -68,15 +69,11 @@
"metadata": {},
"outputs": [],
"source": [
"def run_inference(model, \n",
" num_warmup=1000, \n",
" num_samples=1000,\n",
" max_tree_depth=10, \n",
" dense_mass=False):\n",
" \n",
" kernel = NUTS(model, \n",
" max_tree_depth=max_tree_depth,\n",
" dense_mass=dense_mass)\n",
"def run_inference(\n",
" model, num_warmup=1000, num_samples=1000, max_tree_depth=10, dense_mass=False\n",
"):\n",
"\n",
" kernel = NUTS(model, max_tree_depth=max_tree_depth, dense_mass=dense_mass)\n",
" mcmc = MCMC(\n",
" kernel,\n",
" num_warmup=num_warmup,\n",
Expand All @@ -85,12 +82,12 @@
" progress_bar=False,\n",
" )\n",
" mcmc.run(random.PRNGKey(0))\n",
" summary_dict = summary(mcmc.get_samples(), group_by_chain=False) \n",
" \n",
" summary_dict = summary(mcmc.get_samples(), group_by_chain=False)\n",
"\n",
" # print the largest r_hat for each variable\n",
" for k, v in summary_dict.items():\n",
" spaces = \" \" * max(12 - len(k), 0)\n",
" print(\"[{}] {} \\t max r_hat: {:.4f}\".format(k, spaces, np.max(v['r_hat'])))"
" print(\"[{}] {} \\t max r_hat: {:.4f}\".format(k, spaces, np.max(v[\"r_hat\"])))"
]
},
{
Expand Down Expand Up @@ -170,15 +167,17 @@
"outputs": [],
"source": [
"# In this reparameterized model none of the parameters of the distributions\n",
"# explicitly depend on other parameters. This model is exactly equivalent \n",
"# explicitly depend on other parameters. This model is exactly equivalent\n",
"# to _unrep_hs_model but is expressed in a different coordinate system.\n",
"def _rep_hs_model1(X, Y):\n",
" lambdas = numpyro.sample(\"lambdas\", dist.HalfCauchy(jnp.ones(X.shape[1])))\n",
" tau = numpyro.sample(\"tau\", dist.HalfCauchy(jnp.ones(1)))\n",
" unscaled_betas = numpyro.sample(\"unscaled_betas\", dist.Normal(scale=jnp.ones(X.shape[1])))\n",
" unscaled_betas = numpyro.sample(\n",
" \"unscaled_betas\", dist.Normal(scale=jnp.ones(X.shape[1]))\n",
" )\n",
" scaled_betas = numpyro.deterministic(\"betas\", tau * lambdas * unscaled_betas)\n",
" mean_function = jnp.dot(X, scaled_betas)\n",
" numpyro.sample(\"Y\", dist.Normal(mean_function, 0.05), obs=Y) "
" numpyro.sample(\"Y\", dist.Normal(mean_function, 0.05), obs=Y)"
]
},
{
Expand All @@ -196,7 +195,8 @@
"metadata": {},
"outputs": [],
"source": [
"from numpyro.infer.reparam import LocScaleReparam \n",
"from numpyro.infer.reparam import LocScaleReparam\n",
"\n",
"# LocScaleReparam with centered=0 fully \"decenters\" the prior over betas.\n",
"config = {\"betas\": LocScaleReparam(centered=0)}\n",
"# The coordinate system of this model is equivalent to that in _rep_hs_model1 above.\n",
Expand All @@ -217,30 +217,31 @@
"outputs": [],
"source": [
"from numpyro.distributions.transforms import AffineTransform\n",
"from numpyro.infer.reparam import TransformReparam \n",
"from numpyro.infer.reparam import TransformReparam\n",
"\n",
"# In this reparameterized model none of the parameters of the distributions\n",
"# explicitly depend on other parameters. This model is exactly equivalent \n",
"# explicitly depend on other parameters. This model is exactly equivalent\n",
"# to _unrep_hs_model but is expressed in a different coordinate system.\n",
"def _rep_hs_model3(X, Y):\n",
" lambdas = numpyro.sample(\"lambdas\", dist.HalfCauchy(jnp.ones(X.shape[1])))\n",
" tau = numpyro.sample(\"tau\", dist.HalfCauchy(jnp.ones(1)))\n",
" \n",
"\n",
" # instruct NumPyro to do the reparameterization automatically.\n",
" reparam_config = {\"betas\": TransformReparam()}\n",
" with numpyro.handlers.reparam(config=reparam_config):\n",
" betas_root_variance = tau * lambdas\n",
" # in order to use TransformReparam we have to express the prior\n",
" # over betas as a TransformedDistribution\n",
" betas = numpyro.sample(\"betas\", \n",
" betas = numpyro.sample(\n",
" \"betas\",\n",
" dist.TransformedDistribution(\n",
" dist.Normal(0., jnp.ones(X.shape[1])), \n",
" AffineTransform(0., betas_root_variance)\n",
" )\n",
" dist.Normal(0.0, jnp.ones(X.shape[1])),\n",
" AffineTransform(0.0, betas_root_variance),\n",
" ),\n",
" )\n",
" \n",
"\n",
" mean_function = jnp.dot(X, betas)\n",
" numpyro.sample(\"Y\", dist.Normal(mean_function, 0.05), obs=Y) "
" numpyro.sample(\"Y\", dist.Normal(mean_function, 0.05), obs=Y)"
]
},
{
Expand Down Expand Up @@ -349,17 +350,17 @@
}
],
"source": [
"# Because rho is very close to 1.0 the posterior geometry \n",
"# is extremely skewed and using the \"diagonal\" coordinate system \n",
"# Because rho is very close to 1.0 the posterior geometry\n",
"# is extremely skewed and using the \"diagonal\" coordinate system\n",
"# implied by dense_mass=False leads to bad results\n",
"rho = 0.9999\n",
"cov = jnp.array([[10.0, rho], [rho, 0.1]])\n",
"\n",
"\n",
"def mvn_model():\n",
" numpyro.sample(\"x\", \n",
" dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)\n",
" )\n",
" \n",
" numpyro.sample(\"x\", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov))\n",
"\n",
"\n",
"print(\"dense_mass = False (bad r_hat)\")\n",
"run_inference(mvn_model, dense_mass=False, max_tree_depth=3)\n",
"\n",
Expand Down Expand Up @@ -391,12 +392,12 @@
"# In this model x1 and x2 are highly correlated with one another\n",
"# but not correlated with y at all.\n",
"def partially_correlated_model():\n",
" x1 = numpyro.sample(\"x1\", \n",
" dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)\n",
" x1 = numpyro.sample(\n",
" \"x1\", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)\n",
" )\n",
" x2 = numpyro.sample(\n",
" \"x2\", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)\n",
" )\n",
" x2 = numpyro.sample(\"x2\", \n",
" dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)\n",
" ) \n",
" y = numpyro.sample(\"y\", dist.Normal(jnp.zeros(100), 1.0))\n",
" numpyro.sample(\"obs\", dist.Normal(x1 - x2, 0.1), jnp.ones(2))"
]
Expand Down Expand Up @@ -484,18 +485,19 @@
}
],
"source": [
"# Because rho is very close to 1.0 the posterior geometry is extremely \n",
"# Because rho is very close to 1.0 the posterior geometry is extremely\n",
"# skewed and using small max_tree_depth leads to bad results.\n",
"rho = 0.999\n",
"dim = 200\n",
"cov = rho * jnp.ones((dim, dim)) + (1 - rho) * jnp.eye(dim)\n",
"\n",
"\n",
"def mvn_model():\n",
" x = numpyro.sample(\"x\", \n",
" dist.MultivariateNormal(jnp.zeros(dim), \n",
" covariance_matrix=cov)\n",
" x = numpyro.sample(\n",
" \"x\", dist.MultivariateNormal(jnp.zeros(dim), covariance_matrix=cov)\n",
" )\n",
" \n",
"\n",
"\n",
"print(\"max_tree_depth = 5 (bad r_hat)\")\n",
"run_inference(mvn_model, max_tree_depth=5)\n",
"\n",
Expand Down
128 changes: 64 additions & 64 deletions notebooks/source/bayesian_hierarchical_linear_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,12 @@
}
],
"source": [
"train = pd.read_csv('https://gist.githubusercontent.com/ucals/'\n",
" '2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/'\n",
" '43034c39052dcf97d4b894d2ec1bc3f90f3623d9/'\n",
" 'osic_pulmonary_fibrosis.csv')\n",
"train = pd.read_csv(\n",
" \"https://gist.githubusercontent.com/ucals/\"\n",
" \"2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/\"\n",
" \"43034c39052dcf97d4b894d2ec1bc3f90f3623d9/\"\n",
" \"osic_pulmonary_fibrosis.csv\"\n",
")\n",
"train.head()"
]
},
Expand Down Expand Up @@ -181,17 +183,17 @@
],
"source": [
"def chart(patient_id, ax):\n",
" data = train[train['Patient'] == patient_id]\n",
" x = data['Weeks']\n",
" y = data['FVC']\n",
" data = train[train[\"Patient\"] == patient_id]\n",
" x = data[\"Weeks\"]\n",
" y = data[\"FVC\"]\n",
" ax.set_title(patient_id)\n",
" ax = sns.regplot(x, y, ax=ax, ci=None, line_kws={'color':'red'})\n",
" \n",
" ax = sns.regplot(x, y, ax=ax, ci=None, line_kws={\"color\": \"red\"})\n",
"\n",
"\n",
"f, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
"chart('ID00007637202177411956430', axes[0])\n",
"chart('ID00009637202177434476278', axes[1])\n",
"chart('ID00010637202177584971671', axes[2])"
"chart(\"ID00007637202177411956430\", axes[0])\n",
"chart(\"ID00009637202177434476278\", axes[1])\n",
"chart(\"ID00010637202177584971671\", axes[2])"
]
},
{
Expand Down Expand Up @@ -242,7 +244,7 @@
"import numpyro.distributions as dist\n",
"from jax import random\n",
"\n",
"assert numpyro.__version__.startswith('0.7.2')"
"assert numpyro.__version__.startswith(\"0.7.2\")"
]
},
{
Expand All @@ -252,21 +254,21 @@
"outputs": [],
"source": [
"def model(PatientID, Weeks, FVC_obs=None):\n",
" μ_α = numpyro.sample(\"μ_α\", dist.Normal(0., 100.))\n",
" σ_α = numpyro.sample(\"σ_α\", dist.HalfNormal(100.))\n",
" μ_β = numpyro.sample(\"μ_β\", dist.Normal(0., 100.))\n",
" σ_β = numpyro.sample(\"σ_β\", dist.HalfNormal(100.))\n",
" \n",
" μ_α = numpyro.sample(\"μ_α\", dist.Normal(0.0, 100.0))\n",
" σ_α = numpyro.sample(\"σ_α\", dist.HalfNormal(100.0))\n",
" μ_β = numpyro.sample(\"μ_β\", dist.Normal(0.0, 100.0))\n",
" σ_β = numpyro.sample(\"σ_β\", dist.HalfNormal(100.0))\n",
"\n",
" unique_patient_IDs = np.unique(PatientID)\n",
" n_patients = len(unique_patient_IDs)\n",
" \n",
"\n",
" with numpyro.plate(\"plate_i\", n_patients):\n",
" α = numpyro.sample(\"α\", dist.Normal(μ_α, σ_α))\n",
" β = numpyro.sample(\"β\", dist.Normal(μ_β, σ_β))\n",
" \n",
" σ = numpyro.sample(\"σ\", dist.HalfNormal(100.))\n",
"\n",
" σ = numpyro.sample(\"σ\", dist.HalfNormal(100.0))\n",
" FVC_est = α[PatientID] + β[PatientID] * Weeks\n",
" \n",
"\n",
" with numpyro.plate(\"data\", len(PatientID)):\n",
" numpyro.sample(\"obs\", dist.Normal(FVC_est, σ), obs=FVC_obs)"
]
Expand All @@ -292,11 +294,11 @@
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"le = LabelEncoder()\n",
"train['PatientID'] = le.fit_transform(train['Patient'].values)\n",
"train[\"PatientID\"] = le.fit_transform(train[\"Patient\"].values)\n",
"\n",
"FVC_obs = train['FVC'].values\n",
"Weeks = train['Weeks'].values\n",
"PatientID = train['PatientID'].values"
"FVC_obs = train[\"FVC\"].values\n",
"Weeks = train[\"Weeks\"].values\n",
"PatientID = train[\"PatientID\"].values"
]
},
{
Expand Down Expand Up @@ -380,10 +382,10 @@
"outputs": [],
"source": [
"pred_template = []\n",
"for i in range(train['Patient'].nunique()):\n",
" df = pd.DataFrame(columns=['PatientID', 'Weeks'])\n",
" df['Weeks'] = np.arange(-12, 134)\n",
" df['PatientID'] = i\n",
"for i in range(train[\"Patient\"].nunique()):\n",
" df = pd.DataFrame(columns=[\"PatientID\", \"Weeks\"])\n",
" df[\"Weeks\"] = np.arange(-12, 134)\n",
" df[\"PatientID\"] = i\n",
" pred_template.append(df)\n",
"pred_template = pd.concat(pred_template, ignore_index=True)"
]
Expand All @@ -401,12 +403,10 @@
"metadata": {},
"outputs": [],
"source": [
"PatientID = pred_template['PatientID'].values\n",
"Weeks = pred_template['Weeks'].values\n",
"predictive = Predictive(model, posterior_samples, \n",
" return_sites=['σ', 'obs'])\n",
"samples_predictive = predictive(random.PRNGKey(0), \n",
" PatientID, Weeks, None)"
"PatientID = pred_template[\"PatientID\"].values\n",
"Weeks = pred_template[\"Weeks\"].values\n",
"predictive = Predictive(model, posterior_samples, return_sites=[\"σ\", \"obs\"])\n",
"samples_predictive = predictive(random.PRNGKey(0), PatientID, Weeks, None)"
]
},
{
Expand Down Expand Up @@ -528,16 +528,17 @@
}
],
"source": [
"df = pd.DataFrame(columns=['Patient', 'Weeks', 'FVC_pred', 'sigma'])\n",
"df['Patient'] = le.inverse_transform(pred_template['PatientID'])\n",
"df['Weeks'] = pred_template['Weeks']\n",
"df['FVC_pred'] = samples_predictive['obs'].T.mean(axis=1)\n",
"df['sigma'] = samples_predictive['obs'].T.std(axis=1)\n",
"df['FVC_inf'] = df['FVC_pred'] - df['sigma']\n",
"df['FVC_sup'] = df['FVC_pred'] + df['sigma']\n",
"df = pd.merge(df, train[['Patient', 'Weeks', 'FVC']], \n",
" how='left', on=['Patient', 'Weeks'])\n",
"df = df.rename(columns={'FVC': 'FVC_true'})\n",
"df = pd.DataFrame(columns=[\"Patient\", \"Weeks\", \"FVC_pred\", \"sigma\"])\n",
"df[\"Patient\"] = le.inverse_transform(pred_template[\"PatientID\"])\n",
"df[\"Weeks\"] = pred_template[\"Weeks\"]\n",
"df[\"FVC_pred\"] = samples_predictive[\"obs\"].T.mean(axis=1)\n",
"df[\"sigma\"] = samples_predictive[\"obs\"].T.std(axis=1)\n",
"df[\"FVC_inf\"] = df[\"FVC_pred\"] - df[\"sigma\"]\n",
"df[\"FVC_sup\"] = df[\"FVC_pred\"] + df[\"sigma\"]\n",
"df = pd.merge(\n",
" df, train[[\"Patient\", \"Weeks\", \"FVC\"]], how=\"left\", on=[\"Patient\", \"Weeks\"]\n",
")\n",
"df = df.rename(columns={\"FVC\": \"FVC_true\"})\n",
"df.head()"
]
},
Expand Down Expand Up @@ -568,21 +569,20 @@
],
"source": [
"def chart(patient_id, ax):\n",
" data = df[df['Patient'] == patient_id]\n",
" x = data['Weeks']\n",
" data = df[df[\"Patient\"] == patient_id]\n",
" x = data[\"Weeks\"]\n",
" ax.set_title(patient_id)\n",
" ax.plot(x, data['FVC_true'], 'o')\n",
" ax.plot(x, data['FVC_pred'])\n",
" ax = sns.regplot(x, data['FVC_true'], ax=ax, ci=None, \n",
" line_kws={'color':'red'})\n",
" ax.fill_between(x, data[\"FVC_inf\"], data[\"FVC_sup\"],\n",
" alpha=0.5, color='#ffcd3c')\n",
" ax.set_ylabel('FVC')\n",
" ax.plot(x, data[\"FVC_true\"], \"o\")\n",
" ax.plot(x, data[\"FVC_pred\"])\n",
" ax = sns.regplot(x, data[\"FVC_true\"], ax=ax, ci=None, line_kws={\"color\": \"red\"})\n",
" ax.fill_between(x, data[\"FVC_inf\"], data[\"FVC_sup\"], alpha=0.5, color=\"#ffcd3c\")\n",
" ax.set_ylabel(\"FVC\")\n",
"\n",
"\n",
"f, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
"chart('ID00007637202177411956430', axes[0])\n",
"chart('ID00009637202177434476278', axes[1])\n",
"chart('ID00011637202177653955184', axes[2])"
"chart(\"ID00007637202177411956430\", axes[0])\n",
"chart(\"ID00009637202177434476278\", axes[1])\n",
"chart(\"ID00011637202177653955184\", axes[2])"
]
},
{
Expand Down Expand Up @@ -628,15 +628,15 @@
],
"source": [
"y = df.dropna()\n",
"rmse = ((y['FVC_pred'] - y['FVC_true']) ** 2).mean() ** (1/2)\n",
"print(f'RMSE: {rmse:.1f} ml')\n",
"rmse = ((y[\"FVC_pred\"] - y[\"FVC_true\"]) ** 2).mean() ** (1 / 2)\n",
"print(f\"RMSE: {rmse:.1f} ml\")\n",
"\n",
"sigma_c = y['sigma'].values\n",
"sigma_c = y[\"sigma\"].values\n",
"sigma_c[sigma_c < 70] = 70\n",
"delta = (y['FVC_pred'] - y['FVC_true']).abs()\n",
"delta = (y[\"FVC_pred\"] - y[\"FVC_true\"]).abs()\n",
"delta[delta > 1000] = 1000\n",
"lll = - np.sqrt(2) * delta / sigma_c - np.log(np.sqrt(2) * sigma_c)\n",
"print(f'Laplace Log Likelihood: {lll.mean():.4f}')"
"lll = -np.sqrt(2) * delta / sigma_c - np.log(np.sqrt(2) * sigma_c)\n",
"print(f\"Laplace Log Likelihood: {lll.mean():.4f}\")"
]
},
{
Expand Down
Loading

0 comments on commit 968349e

Please sign in to comment.