diff --git a/examples/open_clip/difference.py b/examples/open_clip/difference.py index b2ca1a1b8..4117cea57 100644 --- a/examples/open_clip/difference.py +++ b/examples/open_clip/difference.py @@ -33,18 +33,8 @@ def parse_args(args): def difference(a_file, b_file): - file = open(a_file, "r") - a_variable = eval(file.read()) - file.close() - - file = open(b_file, "r") - b_variable = eval(file.read()) - file.close() - - if not isinstance(b_variable, np.ndarray): - b_variable = np.array(b_variable) - if not isinstance(a_variable, np.ndarray): - a_variable = np.array(a_variable) + a_variable = np.load(a_file) + b_variable = np.load(b_file) if b_variable.shape != a_variable.shape: raise ValueError( diff --git a/examples/open_clip/test.py b/examples/open_clip/test.py index 99ea731c8..1306c6a49 100644 --- a/examples/open_clip/test.py +++ b/examples/open_clip/test.py @@ -12,6 +12,7 @@ import os import sys +import numpy as np from PIL import Image from src.open_clip import create_model_and_transforms, get_tokenizer @@ -84,30 +85,16 @@ def main(args): root = "./" + args.model_name + args.pretrained os.mkdir(root) - # file = open(root + "/image.txt", "w+") - # file.write(str(image.asnumpy().tolist())) - # file.close() - # - # file = open(root + "/text.txt", "w+") - # file.write(str(text.asnumpy().tolist())) - # file.close() - - file = open(root + "/image_features.txt", "w+") - file.write(str(image_features.asnumpy().tolist())) - file.close() - - file = open(root + "/text_features.txt", "w+") - file.write(str(text_features.asnumpy().tolist())) - file.close() + # save as np files. + np.save(os.path.join(root, "image_features.npy"), image_features.asnumpy()) + np.save(os.path.join(root, "text_features.npy"), text_features.asnumpy()) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) text_probs = ops.softmax(100.0 * image_features @ text_features.T, axis=-1) - file = open(root + "/text_probs.txt", "w+") - file.write(str(text_probs.asnumpy().tolist())) - file.close() + np.save(os.path.join(root, "text_probs.npy"), text_probs.asnumpy()) if __name__ == "__main__":