Skip to content

Commit 3f198ea

Browse files
committedMar 28, 2024
add cloome analysis
1 parent 7a4922f commit 3f198ea

5 files changed

+1515
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Export Data for CLOOB Ablation Study\n",
8+
"### This notebook exports data for the CLOOB ablation analysis done after the interactive article was accepted by VISxAI. "
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": null,
14+
"metadata": {},
15+
"outputs": [],
16+
"source": [
17+
"! pip install git+https://github.com/ginihumer/Amumo.git"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": 3,
23+
"metadata": {},
24+
"outputs": [],
25+
"source": [
26+
"import amumo\n",
27+
"from amumo import data as am_data\n",
28+
"from amumo import utils as am_utils\n",
29+
"from amumo import model as am_model"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": 4,
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"import os\n",
39+
"def create_dir_if_not_exists(dir):\n",
40+
" if not os.path.exists(dir):\n",
41+
" os.mkdir(dir)\n",
42+
" return dir"
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": 5,
48+
"metadata": {},
49+
"outputs": [
50+
{
51+
"data": {
52+
"text/plain": [
53+
"'./exported_data_checkpoints/'"
54+
]
55+
},
56+
"execution_count": 5,
57+
"metadata": {},
58+
"output_type": "execute_result"
59+
}
60+
],
61+
"source": [
62+
"export_directory = './exported_data_checkpoints/'\n",
63+
"create_dir_if_not_exists(export_directory)"
64+
]
65+
},
66+
{
67+
"cell_type": "markdown",
68+
"metadata": {},
69+
"source": [
70+
"### Text-Image"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 34,
76+
"metadata": {},
77+
"outputs": [],
78+
"source": [
79+
"\n",
80+
"\n",
81+
"def export_data(dataset_name, images, prompts, models):\n",
82+
"\n",
83+
" # create folder structure\n",
84+
" dataset_directory = create_dir_if_not_exists(export_directory + dataset_name)\n",
85+
" similarities_dir = create_dir_if_not_exists(dataset_directory + '/similarities')\n",
86+
"\n",
87+
" # export projections and similarities\n",
88+
" import torch\n",
89+
" from sklearn.decomposition import PCA\n",
90+
" from openTSNE import TSNE\n",
91+
" from umap import UMAP\n",
92+
" import numpy as np\n",
93+
" import pandas as pd\n",
94+
" import json\n",
95+
"\n",
96+
" # if there already exists a dataset with projections from prior exports, load it\n",
97+
" if not os.path.exists(dataset_directory + '/projections.csv'):\n",
98+
" projections_df = pd.DataFrame({'emb_id': list(np.arange(0,len(images),1))+list(np.arange(0,len(prompts),1)), 'data_type':['image']*len(images)+['text']*len(prompts)})\n",
99+
" else:\n",
100+
" projections_df = pd.read_csv(dataset_directory + '/projections.csv')\n",
101+
" \n",
102+
"\n",
103+
" for model in models:\n",
104+
" # compute embeddings\n",
105+
" image_embedding_gap, text_embedding_gap, logit_scale = am_utils.get_embedding(model, dataset_name, images, prompts)\n",
106+
" image_embedding_nogap, text_embedding_nogap = am_utils.get_closed_modality_gap(image_embedding_gap, text_embedding_gap)\n",
107+
" \n",
108+
" for image_embedding, text_embedding, mode in [(image_embedding_gap, text_embedding_gap, ''), (image_embedding_nogap, text_embedding_nogap, '_nogap')]:\n",
109+
" \n",
110+
" # compute similarities\n",
111+
" similarity_image_text, similarity = am_utils.get_similarity(image_embedding, text_embedding)\n",
112+
" np.savetxt('%s/%s%s.csv'%(similarities_dir,model.model_name,mode), similarity, delimiter=',')\n",
113+
" \n",
114+
" # compute meta information and similarity clustering\n",
115+
" meta_info = {}\n",
116+
" meta_info['gap_distance'] = float(am_utils.get_modality_distance(image_embedding, text_embedding))\n",
117+
" meta_info['loss'] = float(am_utils.calculate_val_loss(image_embedding, text_embedding, logit_scale.exp()))\n",
118+
"\n",
119+
" idcs, clusters, clusters_unsorted = am_utils.get_cluster_sorting(similarity_image_text)\n",
120+
" cluster_labels = []\n",
121+
" cluster_sizes = []\n",
122+
" for c in set(clusters):\n",
123+
" cluster_size = int(np.count_nonzero(clusters==c))\n",
124+
" cluster_label = am_utils.get_textual_label_for_cluster(np.where(clusters_unsorted==c)[0], prompts)\n",
125+
" cluster_labels.append(cluster_label)\n",
126+
" cluster_sizes.append(cluster_size)\n",
127+
"\n",
128+
" idcs_reverse = np.argsort(idcs)\n",
129+
" meta_info['cluster_sort_idcs'] = idcs.tolist()\n",
130+
" meta_info['cluster_sort_idcs_reverse'] = idcs_reverse.tolist()\n",
131+
" meta_info['cluster_sizes'] = cluster_sizes\n",
132+
" meta_info['cluster_labels'] = cluster_labels\n",
133+
" # print(meta_info)\n",
134+
"\n",
135+
" with open(\"%s/%s%s_meta_info.json\"%(similarities_dir, model.model_name, mode), \"w\") as file:\n",
136+
" json.dump(meta_info, file)\n",
137+
"\n",
138+
" # compute projections\n",
139+
" embedding = np.array(torch.concatenate([image_embedding, text_embedding]))\n",
140+
"\n",
141+
" projection_methods = {\n",
142+
" 'PCA': PCA,\n",
143+
" 'UMAP': UMAP,\n",
144+
" 'TSNE': TSNE\n",
145+
" }\n",
146+
" for method in projection_methods.keys():\n",
147+
" if method == 'PCA':\n",
148+
" proj = projection_methods[method](n_components=2)\n",
149+
" else:\n",
150+
" proj = projection_methods[method](n_components=2, metric='cosine', random_state=31415)\n",
151+
" \n",
152+
" if method == 'TSNE':\n",
153+
" low_dim_data = proj.fit(embedding)\n",
154+
" else:\n",
155+
" low_dim_data = proj.fit_transform(embedding)\n",
156+
" \n",
157+
" projections_df['%s%s_%s_x'%(model.model_name, mode, method)] = low_dim_data[:,0]\n",
158+
" projections_df['%s%s_%s_y'%(model.model_name, mode, method)] = low_dim_data[:,1]\n",
159+
"\n",
160+
"\n",
161+
" projections_df.to_csv(dataset_directory + '/projections.csv')"
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": 35,
167+
"metadata": {},
168+
"outputs": [
169+
{
170+
"name": "stderr",
171+
"output_type": "stream",
172+
"text": [
173+
"C:\\Users\\Christina\\AppData\\Local\\Temp\\ipykernel_31664\\330881050.py:20: FutureWarning: The input object of type 'Image' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Image', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.\n",
174+
" self.all_images = np.array(all_images)\n",
175+
"C:\\Users\\Christina\\AppData\\Local\\Temp\\ipykernel_31664\\330881050.py:20: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
176+
" self.all_images = np.array(all_images)\n"
177+
]
178+
}
179+
],
180+
"source": [
181+
"\n",
182+
"# reuse mscoco subset from previous analysis\n",
183+
"from PIL import Image\n",
184+
"import numpy as np\n",
185+
"\n",
186+
"class Custom_Dataset(am_data.DatasetInterface):\n",
187+
" name = 'MSCOCO-Val'\n",
188+
"\n",
189+
" def __init__(self, path, seed=54, batch_size=None):\n",
190+
" # create triplet dataset if it does not exist\n",
191+
" super().__init__(path, seed, batch_size)\n",
192+
" # path: path to the triplet dataset\n",
193+
" image_paths = [path + \"images/%i.jpg\"%i for i in range(100)]\n",
194+
"\n",
195+
" all_images = []\n",
196+
" for image_path in image_paths:\n",
197+
" with open(image_path, \"rb\") as fopen:\n",
198+
" image = Image.open(fopen).convert(\"RGB\")\n",
199+
" all_images.append(image)\n",
200+
"\n",
201+
" self.all_images = np.array(all_images)\n",
202+
" \n",
203+
" with open(path + \"/prompts.txt\", \"r\") as file:\n",
204+
" self.all_prompts = file.read().splitlines()\n",
205+
"\n",
206+
"mscoco_val_dataset_name = \"MSCOCO-Val_size-100\"\n",
207+
"dataset_mscoco_val = Custom_Dataset(export_directory + mscoco_val_dataset_name + '/')\n",
208+
"mscoco_val_images, mscoco_val_prompts = dataset_mscoco_val.get_data()"
209+
]
210+
},
211+
{
212+
"cell_type": "code",
213+
"execution_count": 37,
214+
"metadata": {},
215+
"outputs": [
216+
{
217+
"name": "stdout",
218+
"output_type": "stream",
219+
"text": [
220+
"found cached embeddings for MSCOCO-Val_size-100_ImageBind_huge\n"
221+
]
222+
}
223+
],
224+
"source": [
225+
"# TODO: export data for the models from the ablation study\n",
226+
"export_data(mscoco_val_dataset_name, mscoco_val_images, mscoco_val_prompts, [am_model.ImageBind_Model()])"
227+
]
228+
},
229+
{
230+
"cell_type": "code",
231+
"execution_count": null,
232+
"metadata": {},
233+
"outputs": [],
234+
"source": []
235+
}
236+
],
237+
"metadata": {
238+
"kernelspec": {
239+
"display_name": "myenv3",
240+
"language": "python",
241+
"name": "python3"
242+
},
243+
"language_info": {
244+
"codemirror_mode": {
245+
"name": "ipython",
246+
"version": 3
247+
},
248+
"file_extension": ".py",
249+
"mimetype": "text/x-python",
250+
"name": "python",
251+
"nbconvert_exporter": "python",
252+
"pygments_lexer": "ipython3",
253+
"version": "3.9.18"
254+
},
255+
"orig_nbformat": 4
256+
},
257+
"nbformat": 4,
258+
"nbformat_minor": 2
259+
}

0 commit comments

Comments
 (0)
Please sign in to comment.