Skip to content

Commit a44e50a

Browse files
committed
clean code
1 parent d845ef9 commit a44e50a

2 files changed

Lines changed: 7 additions & 30 deletions

File tree

examples/open_clip/difference.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,8 @@ def parse_args(args):
3333

3434

3535
def difference(a_file, b_file):
36-
file = open(a_file, "r")
37-
a_variable = eval(file.read())
38-
file.close()
39-
40-
file = open(b_file, "r")
41-
b_variable = eval(file.read())
42-
file.close()
43-
44-
if not isinstance(b_variable, np.ndarray):
45-
b_variable = np.array(b_variable)
46-
if not isinstance(a_variable, np.ndarray):
47-
a_variable = np.array(a_variable)
36+
a_variable = np.load(a_file)
37+
b_variable = np.load(b_file)
4838

4939
if b_variable.shape != a_variable.shape:
5040
raise ValueError(

examples/open_clip/test.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import sys
1414

15+
import numpy as np
1516
from PIL import Image
1617
from src.open_clip import create_model_and_transforms, get_tokenizer
1718

@@ -84,30 +85,16 @@ def main(args):
8485
root = "./" + args.model_name + args.pretrained
8586
os.mkdir(root)
8687

87-
# file = open(root + "/image.txt", "w+")
88-
# file.write(str(image.asnumpy().tolist()))
89-
# file.close()
90-
#
91-
# file = open(root + "/text.txt", "w+")
92-
# file.write(str(text.asnumpy().tolist()))
93-
# file.close()
94-
95-
file = open(root + "/image_features.txt", "w+")
96-
file.write(str(image_features.asnumpy().tolist()))
97-
file.close()
98-
99-
file = open(root + "/text_features.txt", "w+")
100-
file.write(str(text_features.asnumpy().tolist()))
101-
file.close()
88+
# save as np files.
89+
np.save(os.path.join(root, "image_features.npy"), image_features.asnumpy())
90+
np.save(os.path.join(root, "text_features.npy"), text_features.asnumpy())
10291

10392
image_features /= image_features.norm(dim=-1, keepdim=True)
10493
text_features /= text_features.norm(dim=-1, keepdim=True)
10594

10695
text_probs = ops.softmax(100.0 * image_features @ text_features.T, axis=-1)
10796

108-
file = open(root + "/text_probs.txt", "w+")
109-
file.write(str(text_probs.asnumpy().tolist()))
110-
file.close()
97+
np.save(os.path.join(root, "text_probs.npy"), text_probs.asnumpy())
11198

11299

113100
if __name__ == "__main__":

0 commit comments

Comments
 (0)