Skip to content

Commit f1e0bb9

Browse files
committed
generation fixes
1 parent c69592e commit f1e0bb9

File tree

3 files changed

+9
-27
lines changed

3 files changed

+9
-27
lines changed

code/generating.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import argparse
1717
from FlowersDataset import ImageCaption102FlowersDataset
1818
import utils
19+
import datasets
1920
from PIL import Image
20-
from sentence_transformers import SentenceTransformer
2121

2222
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Code ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2323

@@ -42,9 +42,6 @@ def generate(args, dataset, transformations, device):
4242
txt2im_model.eval()
4343
im2txt_model.eval()
4444

45-
sentence_similarity_model = SentenceTransformer(args["sentence_similarity_name"])
46-
47-
4845
# generate new images and sentence from the test set
4946
_, _, test_loader = data_utils.get_loaders(args, dataset)
5047

@@ -55,7 +52,7 @@ def generate(args, dataset, transformations, device):
5552
os.makedirs(gens_dir, exist_ok=True)
5653

5754
if args["text"] is None and args["img_path"] is None:
58-
generate_test_examples(device, gens_dir, im2txt_model, test_loader, txt2im_model, sentence_similarity_model)
55+
generate_test_examples(device, gens_dir, im2txt_model, test_loader, txt2im_model)
5956
else:
6057
if args["text"] is not None:
6158
generate_custom_images_examples(args["text"], args["amount"], device, gens_dir, txt2im_model)
@@ -113,18 +110,20 @@ def generate_custom_images_examples(text, amount, device, gens_dir, txt2im_model
113110
plt.imsave(os.path.join(gens_dir, f'im_{" ".join(text.split(" ")[:5])}_{i}.png'), gen_im)
114111

115112

116-
def generate_test_examples(device, gens_dir, im2txt_model, test_loader, txt2im_model, sentence_similarity_model):
113+
def generate_test_examples(device, gens_dir, im2txt_model, test_loader, txt2im_model):
117114
"""Generate new text and image from the test set
118115
119116
:param device: device to use
120117
:param gens_dir: save the generated text and image in this file
121118
:param im2txt_model: trained im2txt model
122119
:param test_loader: test set data loader
123120
:param txt2im_model: trained txt2im model
124-
:param sentence_similarity_model: model for calculating sentences similarity
125121
"""
126122
deTensor = transforms.ToPILImage()
127-
sentence_similarity = 0.0
123+
bleu = datasets.load_metric('bleu')
124+
rouge = datasets.load_metric('rouge')
125+
meteor = datasets.load_metric('meteor')
126+
128127
with torch.no_grad():
129128
for i, (gt_im, txt_tokens, _, im_idx, txt_idx) in enumerate(test_loader):
130129
torch.cuda.empty_cache()
@@ -145,11 +144,11 @@ def generate_test_examples(device, gens_dir, im2txt_model, test_loader, txt2im_m
145144
gt_im = [deTensor(x) for x in gt_im]
146145

147146
# feed the generated image to the im2txt model to generate new sentences
148-
gen_tokens = im2txt_model.generate(gen_im)
147+
gen_tokens = im2txt_model.generate(gt_im)
149148
gen_sentence = im2txt_model.decode_text(gen_tokens)
150149
gen_sentence = [s.strip() for s in gen_sentence]
151150

152-
sentence_similarity += utils.sentence_similarity(sentence_similarity_model, gen_sentence, gt_sentence)
151+
#bleu.add_batch(predictions=, references=)
153152

154153
# create an image with the gt image and sentence and gen image and sentence
155154
for j in range(len(gen_im)):

code/utils.py

-16
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
function _gen_unique_out_dir_path - generate a unique name for each new out dir
44
function create_output_dir - create a new output dir
55
function set_seed - set a seed for reproducibility
6-
function sentence_similarity - calculate similarity between sentences
76
"""
87

98
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ Imports ~~~~~~~~~~~~~~~~~~~~~~~
@@ -12,7 +11,6 @@
1211
import torch
1312
import random
1413
import numpy as np
15-
from sentence_transformers import util
1614

1715
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ Code ~~~~~~~~~~~~~~~~~~~~~~~~~~
1816

@@ -65,17 +63,3 @@ def set_seed(seed=42):
6563
torch.cuda.manual_seed_all(seed)
6664
torch.backends.cudnn.deterministic = True
6765
torch.backends.cudnn.benchmark = False
68-
69-
70-
def sentence_similarity(model, gen_sentences, gt_sentences):
71-
"""Calculate the cosine similarity between sentences
72-
73-
:param model: sentence transformer model
74-
:param gen_sentences: generated sentences
75-
:param gt_sentences: ground truth sentences
76-
:return: the cosine similarity between the generated sentences and the gt sentences a
77-
"""
78-
gt_sentences_embeddings = model.encode(gt_sentences)
79-
gen_sentences_embeddings = model.encode(gen_sentences)
80-
81-
return util.cos_sim(gt_sentences_embeddings, gen_sentences_embeddings)

configs/config.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
output_dir: './results'
22
db_path: '../databases/102flowers'
33
db_type: 'flowers'
4-
sentence_similarity_name: 'bert-base-nli-mean-tokens'
54

65
training_args:
76
batch_size: 12

0 commit comments

Comments
 (0)