From 6f24901bb1f5cf035cab25d6b885a3bae1d7d3e7 Mon Sep 17 00:00:00 2001 From: SimonRobertPike Date: Mon, 23 Sep 2024 15:00:51 +0100 Subject: [PATCH 1/6] updates for predict contributions --- .../distributions/distribution_utils.py | 105 ++++++++++++------ lightgbmlss/model.py | 1 + tests/test_model/test_model.py | 20 +++- 3 files changed, 92 insertions(+), 34 deletions(-) diff --git a/lightgbmlss/distributions/distribution_utils.py b/lightgbmlss/distributions/distribution_utils.py index a8e90ee..4882045 100644 --- a/lightgbmlss/distributions/distribution_utils.py +++ b/lightgbmlss/distributions/distribution_utils.py @@ -43,6 +43,7 @@ class DistributionClass: penalize_crossing: bool Whether to include a penalty term to discourage crossing of expectiles. Only used for Expectile distribution. """ + def __init__(self, distribution: torch.distributions.Distribution = None, univariate: bool = True, @@ -375,51 +376,89 @@ def predict_dist(self, Predictions. """ + kwargs = dict() + if pred_type == "contributions": + kwargs["pred_contrib"] = True + n_outputs_per_dist = data.shape[1] + 1 + else: + n_outputs_per_dist = 1 + predt = torch.tensor( - booster.predict(data, raw_score=True), + booster.predict(data, raw_score=True, **kwargs), dtype=torch.float32 - ).reshape(-1, self.n_dist_param) + ).reshape(-1, self.n_dist_param * n_outputs_per_dist) # Set init_score as starting point for each distributional parameter. init_score_pred = torch.tensor( - np.ones(shape=(data.shape[0], 1))*start_values, + np.ones(shape=(data.shape[0], 1)) * start_values, dtype=torch.float32 ) - # The predictions don't include the init_score specified in creating the train data. - # Hence, it needs to be added manually with the corresponding transform for each distributional parameter. - dist_params_predt = np.concatenate( - [ - response_fun( - predt[:, i].reshape(-1, 1) + init_score_pred[:, i].reshape(-1, 1)).numpy() - for i, (dist_param, response_fun) in enumerate(self.param_dict.items()) - ], - axis=1, - ) - dist_params_predt = pd.DataFrame(dist_params_predt) - dist_params_predt.columns = self.param_dict.keys() + if pred_type == "contributions": + CONST_COL = "Const" + COLUMN_LEVELS = ["distribution_arg", "FeatureContribution"] + + feature_columns = data.columns.tolist() + [CONST_COL] + contributions_predt = pd.DataFrame( + predt, + columns=pd.MultiIndex.from_product( + [self.distribution_arg_names, feature_columns], + names=COLUMN_LEVELS + ), + index=data.index, + ) + + init_score_pred_df = pd.DataFrame( + init_score_pred, + columns=pd.MultiIndex.from_product( + [self.distribution_arg_names, ["Const"]], + names=COLUMN_LEVELS + ), + index=data.index + ) + contributions_predt[init_score_pred_df.columns] = ( + contributions_predt[init_score_pred_df.columns] + init_score_pred_df + ) + # Cant include response function on individual feature contributions + return contributions_predt + else: + # The predictions don't include the init_score specified in creating the train data. + # Hence, it needs to be added manually with the corresponding transform for each distributional parameter. + dist_params_predt = np.concatenate( + [ + response_fun( + predt[:, i].reshape(-1, 1) + init_score_pred[:, i].reshape(-1, 1)).numpy() + for i, (dist_param, response_fun) in enumerate(self.param_dict.items()) + ], + axis=1, + ) + dist_params_predt = pd.DataFrame(dist_params_predt) + dist_params_predt.columns = self.param_dict.keys() - # Draw samples from predicted response distribution - pred_samples_df = self.draw_samples(predt_params=dist_params_predt, - n_samples=n_samples, - seed=seed) + if pred_type == "parameters": + return dist_params_predt - if pred_type == "parameters": - return dist_params_predt + elif pred_type == "expectiles": + return dist_params_predt + else: - elif pred_type == "expectiles": - return dist_params_predt + # Draw samples from predicted response distribution + pred_samples_df = self.draw_samples(predt_params=dist_params_predt, + n_samples=n_samples, + seed=seed) - elif pred_type == "samples": - return pred_samples_df + if pred_type == "samples": + return pred_samples_df - elif pred_type == "quantiles": - # Calculate quantiles from predicted response distribution - pred_quant_df = pred_samples_df.quantile(quantiles, axis=1).T - pred_quant_df.columns = [str("quant_") + str(quantiles[i]) for i in range(len(quantiles))] - if self.discrete: - pred_quant_df = pred_quant_df.astype(int) - return pred_quant_df + elif pred_type == "quantiles": + # Calculate quantiles from predicted response distribution + pred_quant_df = pred_samples_df.quantile(quantiles, axis=1).T + pred_quant_df.columns = [str("quant_") + str(quantiles[i]) for i in range(len(quantiles))] + if self.discrete: + pred_quant_df = pred_quant_df.astype(int) + return pred_quant_df + else: + raise RuntimeError(f"{pred_type=} not supported") def compute_gradients_and_hessians(self, loss: torch.tensor, @@ -635,7 +674,7 @@ def dist_select(self, try: loss, params = dist_sel.calculate_start_values(target=target.reshape(-1, 1), max_iter=max_iter) fit_df = pd.DataFrame.from_dict( - {self.loss_fn: loss.reshape(-1,), + {self.loss_fn: loss.reshape(-1, ), "distribution": str(dist_name), "params": [params] } diff --git a/lightgbmlss/model.py b/lightgbmlss/model.py index 33896ce..9a006bc 100644 --- a/lightgbmlss/model.py +++ b/lightgbmlss/model.py @@ -452,6 +452,7 @@ def predict(self, - "quantiles" calculates the quantiles from the predicted distribution. - "parameters" returns the predicted distributional parameters. - "expectiles" returns the predicted expectiles. + - "contributions" returns constibutions of each feature and a constant by calling booster.predict(pred_contrib=True) n_samples : int Number of samples to draw from the predicted distribution. quantiles : List[float] diff --git a/tests/test_model/test_model.py b/tests/test_model/test_model.py index c1f12a1..4fe43ce 100644 --- a/tests/test_model/test_model.py +++ b/tests/test_model/test_model.py @@ -1,3 +1,5 @@ +import numpy as np + from lightgbmlss.model import * from lightgbmlss.distributions.Gaussian import * from lightgbmlss.distributions.Mixture import * @@ -6,6 +8,7 @@ from lightgbmlss.datasets.data_loader import load_simulated_gaussian_data import pytest from pytest import approx +from lightgbmlss.utils import identity_fn @pytest.fixture @@ -109,7 +112,7 @@ def test_model_univ_train_eval(self, univariate_data, univariate_lgblss, univari # Assertions assert isinstance(lgblss.booster, lgb.Booster) - def test_model_hpo(self, univariate_data, univariate_lgblss,): + def test_model_hpo(self, univariate_data, univariate_lgblss, ): # Unpack dtrain, _, _, _ = univariate_data lgblss = univariate_lgblss @@ -155,6 +158,7 @@ def test_model_predict(self, univariate_data, univariate_lgblss, univariate_para pred_params = lgblss.predict(X_test, pred_type="parameters") pred_samples = lgblss.predict(X_test, pred_type="samples", n_samples=n_samples) pred_quantiles = lgblss.predict(X_test, pred_type="quantiles", quantiles=quantiles) + pred_contributions = lgblss.predict(X_test, pred_type="contributions") # Assertions assert isinstance(pred_params, (pd.DataFrame, type(None))) @@ -173,6 +177,20 @@ def test_model_predict(self, univariate_data, univariate_lgblss, univariate_para assert not np.isinf(pred_quantiles).any().any() assert pred_quantiles.shape[1] == len(quantiles) + assert isinstance(pred_contributions, (pd.DataFrame, type(None))) + assert not pred_contributions.isna().any().any() + assert not np.isinf(pred_contributions).any().any() + assert (pred_contributions.shape[1] == + lgblss.dist.n_dist_param * lgblss.dist.n_dist_param * (X_test.shape[1] + 1) + ) + + for key, func in lgblss.dist.param_dict.items(): + if func == identity_fn: + assert np.allclose( + pred_contributions.xs(key, level="distribution_arg", axis=1).sum(axis=1), + pred_params[key], atol=1e-5 + ) + def test_model_plot(self, univariate_data, univariate_lgblss, univariate_params): # Unpack dtrain, dtest, _, X_test = univariate_data From 023211def7b32bf4aa85c73eb314d0b11270cdb2 Mon Sep 17 00:00:00 2001 From: SimonRobertPike Date: Tue, 24 Sep 2024 08:02:20 +0100 Subject: [PATCH 2/6] update the test to check all response functions --- tests/test_model/test_model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/test_model/test_model.py b/tests/test_model/test_model.py index 4fe43ce..b2a26ba 100644 --- a/tests/test_model/test_model.py +++ b/tests/test_model/test_model.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd from lightgbmlss.model import * from lightgbmlss.distributions.Gaussian import * @@ -184,12 +185,17 @@ def test_model_predict(self, univariate_data, univariate_lgblss, univariate_para lgblss.dist.n_dist_param * lgblss.dist.n_dist_param * (X_test.shape[1] + 1) ) - for key, func in lgblss.dist.param_dict.items(): - if func == identity_fn: - assert np.allclose( - pred_contributions.xs(key, level="distribution_arg", axis=1).sum(axis=1), - pred_params[key], atol=1e-5 - ) + for key, response_func in lgblss.dist.param_dict.items(): + pred_contributions_combined = ( + pd.Series(response_func( + torch.tensor( + pred_contributions.xs(key, level="distribution_arg", axis=1).sum(axis=1).values) + ))) + assert np.allclose( + pred_contributions_combined, + pred_params[key], atol=1e-5 + ) + def test_model_plot(self, univariate_data, univariate_lgblss, univariate_params): # Unpack From 98db0100d50a9a61321f73f634331181084011eb Mon Sep 17 00:00:00 2001 From: SimonRobertPike Date: Thu, 26 Sep 2024 11:08:29 +0100 Subject: [PATCH 3/6] add an example --- docs/examples/Predict_Contributions.ipynb | 1153 +++++++++++++++++++++ 1 file changed, 1153 insertions(+) create mode 100644 docs/examples/Predict_Contributions.ipynb diff --git a/docs/examples/Predict_Contributions.ipynb b/docs/examples/Predict_Contributions.ipynb new file mode 100644 index 0000000..54cac6e --- /dev/null +++ b/docs/examples/Predict_Contributions.ipynb @@ -0,0 +1,1153 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Show how to extract prediction contributions for each distribution parameter\n", + "\n", + "This example shows how to get the contribution of every feature for each distributional parameter for a given data set. This allows similar inferences as one might get from SHAP but comes from lightGBM's internal workings. We can use output for example to get for a given prediction which features are causing the most impact to a given distributional parameter.\n", + "\n", + "These contributions are created before the response function is applied. As such in the case of the identity function, for a given row of data the sum of the contributions should equal the parameter value.\n" + ], + "id": "bf95ab4267d5a34" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Imports\n", + "\n", + "First, we import necessary functions. " + ], + "id": "bbea43740b87eb" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T09:33:07.019505Z", + "start_time": "2024-09-26T09:33:01.235342Z" + } + }, + "cell_type": "code", + "source": [ + "import numpy as np\n", + "\n", + "from lightgbmlss.model import *\n", + "from lightgbmlss.distributions.Gaussian import *\n", + "from lightgbmlss.datasets.data_loader import load_simulated_gaussian_data\n", + "from scipy.stats import norm\n", + "\n", + "import plotnine\n", + "from plotnine import *\n", + "\n", + "plotnine.options.figure_size = (12, 8)" + ], + "id": "b5f2d07ce70bb24b", + "outputs": [], + "execution_count": 1 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Data", + "id": "bd7bba77a5e0fa2f" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T09:33:07.067189Z", + "start_time": "2024-09-26T09:33:07.019505Z" + } + }, + "cell_type": "code", + "source": [ + "train, test = load_simulated_gaussian_data()\n", + "\n", + "X_train, y_train = train.filter(regex=\"x\"), train[\"y\"].values\n", + "X_test, y_test = test.filter(regex=\"x\"), test[\"y\"].values\n", + "\n", + "dtrain = lgb.Dataset(X_train, label=y_train)" + ], + "id": "1062b4b851a12bc9", + "outputs": [], + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Get a Trained Model\n", + "\n", + "As this example is about th uses of a trained model, we wont do any hyper-parameter searching. We will also use a Gaussian distribution as the response function of the loc parameter is the identity function, this will allow us to more easily compare the outputs of a standard parameter prediction to a contributions prediction." + ], + "id": "170feafe1dccf85c" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T09:58:18.921694Z", + "start_time": "2024-09-26T09:57:36.453028Z" + } + }, + "cell_type": "code", + "source": [ + "lgblss = LightGBMLSS(\n", + " Gaussian()\n", + ")\n", + "lgblss.train(\n", + " params=dict(),\n", + " train_set=dtrain\n", + ")\n", + "\n", + "param_dict = {\n", + " \"eta\": [\"float\", {\"low\": 1e-5, \"high\": 1, \"log\": True}],\n", + " \"max_depth\": [\"int\", {\"low\": 1, \"high\": 10, \"log\": False}],\n", + " \"num_leaves\": [\"int\", {\"low\": 255, \"high\": 255, \"log\": False}], # set to constant for this example\n", + " \"min_data_in_leaf\": [\"int\", {\"low\": 20, \"high\": 20, \"log\": False}], # set to constant for this example\n", + " \"min_gain_to_split\": [\"float\", {\"low\": 1e-8, \"high\": 40, \"log\": False}],\n", + " \"min_sum_hessian_in_leaf\": [\"float\", {\"low\": 1e-8, \"high\": 500, \"log\": True}],\n", + " \"subsample\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n", + " \"feature_fraction\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n", + " \"boosting\": [\"categorical\", [\"gbdt\"]],\n", + "}\n", + "\n", + "np.random.seed(123)\n", + "opt_param = lgblss.hyper_opt(param_dict,\n", + " dtrain,\n", + " num_boost_round=100, # Number of boosting iterations.\n", + " nfold=5, # Number of cv-folds.\n", + " early_stopping_rounds=20, # Number of early-stopping rounds\n", + " max_minutes=10, # Time budget in minutes, i.e., stop study after the given number of minutes.\n", + " n_trials=30 , # The number of trials. If this argument is set to None, there is no limitation on the number of trials.\n", + " silence=True, # Controls the verbosity of the trail, i.e., user can silence the outputs of the trail.\n", + " seed=123, # Seed used to generate cv-folds.\n", + " hp_seed=123 # Seed for random number generator used in the Bayesian hyperparameter search.\n", + " )\n", + "\n", + "np.random.seed(123)\n", + "\n", + "opt_params = opt_param.copy()\n", + "n_rounds = opt_params[\"opt_rounds\"]\n", + "del opt_params[\"opt_rounds\"]\n", + "\n", + "# Train Model with optimized hyperparameters\n", + "lgblss.train(opt_params,\n", + " dtrain,\n", + " num_boost_round=n_rounds\n", + " )\n" + ], + "id": "f45c868160f1f08b", + "outputs": [ + { + "data": { + "text/plain": [ + " 0%| | 0/30 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FeatureContributionx_truex_noise1x_noise2x_noise3x_noise4x_noise5x_noise6x_noise7x_noise8x_noise9x_noise10Const
00.00.00.00.00.00.00.00.00.00.00.09.979979
10.00.00.00.00.00.00.00.00.00.00.09.979979
20.00.00.00.00.00.00.00.00.00.00.09.979979
30.00.00.00.00.00.00.00.00.00.00.09.979979
40.00.00.00.00.00.00.00.00.00.00.09.979979
.......................................
29950.00.00.00.00.00.00.00.00.00.00.09.979979
29960.00.00.00.00.00.00.00.00.00.00.09.979979
29970.00.00.00.00.00.00.00.00.00.00.09.979979
29980.00.00.00.00.00.00.00.00.00.00.09.979979
29990.00.00.00.00.00.00.00.00.00.00.09.979979
\n", + "

