From b26b0cf4b3296efdb401e04e32b5860958cf6164 Mon Sep 17 00:00:00 2001 From: jaberkow Date: Mon, 14 Dec 2020 18:16:33 -0800 Subject: [PATCH] added documentation to detection_utils --- LICENSE | 2 +- detection_utils.py | 34 ++++++++++++++++++++++---- hypothesis_test.ipynb | 55 ++++++++++++++++++++----------------------- 3 files changed, 57 insertions(+), 34 deletions(-) diff --git a/LICENSE b/LICENSE index 261eeb9..8d789bb 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2020 IQT Labs LLC, All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/detection_utils.py b/detection_utils.py index 2ba2801..f69d1d1 100644 --- a/detection_utils.py +++ b/detection_utils.py @@ -5,7 +5,21 @@ def MMD_test(source_data,target_data,p_val=0.05,preprocess_kwargs={},chunk_size=100, n_permutations=20): """ - Functional wrapper around alibi_detect MMDDrift class, uses gaussian kernel + Functional wrapper around alibi_detect MMDDrift class that uses uses gaussian kernel + (https://docs.seldon.io/projects/alibi-detect/en/stable/api/alibi_detect.cd.mmd.html) + + + Inputs: + source_data - numpy.ndarray of shape (number of source samples,embedding dimension), + samples from the source distribution + target_data - numpy.ndarray of shape (number of target samples,embedding dimension), + samples from the target distribution + p_val - p-value used for the significance of the permutation test. + preprocess_kwargs - Kwargs for a preprocessing function, pass callables under "model" key + chunk_size - Chunk size if dask is used to parallelise the computation. + n_permutations - Number of permutations used in the permutation test. + Outputs: + p - float, empirical p-value determined using the permutation test """ source_size,source_dim = np.shape(source_data) target_size,target_dim = np.shape(target_data) @@ -24,8 +38,20 @@ def MMD_test(source_data,target_data,p_val=0.05,preprocess_kwargs={},chunk_size= def repeated_MMD_test(source_data,target_data,p_val=0.05,preprocess_kwargs={},chunk_size=100, n_permutations=20,n_samples=100,n_splits=5): """ - Repeatedly carry out the MMD test, subsampling the data each time. Returns mean and standard - deviation of the p_values + Repeatedly carry out the MMD test, subsampling the data each time. Returns an array of p-values + Inputs: + source_data - numpy.ndarray of shape (number of source samples,embedding dimension), + samples from the source distribution + target_data - numpy.ndarray of shape (number of target samples,embedding dimension), + samples from the target distribution + p_val - p-value used for the significance of the permutation test. + preprocess_kwargs - Kwargs for a preprocessing function, pass callables under "model" key + chunk_size - Chunk size if dask is used to parallelise the computation. + n_permutations - Number of permutations used in the permutation test. + n_samples - number of samples to use from the source and target data in each subsampling + n_splits - number of different subsamplings to carry out + Outputs: + p_array - np.ndarray of shape (n_splits,), the set of p-values computed """ source_size,source_dim = np.shape(source_data) target_size,target_dim = np.shape(target_data) @@ -38,7 +64,7 @@ def repeated_MMD_test(source_data,target_data,p_val=0.05,preprocess_kwargs={},ch preprocess_kwargs=preprocess_kwargs,n_permutations=n_permutations) p_list.append(p_temp) p_array = np.array(p_list) - return np.mean(p_array),np.std(p_array) + return p_array diff --git a/hypothesis_test.ipynb b/hypothesis_test.ipynb index ebfbd4e..8414d57 100644 --- a/hypothesis_test.ipynb +++ b/hypothesis_test.ipynb @@ -36,7 +36,6 @@ "source": [ "# Change this component to the root of the VOiCES dataset\n", "DATASET_ROOT = '/Users/jberkowitz/Datasets/VOiCES_devkit'\n", - "#DATASET_ROOT = '/mnt/fs03/shared/datasets/VOiCES_devkit'\n", "# Convenience function to add root to data path\n", "add_root = lambda x: os.path.join(DATASET_ROOT,x)" ] @@ -3967,7 +3966,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 96, "metadata": {}, "outputs": [], "source": [ @@ -3988,42 +3987,40 @@ " temp_model = w2v_uae\n", " else:\n", " temp_model = trill_uae\n", - " p_mean,p_std= detection_utils.repeated_MMD_test(source_data,target_data,preprocess_kwargs={'model':temp_model,'batch_size':128},n_permutations=100,n_samples=50,n_splits=10)\n", - " res = {'distractor':distractor_type,'embedding':embedding_type,'p value':p_mean,'p error':p_std}\n", - " res_list.append(res)" + " p_array = detection_utils.repeated_MMD_test(source_data,target_data,preprocess_kwargs={'model':temp_model,'batch_size':128},n_permutations=100,n_samples=50,n_splits=10)\n", + " for p in p_array:\n", + " res = {'distractor':distractor_type,'embedding':embedding_type,'p value':p}\n", + " res_list.append(res)" ] }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 97, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ - " distractor embedding p value p error\n", - "0 none w2v_mean 0.518 0.302185\n", - "1 none w2v_pca 0.612 0.252460\n", - "2 none trill_mean 0.419 0.287453\n", - "3 none trill_pca 0.452 0.281950\n", - "4 musi w2v_mean 0.001 0.003000\n", - "5 musi w2v_pca 0.139 0.152476\n", - "6 musi trill_mean 0.000 0.000000\n", - "7 musi trill_pca 0.184 0.221142\n", - "8 tele w2v_mean 0.001 0.003000\n", - "9 tele w2v_pca 0.319 0.211587\n", - "10 tele trill_mean 0.157 0.253576\n", - "11 tele trill_pca 0.246 0.277712\n", - "12 babb w2v_mean 0.000 0.000000\n", - "13 babb w2v_pca 0.070 0.115672\n", - "14 babb trill_mean 0.000 0.000000\n", - "15 babb trill_pca 0.021 0.046787" + " distractor embedding p value\n", + "0 none w2v_mean 0.41\n", + "1 none w2v_mean 0.31\n", + "2 none w2v_mean 0.38\n", + "3 none w2v_mean 0.53\n", + "4 none w2v_mean 0.88\n", + ".. ... ... ...\n", + "155 babb trill_pca 0.00\n", + "156 babb trill_pca 0.05\n", + "157 babb trill_pca 0.00\n", + "158 babb trill_pca 0.07\n", + "159 babb trill_pca 0.01\n", + "\n", + "[160 rows x 3 columns]" ], - "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
distractorembeddingp valuep error
0nonew2v_mean0.5180.302185
1nonew2v_pca0.6120.252460
2nonetrill_mean0.4190.287453
3nonetrill_pca0.4520.281950
4musiw2v_mean0.0010.003000
5musiw2v_pca0.1390.152476
6musitrill_mean0.0000.000000
7musitrill_pca0.1840.221142
8telew2v_mean0.0010.003000
9telew2v_pca0.3190.211587
10teletrill_mean0.1570.253576
11teletrill_pca0.2460.277712
12babbw2v_mean0.0000.000000
13babbw2v_pca0.0700.115672
14babbtrill_mean0.0000.000000
15babbtrill_pca0.0210.046787
\n
" + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
distractorembeddingp value
0nonew2v_mean0.41
1nonew2v_mean0.31
2nonew2v_mean0.38
3nonew2v_mean0.53
4nonew2v_mean0.88
............
155babbtrill_pca0.00
156babbtrill_pca0.05
157babbtrill_pca0.00
158babbtrill_pca0.07
159babbtrill_pca0.01
\n

160 rows × 3 columns

\n
" }, "metadata": {}, - "execution_count": 88 + "execution_count": 97 } ], "source": [ @@ -4033,15 +4030,15 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 101, "metadata": {}, "outputs": [ { "output_type": "display_data", "data": { "text/plain": "
", - "image/svg+xml": "\n\n\n\n \n \n \n \n 2020-12-14T00:46:48.592307\n image/svg+xml\n \n \n Matplotlib v3.3.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", - "image/png": "\n" + "image/svg+xml": "\n\n\n\n \n \n \n \n 2020-12-14T17:48:08.125699\n image/svg+xml\n \n \n Matplotlib v3.3.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "\n" }, "metadata": {} } @@ -4051,7 +4048,7 @@ "g = sns.catplot(\n", " data=res_df, kind=\"bar\",\n", " x=\"distractor\", y=\"p value\", hue=\"embedding\",\n", - " ci=\"p error\", palette=\"dark\", alpha=.6, height=6\n", + " palette=\"dark\", alpha=.6, height=6\n", ")\n", "g.despine(left=True)\n", "g.set_axis_labels(\"\", \"p value\",fontsize=15)\n",