16
16
import argparse
17
17
from FlowersDataset import ImageCaption102FlowersDataset
18
18
import utils
19
+ import datasets
19
20
from PIL import Image
20
- from sentence_transformers import SentenceTransformer
21
21
22
22
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Code ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23
23
@@ -42,9 +42,6 @@ def generate(args, dataset, transformations, device):
42
42
txt2im_model .eval ()
43
43
im2txt_model .eval ()
44
44
45
- sentence_similarity_model = SentenceTransformer (args ["sentence_similarity_name" ])
46
-
47
-
48
45
# generate new images and sentence from the test set
49
46
_ , _ , test_loader = data_utils .get_loaders (args , dataset )
50
47
@@ -55,7 +52,7 @@ def generate(args, dataset, transformations, device):
55
52
os .makedirs (gens_dir , exist_ok = True )
56
53
57
54
if args ["text" ] is None and args ["img_path" ] is None :
58
- generate_test_examples (device , gens_dir , im2txt_model , test_loader , txt2im_model , sentence_similarity_model )
55
+ generate_test_examples (device , gens_dir , im2txt_model , test_loader , txt2im_model )
59
56
else :
60
57
if args ["text" ] is not None :
61
58
generate_custom_images_examples (args ["text" ], args ["amount" ], device , gens_dir , txt2im_model )
@@ -113,18 +110,20 @@ def generate_custom_images_examples(text, amount, device, gens_dir, txt2im_model
113
110
plt .imsave (os .path .join (gens_dir , f'im_{ " " .join (text .split (" " )[:5 ])} _{ i } .png' ), gen_im )
114
111
115
112
116
- def generate_test_examples (device , gens_dir , im2txt_model , test_loader , txt2im_model , sentence_similarity_model ):
113
+ def generate_test_examples (device , gens_dir , im2txt_model , test_loader , txt2im_model ):
117
114
"""Generate new text and image from the test set
118
115
119
116
:param device: device to use
120
117
:param gens_dir: save the generated text and image in this file
121
118
:param im2txt_model: trained im2txt model
122
119
:param test_loader: test set data loader
123
120
:param txt2im_model: trained txt2im model
124
- :param sentence_similarity_model: model for calculating sentences similarity
125
121
"""
126
122
deTensor = transforms .ToPILImage ()
127
- sentence_similarity = 0.0
123
+ bleu = datasets .load_metric ('bleu' )
124
+ rouge = datasets .load_metric ('rouge' )
125
+ meteor = datasets .load_metric ('meteor' )
126
+
128
127
with torch .no_grad ():
129
128
for i , (gt_im , txt_tokens , _ , im_idx , txt_idx ) in enumerate (test_loader ):
130
129
torch .cuda .empty_cache ()
@@ -145,11 +144,11 @@ def generate_test_examples(device, gens_dir, im2txt_model, test_loader, txt2im_m
145
144
gt_im = [deTensor (x ) for x in gt_im ]
146
145
147
146
# feed the generated image to the im2txt model to generate new sentences
148
- gen_tokens = im2txt_model .generate (gen_im )
147
+ gen_tokens = im2txt_model .generate (gt_im )
149
148
gen_sentence = im2txt_model .decode_text (gen_tokens )
150
149
gen_sentence = [s .strip () for s in gen_sentence ]
151
150
152
- sentence_similarity += utils . sentence_similarity ( sentence_similarity_model , gen_sentence , gt_sentence )
151
+ #bleu.add_batch(predictions=, references= )
153
152
154
153
# create an image with the gt image and sentence and gen image and sentence
155
154
for j in range (len (gen_im )):
0 commit comments