3000 rows × 12 columns

\n", + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 28 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Show contributions for each feature for scale parameter", + "id": "eaf2ad3ecc736152" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T09:58:55.637299Z", + "start_time": "2024-09-26T09:58:55.621654Z" + } + }, + "cell_type": "code", + "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\")\n", + "id": "c5453f7e5a378096", + "outputs": [ + { + "data": { + "text/plain": [ + "FeatureContribution x_true x_noise1 x_noise2 x_noise3 x_noise4 \\\n", + "0 0.410550 0.002106 0.0 0.0 0.000034 \n", + "1 0.411261 0.000684 0.0 0.0 -0.000340 \n", + "2 -0.597674 0.002106 0.0 0.0 -0.000340 \n", + "3 0.848748 0.002832 0.0 0.0 0.000034 \n", + "4 0.414522 0.001565 0.0 0.0 0.000866 \n", + "... ... ... ... ... ... \n", + "2995 0.411230 0.002832 0.0 0.0 -0.000340 \n", + "2996 0.380649 0.002106 0.0 0.0 -0.000340 \n", + "2997 -0.597582 0.001647 0.0 0.0 0.000034 \n", + "2998 -0.607346 -0.001425 0.0 0.0 -0.001143 \n", + "2999 0.410550 0.002106 0.0 0.0 0.000034 \n", + "\n", + "FeatureContribution x_noise5 x_noise6 x_noise7 x_noise8 x_noise9 \\\n", + "0 0.000197 0.004102 -0.000127 0.0 -0.000608 \n", + "1 0.000197 0.004813 -0.000127 0.0 -0.000608 \n", + "2 0.000197 0.004102 -0.000127 0.0 -0.000608 \n", + "3 0.000197 0.001399 -0.000127 0.0 -0.000608 \n", + "4 0.000123 0.002716 -0.004167 0.0 0.053916 \n", + "... ... ... ... ... ... \n", + "2995 0.000197 0.002135 -0.000127 0.0 -0.000608 \n", + "2996 0.000197 0.004400 -0.000127 0.0 -0.000608 \n", + "2997 0.000197 -0.004547 -0.000127 0.0 -0.000700 \n", + "2998 0.000887 0.002013 0.003194 0.0 -0.000029 \n", + "2999 0.000197 0.004102 -0.000127 0.0 -0.000608 \n", + "\n", + "FeatureContribution x_noise10 Const \n", + "0 -0.000503 0.653589 \n", + "1 -0.000129 0.653589 \n", + "2 -0.000129 0.653589 \n", + "3 0.001529 0.653589 \n", + "4 0.001894 0.653589 \n", + "... ... ... \n", + "2995 0.000432 0.653589 \n", + "2996 0.002376 0.653589 \n", + "2997 -0.000893 0.653589 \n", + "2998 -0.004395 0.653589 \n", + "2999 -0.000503 0.653589 \n", + "\n", + "[3000 rows x 12 columns]" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FeatureContributionx_truex_noise1x_noise2x_noise3x_noise4x_noise5x_noise6x_noise7x_noise8x_noise9x_noise10Const
00.4105500.0021060.00.00.0000340.0001970.004102-0.0001270.0-0.000608-0.0005030.653589
10.4112610.0006840.00.0-0.0003400.0001970.004813-0.0001270.0-0.000608-0.0001290.653589
2-0.5976740.0021060.00.0-0.0003400.0001970.004102-0.0001270.0-0.000608-0.0001290.653589
30.8487480.0028320.00.00.0000340.0001970.001399-0.0001270.0-0.0006080.0015290.653589
40.4145220.0015650.00.00.0008660.0001230.002716-0.0041670.00.0539160.0018940.653589
.......................................
29950.4112300.0028320.00.0-0.0003400.0001970.002135-0.0001270.0-0.0006080.0004320.653589
29960.3806490.0021060.00.0-0.0003400.0001970.004400-0.0001270.0-0.0006080.0023760.653589
2997-0.5975820.0016470.00.00.0000340.000197-0.004547-0.0001270.0-0.000700-0.0008930.653589
2998-0.607346-0.0014250.00.0-0.0011430.0008870.0020130.0031940.0-0.000029-0.0043950.653589
29990.4105500.0021060.00.00.0000340.0001970.004102-0.0001270.0-0.000608-0.0005030.653589
\n", + "

3000 rows × 12 columns

\n", + "
" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 29 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Show Mean Feature Impact for Data set", + "id": "394e64d247168fa0" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:02:54.858851Z", + "start_time": "2024-09-26T10:02:54.838744Z" + } + }, + "cell_type": "code", + "source": [ + "sum_of_contributions_column = \"SumOfContributions\"\n", + "mean_parameter_contribution = pred_param_contributions.abs().mean().unstack(\"distribution_arg\")\n", + "mean_parameter_contribution[sum_of_contributions_column] = mean_parameter_contribution.sum(1)\n", + "\n", + "mean_parameter_contribution.sort_values(sum_of_contributions_column, ascending=False).drop(columns=sum_of_contributions_column)\n" + ], + "id": "54d4970cf1957735", + "outputs": [ + { + "data": { + "text/plain": [ + "distribution_arg loc scale\n", + "FeatureContribution \n", + "Const 9.97998 0.653589\n", + "x_true 0.00000 0.591846\n", + "x_noise6 0.00000 0.004865\n", + "x_noise7 0.00000 0.004410\n", + "x_noise1 0.00000 0.003991\n", + "x_noise10 0.00000 0.002688\n", + "x_noise9 0.00000 0.002582\n", + "x_noise4 0.00000 0.001666\n", + "x_noise5 0.00000 0.000585\n", + "x_noise2 0.00000 0.000000\n", + "x_noise3 0.00000 0.000000\n", + "x_noise8 0.00000 0.000000" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
distribution_arglocscale
FeatureContribution
Const9.979980.653589
x_true0.000000.591846
x_noise60.000000.004865
x_noise70.000000.004410
x_noise10.000000.003991
x_noise100.000000.002688
x_noise90.000000.002582
x_noise40.000000.001666
x_noise50.000000.000585
x_noise20.000000.000000
x_noise30.000000.000000
x_noise80.000000.000000
\n", + "
" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 36 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Get correlation between contributions for the scale parameter ", + "id": "f7c73f303f04d4ff" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "", + "id": "8d5dc9e448d5c322" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:07:21.265976Z", + "start_time": "2024-09-26T10:07:21.249801Z" + } + }, + "cell_type": "code", + "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\").corr().dropna(how=\"all\").dropna(axis=1,how=\"all\")\n", + "id": "f331d8603042908", + "outputs": [ + { + "data": { + "text/plain": [ + "FeatureContribution x_true x_noise1 x_noise4 x_noise5 x_noise6 \\\n", + "FeatureContribution \n", + "x_true 1.000000 0.007743 -0.001227 -0.047812 -0.021568 \n", + "x_noise1 0.007743 1.000000 -0.006627 -0.022206 0.136683 \n", + "x_noise4 -0.001227 -0.006627 1.000000 -0.015965 -0.030661 \n", + "x_noise5 -0.047812 -0.022206 -0.015965 1.000000 0.006217 \n", + "x_noise6 -0.021568 0.136683 -0.030661 0.006217 1.000000 \n", + "x_noise7 0.015344 0.002144 0.474505 0.021826 0.029863 \n", + "x_noise9 0.024361 -0.006972 0.013089 0.016001 0.009558 \n", + "x_noise10 0.035479 0.012114 -0.035713 -0.001433 0.028450 \n", + "\n", + "FeatureContribution x_noise7 x_noise9 x_noise10 \n", + "FeatureContribution \n", + "x_true 0.015344 0.024361 0.035479 \n", + "x_noise1 0.002144 -0.006972 0.012114 \n", + "x_noise4 0.474505 0.013089 -0.035713 \n", + "x_noise5 0.021826 0.016001 -0.001433 \n", + "x_noise6 0.029863 0.009558 0.028450 \n", + "x_noise7 1.000000 0.023556 -0.015318 \n", + "x_noise9 0.023556 1.000000 -0.030408 \n", + "x_noise10 -0.015318 -0.030408 1.000000 " + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FeatureContributionx_truex_noise1x_noise4x_noise5x_noise6x_noise7x_noise9x_noise10
FeatureContribution
x_true1.0000000.007743-0.001227-0.047812-0.0215680.0153440.0243610.035479
x_noise10.0077431.000000-0.006627-0.0222060.1366830.002144-0.0069720.012114
x_noise4-0.001227-0.0066271.000000-0.015965-0.0306610.4745050.013089-0.035713
x_noise5-0.047812-0.022206-0.0159651.0000000.0062170.0218260.016001-0.001433
x_noise6-0.0215680.136683-0.0306610.0062171.0000000.0298630.0095580.028450
x_noise70.0153440.0021440.4745050.0218260.0298631.0000000.023556-0.015318
x_noise90.024361-0.0069720.0130890.0160010.0095580.0235561.000000-0.030408
x_noise100.0354790.012114-0.035713-0.0014330.028450-0.015318-0.0304081.000000
\n", + "
" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 44 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "ae0a0247ad688b42" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 9877944e8f3725ea7cdb1c611857ee2e1b125047 Mon Sep 17 00:00:00 2001 From: SimonRobertPike Date: Thu, 26 Sep 2024 11:34:36 +0100 Subject: [PATCH 4/6] add an example --- docs/examples/Predict_Contributions.ipynb | 493 +++++++++++----------- 1 file changed, 249 insertions(+), 244 deletions(-) diff --git a/docs/examples/Predict_Contributions.ipynb b/docs/examples/Predict_Contributions.ipynb index 54cac6e..57ecfec 100644 --- a/docs/examples/Predict_Contributions.ipynb +++ b/docs/examples/Predict_Contributions.ipynb @@ -25,8 +25,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-26T09:33:07.019505Z", - "start_time": "2024-09-26T09:33:01.235342Z" + "end_time": "2024-09-26T10:08:32.499517Z", + "start_time": "2024-09-26T10:08:32.484942Z" } }, "cell_type": "code", @@ -45,7 +45,7 @@ ], "id": "b5f2d07ce70bb24b", "outputs": [], - "execution_count": 1 + "execution_count": 45 }, { "metadata": {}, @@ -56,8 +56,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-26T09:33:07.067189Z", - "start_time": "2024-09-26T09:33:07.019505Z" + "end_time": "2024-09-26T10:08:32.595603Z", + "start_time": "2024-09-26T10:08:32.563920Z" } }, "cell_type": "code", @@ -71,7 +71,7 @@ ], "id": "1062b4b851a12bc9", "outputs": [], - "execution_count": 2 + "execution_count": 46 }, { "metadata": {}, @@ -86,8 +86,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-26T09:58:18.921694Z", - "start_time": "2024-09-26T09:57:36.453028Z" + "end_time": "2024-09-26T10:09:16.586326Z", + "start_time": "2024-09-26T10:08:32.595603Z" } }, "cell_type": "code", @@ -147,7 +147,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "fe4d614b202c4516931ba7cd69d2c733" + "model_id": "45def0cbae7345c2af90d41ce5c331b0" } }, "metadata": {}, @@ -161,7 +161,7 @@ "Hyper-Parameter Optimization successfully finished.\n", " Number of finished trials: 30\n", " Best trial:\n", - " Value: 2.0839056900194977\n", + " Value: 2.0839106241730967\n", " Params: \n", " eta: 0.042322345196562056\n", " max_depth: 3\n", @@ -176,7 +176,7 @@ ] } ], - "execution_count": 25 + "execution_count": 47 }, { "metadata": {}, @@ -187,8 +187,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-26T09:58:28.354001Z", - "start_time": "2024-09-26T09:58:28.322325Z" + "end_time": "2024-09-26T10:09:16.618477Z", + "start_time": "2024-09-26T10:09:16.586326Z" } }, "cell_type": "code", @@ -198,7 +198,7 @@ ], "id": "c0bab6ad5807cd8d", "outputs": [], - "execution_count": 26 + "execution_count": 48 }, { "metadata": {}, @@ -209,8 +209,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-26T09:58:30.937621Z", - "start_time": "2024-09-26T09:58:30.930606Z" + "end_time": "2024-09-26T10:09:16.639879Z", + "start_time": "2024-09-26T10:09:16.618477Z" } }, "cell_type": "code", @@ -241,7 +241,7 @@ ] } ], - "execution_count": 27 + "execution_count": 49 }, { "metadata": {}, @@ -252,8 +252,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-26T09:58:40.765411Z", - "start_time": "2024-09-26T09:58:40.738780Z" + "end_time": "2024-09-26T10:09:16.672598Z", + "start_time": "2024-09-26T10:09:16.642194Z" } }, "cell_type": "code", @@ -290,17 +290,17 @@ "2999 0.0 0.0 0.0 0.0 0.0 \n", "\n", "FeatureContribution Const \n", - "0 9.979979 \n", - "1 9.979979 \n", - "2 9.979979 \n", - "3 9.979979 \n", - "4 9.979979 \n", + "0 9.979578 \n", + "1 9.979578 \n", + "2 9.979578 \n", + "3 9.979578 \n", + "4 9.979578 \n", "... ... \n", - "2995 9.979979 \n", - "2996 9.979979 \n", - "2997 9.979979 \n", - "2998 9.979979 \n", - "2999 9.979979 \n", + "2995 9.979578 \n", + "2996 9.979578 \n", + "2997 9.979578 \n", + "2998 9.979578 \n", + "2999 9.979578 \n", "\n", "[3000 rows x 12 columns]" ], @@ -351,7 +351,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", " 1\n", @@ -366,7 +366,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", " 2\n", @@ -381,7 +381,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", " 3\n", @@ -396,7 +396,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", " 4\n", @@ -411,7 +411,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", " ...\n", @@ -441,7 +441,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", " 2996\n", @@ -456,7 +456,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", " 2997\n", @@ -471,7 +471,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", " 2998\n", @@ -486,7 +486,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", " 2999\n", @@ -501,7 +501,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 9.979979\n", + " 9.979578\n", " \n", " \n", "\n", @@ -509,12 +509,12 @@ "" ] }, - "execution_count": 28, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 28 + "execution_count": 50 }, { "metadata": {}, @@ -525,8 +525,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-26T09:58:55.637299Z", - "start_time": "2024-09-26T09:58:55.621654Z" + "end_time": "2024-09-26T10:09:16.706240Z", + "start_time": "2024-09-26T10:09:16.673598Z" } }, "cell_type": "code", @@ -537,43 +537,43 @@ "data": { "text/plain": [ "FeatureContribution x_true x_noise1 x_noise2 x_noise3 x_noise4 \\\n", - "0 0.410550 0.002106 0.0 0.0 0.000034 \n", - "1 0.411261 0.000684 0.0 0.0 -0.000340 \n", - "2 -0.597674 0.002106 0.0 0.0 -0.000340 \n", - "3 0.848748 0.002832 0.0 0.0 0.000034 \n", - "4 0.414522 0.001565 0.0 0.0 0.000866 \n", + "0 0.410556 0.002107 0.0 0.0 0.000034 \n", + "1 0.411267 0.000683 0.0 0.0 -0.000340 \n", + "2 -0.597710 0.002107 0.0 0.0 -0.000340 \n", + "3 0.848812 0.002835 0.0 0.0 0.000034 \n", + "4 0.414533 0.001566 0.0 0.0 0.000867 \n", "... ... ... ... ... ... \n", - "2995 0.411230 0.002832 0.0 0.0 -0.000340 \n", - "2996 0.380649 0.002106 0.0 0.0 -0.000340 \n", - "2997 -0.597582 0.001647 0.0 0.0 0.000034 \n", - "2998 -0.607346 -0.001425 0.0 0.0 -0.001143 \n", - "2999 0.410550 0.002106 0.0 0.0 0.000034 \n", + "2995 0.411235 0.002835 0.0 0.0 -0.000340 \n", + "2996 0.380668 0.002107 0.0 0.0 -0.000340 \n", + "2997 -0.597620 0.001648 0.0 0.0 0.000034 \n", + "2998 -0.607374 -0.001427 0.0 0.0 -0.001144 \n", + "2999 0.410556 0.002107 0.0 0.0 0.000034 \n", "\n", "FeatureContribution x_noise5 x_noise6 x_noise7 x_noise8 x_noise9 \\\n", - "0 0.000197 0.004102 -0.000127 0.0 -0.000608 \n", - "1 0.000197 0.004813 -0.000127 0.0 -0.000608 \n", - "2 0.000197 0.004102 -0.000127 0.0 -0.000608 \n", - "3 0.000197 0.001399 -0.000127 0.0 -0.000608 \n", - "4 0.000123 0.002716 -0.004167 0.0 0.053916 \n", + "0 0.000197 0.004104 -0.000126 0.0 -0.000608 \n", + "1 0.000197 0.004816 -0.000126 0.0 -0.000608 \n", + "2 0.000197 0.004104 -0.000126 0.0 -0.000608 \n", + "3 0.000197 0.001399 -0.000126 0.0 -0.000608 \n", + "4 0.000123 0.002717 -0.004173 0.0 0.053938 \n", "... ... ... ... ... ... \n", - "2995 0.000197 0.002135 -0.000127 0.0 -0.000608 \n", - "2996 0.000197 0.004400 -0.000127 0.0 -0.000608 \n", - "2997 0.000197 -0.004547 -0.000127 0.0 -0.000700 \n", - "2998 0.000887 0.002013 0.003194 0.0 -0.000029 \n", - "2999 0.000197 0.004102 -0.000127 0.0 -0.000608 \n", + "2995 0.000197 0.002135 -0.000126 0.0 -0.000608 \n", + "2996 0.000197 0.004402 -0.000126 0.0 -0.000608 \n", + "2997 0.000197 -0.004548 -0.000126 0.0 -0.000700 \n", + "2998 0.000888 0.002017 0.003200 0.0 -0.000029 \n", + "2999 0.000197 0.004104 -0.000126 0.0 -0.000608 \n", "\n", "FeatureContribution x_noise10 Const \n", - "0 -0.000503 0.653589 \n", - "1 -0.000129 0.653589 \n", - "2 -0.000129 0.653589 \n", - "3 0.001529 0.653589 \n", - "4 0.001894 0.653589 \n", + "0 -0.000503 0.653625 \n", + "1 -0.000130 0.653625 \n", + "2 -0.000130 0.653625 \n", + "3 0.001530 0.653625 \n", + "4 0.001895 0.653625 \n", "... ... ... \n", - "2995 0.000432 0.653589 \n", - "2996 0.002376 0.653589 \n", - "2997 -0.000893 0.653589 \n", - "2998 -0.004395 0.653589 \n", - "2999 -0.000503 0.653589 \n", + "2995 0.000432 0.653625 \n", + "2996 0.002378 0.653625 \n", + "2997 -0.000892 0.653625 \n", + "2998 -0.004399 0.653625 \n", + "2999 -0.000503 0.653625 \n", "\n", "[3000 rows x 12 columns]" ], @@ -613,78 +613,78 @@ " \n", " \n", " 0\n", - " 0.410550\n", - " 0.002106\n", + " 0.410556\n", + " 0.002107\n", " 0.0\n", " 0.0\n", " 0.000034\n", " 0.000197\n", - " 0.004102\n", - " -0.000127\n", + " 0.004104\n", + " -0.000126\n", " 0.0\n", " -0.000608\n", " -0.000503\n", - " 0.653589\n", + " 0.653625\n", " \n", " \n", " 1\n", - " 0.411261\n", - " 0.000684\n", + " 0.411267\n", + " 0.000683\n", " 0.0\n", " 0.0\n", " -0.000340\n", " 0.000197\n", - " 0.004813\n", - " -0.000127\n", + " 0.004816\n", + " -0.000126\n", " 0.0\n", " -0.000608\n", - " -0.000129\n", - " 0.653589\n", + " -0.000130\n", + " 0.653625\n", " \n", " \n", " 2\n", - " -0.597674\n", - " 0.002106\n", + " -0.597710\n", + " 0.002107\n", " 0.0\n", " 0.0\n", " -0.000340\n", " 0.000197\n", - " 0.004102\n", - " -0.000127\n", + " 0.004104\n", + " -0.000126\n", " 0.0\n", " -0.000608\n", - " -0.000129\n", - " 0.653589\n", + " -0.000130\n", + " 0.653625\n", " \n", " \n", " 3\n", - " 0.848748\n", - " 0.002832\n", + " 0.848812\n", + " 0.002835\n", " 0.0\n", " 0.0\n", " 0.000034\n", " 0.000197\n", " 0.001399\n", - " -0.000127\n", + " -0.000126\n", " 0.0\n", " -0.000608\n", - " 0.001529\n", - " 0.653589\n", + " 0.001530\n", + " 0.653625\n", " \n", " \n", " 4\n", - " 0.414522\n", - " 0.001565\n", + " 0.414533\n", + " 0.001566\n", " 0.0\n", " 0.0\n", - " 0.000866\n", + " 0.000867\n", " 0.000123\n", - " 0.002716\n", - " -0.004167\n", + " 0.002717\n", + " -0.004173\n", " 0.0\n", - " 0.053916\n", - " 0.001894\n", - " 0.653589\n", + " 0.053938\n", + " 0.001895\n", + " 0.653625\n", " \n", " \n", " ...\n", @@ -703,78 +703,78 @@ " \n", " \n", " 2995\n", - " 0.411230\n", - " 0.002832\n", + " 0.411235\n", + " 0.002835\n", " 0.0\n", " 0.0\n", " -0.000340\n", " 0.000197\n", " 0.002135\n", - " -0.000127\n", + " -0.000126\n", " 0.0\n", " -0.000608\n", " 0.000432\n", - " 0.653589\n", + " 0.653625\n", " \n", " \n", " 2996\n", - " 0.380649\n", - " 0.002106\n", + " 0.380668\n", + " 0.002107\n", " 0.0\n", " 0.0\n", " -0.000340\n", " 0.000197\n", - " 0.004400\n", - " -0.000127\n", + " 0.004402\n", + " -0.000126\n", " 0.0\n", " -0.000608\n", - " 0.002376\n", - " 0.653589\n", + " 0.002378\n", + " 0.653625\n", " \n", " \n", " 2997\n", - " -0.597582\n", - " 0.001647\n", + " -0.597620\n", + " 0.001648\n", " 0.0\n", " 0.0\n", " 0.000034\n", " 0.000197\n", - " -0.004547\n", - " -0.000127\n", + " -0.004548\n", + " -0.000126\n", " 0.0\n", " -0.000700\n", - " -0.000893\n", - " 0.653589\n", + " -0.000892\n", + " 0.653625\n", " \n", " \n", " 2998\n", - " -0.607346\n", - " -0.001425\n", + " -0.607374\n", + " -0.001427\n", " 0.0\n", " 0.0\n", - " -0.001143\n", - " 0.000887\n", - " 0.002013\n", - " 0.003194\n", + " -0.001144\n", + " 0.000888\n", + " 0.002017\n", + " 0.003200\n", " 0.0\n", " -0.000029\n", - " -0.004395\n", - " 0.653589\n", + " -0.004399\n", + " 0.653625\n", " \n", " \n", " 2999\n", - " 0.410550\n", - " 0.002106\n", + " 0.410556\n", + " 0.002107\n", " 0.0\n", " 0.0\n", " 0.000034\n", " 0.000197\n", - " 0.004102\n", - " -0.000127\n", + " 0.004104\n", + " -0.000126\n", " 0.0\n", " -0.000608\n", " -0.000503\n", - " 0.653589\n", + " 0.653625\n", " \n", " \n", "\n", @@ -782,12 +782,12 @@ "" ] }, - "execution_count": 29, + "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 29 + "execution_count": 51 }, { "metadata": {}, @@ -798,8 +798,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-26T10:02:54.858851Z", - "start_time": "2024-09-26T10:02:54.838744Z" + "end_time": "2024-09-26T10:09:16.722325Z", + "start_time": "2024-09-26T10:09:16.706781Z" } }, "cell_type": "code", @@ -815,20 +815,20 @@ { "data": { "text/plain": [ - "distribution_arg loc scale\n", - "FeatureContribution \n", - "Const 9.97998 0.653589\n", - "x_true 0.00000 0.591846\n", - "x_noise6 0.00000 0.004865\n", - "x_noise7 0.00000 0.004410\n", - "x_noise1 0.00000 0.003991\n", - "x_noise10 0.00000 0.002688\n", - "x_noise9 0.00000 0.002582\n", - "x_noise4 0.00000 0.001666\n", - "x_noise5 0.00000 0.000585\n", - "x_noise2 0.00000 0.000000\n", - "x_noise3 0.00000 0.000000\n", - "x_noise8 0.00000 0.000000" + "distribution_arg loc scale\n", + "FeatureContribution \n", + "Const 9.979577 0.653625\n", + "x_true 0.000000 0.591884\n", + "x_noise6 0.000000 0.004868\n", + "x_noise7 0.000000 0.004415\n", + "x_noise1 0.000000 0.003994\n", + "x_noise10 0.000000 0.002689\n", + "x_noise9 0.000000 0.002583\n", + "x_noise4 0.000000 0.001668\n", + "x_noise5 0.000000 0.000585\n", + "x_noise2 0.000000 0.000000\n", + "x_noise3 0.000000 0.000000\n", + "x_noise8 0.000000 0.000000" ], "text/html": [ "
\n", @@ -861,62 +861,62 @@ " \n", " \n", " Const\n", - " 9.97998\n", - " 0.653589\n", + " 9.979577\n", + " 0.653625\n", " \n", " \n", " x_true\n", - " 0.00000\n", - " 0.591846\n", + " 0.000000\n", + " 0.591884\n", " \n", " \n", " x_noise6\n", - " 0.00000\n", - " 0.004865\n", + " 0.000000\n", + " 0.004868\n", " \n", " \n", " x_noise7\n", - " 0.00000\n", - " 0.004410\n", + " 0.000000\n", + " 0.004415\n", " \n", " \n", " x_noise1\n", - " 0.00000\n", - " 0.003991\n", + " 0.000000\n", + " 0.003994\n", " \n", " \n", " x_noise10\n", - " 0.00000\n", - " 0.002688\n", + " 0.000000\n", + " 0.002689\n", " \n", " \n", " x_noise9\n", - " 0.00000\n", - " 0.002582\n", + " 0.000000\n", + " 0.002583\n", " \n", " \n", " x_noise4\n", - " 0.00000\n", - " 0.001666\n", + " 0.000000\n", + " 0.001668\n", " \n", " \n", " x_noise5\n", - " 0.00000\n", + " 0.000000\n", " 0.000585\n", " \n", " \n", " x_noise2\n", - " 0.00000\n", + " 0.000000\n", " 0.000000\n", " \n", " \n", " x_noise3\n", - " 0.00000\n", + " 0.000000\n", " 0.000000\n", " \n", " \n", " x_noise8\n", - " 0.00000\n", + " 0.000000\n", " 0.000000\n", " \n", " \n", @@ -924,12 +924,12 @@ "
" ] }, - "execution_count": 36, + "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 36 + "execution_count": 52 }, { "metadata": {}, @@ -946,8 +946,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-09-26T10:07:21.265976Z", - "start_time": "2024-09-26T10:07:21.249801Z" + "end_time": "2024-09-26T10:09:16.737953Z", + "start_time": "2024-09-26T10:09:16.722325Z" } }, "cell_type": "code", @@ -959,25 +959,25 @@ "text/plain": [ "FeatureContribution x_true x_noise1 x_noise4 x_noise5 x_noise6 \\\n", "FeatureContribution \n", - "x_true 1.000000 0.007743 -0.001227 -0.047812 -0.021568 \n", - "x_noise1 0.007743 1.000000 -0.006627 -0.022206 0.136683 \n", - "x_noise4 -0.001227 -0.006627 1.000000 -0.015965 -0.030661 \n", - "x_noise5 -0.047812 -0.022206 -0.015965 1.000000 0.006217 \n", - "x_noise6 -0.021568 0.136683 -0.030661 0.006217 1.000000 \n", - "x_noise7 0.015344 0.002144 0.474505 0.021826 0.029863 \n", - "x_noise9 0.024361 -0.006972 0.013089 0.016001 0.009558 \n", - "x_noise10 0.035479 0.012114 -0.035713 -0.001433 0.028450 \n", + "x_true 1.000000 0.007743 -0.001231 -0.047812 -0.021563 \n", + "x_noise1 0.007743 1.000000 -0.006635 -0.022209 0.136772 \n", + "x_noise4 -0.001231 -0.006635 1.000000 -0.015972 -0.030669 \n", + "x_noise5 -0.047812 -0.022209 -0.015972 1.000000 0.006214 \n", + "x_noise6 -0.021563 0.136772 -0.030669 0.006214 1.000000 \n", + "x_noise7 0.015347 0.002129 0.474525 0.021844 0.029845 \n", + "x_noise9 0.024365 -0.006972 0.013082 0.015998 0.009551 \n", + "x_noise10 0.035477 0.012110 -0.035711 -0.001439 0.028450 \n", "\n", "FeatureContribution x_noise7 x_noise9 x_noise10 \n", "FeatureContribution \n", - "x_true 0.015344 0.024361 0.035479 \n", - "x_noise1 0.002144 -0.006972 0.012114 \n", - "x_noise4 0.474505 0.013089 -0.035713 \n", - "x_noise5 0.021826 0.016001 -0.001433 \n", - "x_noise6 0.029863 0.009558 0.028450 \n", - "x_noise7 1.000000 0.023556 -0.015318 \n", - "x_noise9 0.023556 1.000000 -0.030408 \n", - "x_noise10 -0.015318 -0.030408 1.000000 " + "x_true 0.015347 0.024365 0.035477 \n", + "x_noise1 0.002129 -0.006972 0.012110 \n", + "x_noise4 0.474525 0.013082 -0.035711 \n", + "x_noise5 0.021844 0.015998 -0.001439 \n", + "x_noise6 0.029845 0.009551 0.028450 \n", + "x_noise7 1.000000 0.023553 -0.015334 \n", + "x_noise9 0.023553 1.000000 -0.030410 \n", + "x_noise10 -0.015334 -0.030410 1.000000 " ], "text/html": [ "
\n", @@ -1024,88 +1024,88 @@ " x_true\n", " 1.000000\n", " 0.007743\n", - " -0.001227\n", + " -0.001231\n", " -0.047812\n", - " -0.021568\n", - " 0.015344\n", - " 0.024361\n", - " 0.035479\n", + " -0.021563\n", + " 0.015347\n", + " 0.024365\n", + " 0.035477\n", " \n", " \n", " x_noise1\n", " 0.007743\n", " 1.000000\n", - " -0.006627\n", - " -0.022206\n", - " 0.136683\n", - " 0.002144\n", + " -0.006635\n", + " -0.022209\n", + " 0.136772\n", + " 0.002129\n", " -0.006972\n", - " 0.012114\n", + " 0.012110\n", " \n", " \n", " x_noise4\n", - " -0.001227\n", - " -0.006627\n", + " -0.001231\n", + " -0.006635\n", " 1.000000\n", - " -0.015965\n", - " -0.030661\n", - " 0.474505\n", - " 0.013089\n", - " -0.035713\n", + " -0.015972\n", + " -0.030669\n", + " 0.474525\n", + " 0.013082\n", + " -0.035711\n", " \n", " \n", " x_noise5\n", " -0.047812\n", - " -0.022206\n", - " -0.015965\n", + " -0.022209\n", + " -0.015972\n", " 1.000000\n", - " 0.006217\n", - " 0.021826\n", - " 0.016001\n", - " -0.001433\n", + " 0.006214\n", + " 0.021844\n", + " 0.015998\n", + " -0.001439\n", " \n", " \n", " x_noise6\n", - " -0.021568\n", - " 0.136683\n", - " -0.030661\n", - " 0.006217\n", + " -0.021563\n", + " 0.136772\n", + " -0.030669\n", + " 0.006214\n", " 1.000000\n", - " 0.029863\n", - " 0.009558\n", + " 0.029845\n", + " 0.009551\n", " 0.028450\n", " \n", " \n", " x_noise7\n", - " 0.015344\n", - " 0.002144\n", - " 0.474505\n", - " 0.021826\n", - " 0.029863\n", + " 0.015347\n", + " 0.002129\n", + " 0.474525\n", + " 0.021844\n", + " 0.029845\n", " 1.000000\n", - " 0.023556\n", - " -0.015318\n", + " 0.023553\n", + " -0.015334\n", " \n", " \n", " x_noise9\n", - " 0.024361\n", + " 0.024365\n", " -0.006972\n", - " 0.013089\n", - " 0.016001\n", - " 0.009558\n", - " 0.023556\n", + " 0.013082\n", + " 0.015998\n", + " 0.009551\n", + " 0.023553\n", " 1.000000\n", - " -0.030408\n", + " -0.030410\n", " \n", " \n", " x_noise10\n", - " 0.035479\n", - " 0.012114\n", - " -0.035713\n", - " -0.001433\n", + " 0.035477\n", + " 0.012110\n", + " -0.035711\n", + " -0.001439\n", " 0.028450\n", - " -0.015318\n", - " -0.030408\n", + " -0.015334\n", + " -0.030410\n", " 1.000000\n", " \n", " \n", @@ -1113,20 +1113,25 @@ "
" ] }, - "execution_count": 44, + "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], - "execution_count": 44 + "execution_count": 53 }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:09:16.753632Z", + "start_time": "2024-09-26T10:09:16.738083Z" + } + }, "cell_type": "code", - "outputs": [], - "execution_count": null, "source": "", - "id": "ae0a0247ad688b42" + "id": "ae0a0247ad688b42", + "outputs": [], + "execution_count": 53 } ], "metadata": { From b393dcead7a1cacded08607b8534e5e30e78d6aa Mon Sep 17 00:00:00 2001 From: SimonRobertPike Date: Thu, 3 Oct 2024 09:52:21 +0100 Subject: [PATCH 5/6] change name of multi index to parameters to allign with pred_type argument --- lightgbmlss/distributions/distribution_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightgbmlss/distributions/distribution_utils.py b/lightgbmlss/distributions/distribution_utils.py index 4882045..ee02590 100644 --- a/lightgbmlss/distributions/distribution_utils.py +++ b/lightgbmlss/distributions/distribution_utils.py @@ -396,7 +396,7 @@ def predict_dist(self, if pred_type == "contributions": CONST_COL = "Const" - COLUMN_LEVELS = ["distribution_arg", "FeatureContribution"] + COLUMN_LEVELS = ["parameters", "FeatureContribution"] feature_columns = data.columns.tolist() + [CONST_COL] contributions_predt = pd.DataFrame( From f1780b6a4cdb1b2afb1e345079cbe44a923a111b Mon Sep 17 00:00:00 2001 From: SimonRobertPike Date: Thu, 3 Oct 2024 10:04:20 +0100 Subject: [PATCH 6/6] give columns level names for easier pandas manipulations --- lightgbmlss/distributions/distribution_utils.py | 16 ++++++++++++---- tests/test_model/test_model.py | 10 ++++++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/lightgbmlss/distributions/distribution_utils.py b/lightgbmlss/distributions/distribution_utils.py index ee02590..62ca5ae 100644 --- a/lightgbmlss/distributions/distribution_utils.py +++ b/lightgbmlss/distributions/distribution_utils.py @@ -14,6 +14,7 @@ import warnings + class DistributionClass: """ Generic class that contains general functions for univariate distributions. @@ -396,7 +397,7 @@ def predict_dist(self, if pred_type == "contributions": CONST_COL = "Const" - COLUMN_LEVELS = ["parameters", "FeatureContribution"] + COLUMN_LEVELS = ["parameters", "feature_contributions"] feature_columns = data.columns.tolist() + [CONST_COL] contributions_predt = pd.DataFrame( @@ -432,8 +433,14 @@ def predict_dist(self, ], axis=1, ) - dist_params_predt = pd.DataFrame(dist_params_predt) - dist_params_predt.columns = self.param_dict.keys() + dist_params_predt = pd.DataFrame( + index=data.index, + data=dist_params_predt, + columns=pd.Index( + self.param_dict.keys(), + name=pred_type if pred_type == "expectiles" else "parameters" + ) + ) if pred_type == "parameters": return dist_params_predt @@ -446,7 +453,7 @@ def predict_dist(self, pred_samples_df = self.draw_samples(predt_params=dist_params_predt, n_samples=n_samples, seed=seed) - + pred_samples_df.columns.name = "samples" if pred_type == "samples": return pred_samples_df @@ -456,6 +463,7 @@ def predict_dist(self, pred_quant_df.columns = [str("quant_") + str(quantiles[i]) for i in range(len(quantiles))] if self.discrete: pred_quant_df = pred_quant_df.astype(int) + pred_quant_df.columns.name = "quantiles" return pred_quant_df else: raise RuntimeError(f"{pred_type=} not supported") diff --git a/tests/test_model/test_model.py b/tests/test_model/test_model.py index b2a26ba..60cd6a6 100644 --- a/tests/test_model/test_model.py +++ b/tests/test_model/test_model.py @@ -167,29 +167,35 @@ def test_model_predict(self, univariate_data, univariate_lgblss, univariate_para assert not np.isinf(pred_params).any().any() assert pred_params.shape[1] == lgblss.dist.n_dist_param assert approx(pred_params["loc"].mean(), abs=0.2) == 10.0 + assert pred_params.columns.name == "parameters" assert isinstance(pred_samples, (pd.DataFrame, type(None))) assert not pred_samples.isna().any().any() assert not np.isinf(pred_samples).any().any() assert pred_samples.shape[1] == n_samples + assert pred_samples.columns.name == "samples" assert isinstance(pred_quantiles, (pd.DataFrame, type(None))) assert not pred_quantiles.isna().any().any() assert not np.isinf(pred_quantiles).any().any() assert pred_quantiles.shape[1] == len(quantiles) + assert pred_quantiles.columns.name == "quantiles" assert isinstance(pred_contributions, (pd.DataFrame, type(None))) assert not pred_contributions.isna().any().any() assert not np.isinf(pred_contributions).any().any() assert (pred_contributions.shape[1] == - lgblss.dist.n_dist_param * lgblss.dist.n_dist_param * (X_test.shape[1] + 1) + lgblss.dist.n_dist_param * (X_test.shape[1] + 1) ) + assert pred_contributions.columns.names == ["parameters", "feature_contributions"] + for key, response_func in lgblss.dist.param_dict.items(): + # Sum contributions for each parameter and apply response function pred_contributions_combined = ( pd.Series(response_func( torch.tensor( - pred_contributions.xs(key, level="distribution_arg", axis=1).sum(axis=1).values) + pred_contributions.xs(key, level="parameters", axis=1).sum(axis=1).values) ))) assert np.allclose( pred_contributions_combined,