Skip to content

Commit 68b198a

Browse files
committed
Refactoring, adding common utils library w/ method for saving/caching PyTorch embeddings
1 parent b7a29c8 commit 68b198a

File tree

5 files changed

+41
-3
lines changed

5 files changed

+41
-3
lines changed
File renamed without changes.

benchmarking/facenet_benchmarking.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# precomputed tensors. The script will return a dataframe with your testing
99
# data in it as well as write it to a log file.
1010
import os
11+
import sys
1112
import torch
1213
import pandas as pd
1314
import torch.nn.functional as F
@@ -16,7 +17,11 @@
1617
from PIL import Image
1718
from time import time
1819

19-
from common.logging_util import LoggingUtilities # noqa: E402
20+
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
21+
sys.path.append(parent_dir)
22+
23+
from common_utils.logging_util import LoggingUtilities # noqa: E402
24+
from common_utils.general_utilities import GeneralUtils # noqa: E402
2025

2126

2227
class FacenetBenchmarking:
@@ -26,15 +31,17 @@ def __init__(self, photo_path: str, tensor_path: str):
2631
self.logger = LoggingUtilities.\
2732
log_file_logger("facenet_benchmarking")
2833

34+
self.utilities = GeneralUtils()
35+
2936
# check for cuda, configure cudnn
3037
self.cuda_check()
3138

3239
# get models
3340
self.mtcnn, self.resnet = self.get_models()\
3441

3542
# get files
36-
self.photo_files = self.get_file_lists(photo_path)
37-
self.tensor_files = self.get_file_lists(tensor_path)
43+
self.photo_files = self.utilities.get_file_list(photo_path)
44+
self.tensor_files = self.utilities.get_file_list(tensor_path)
3845

3946
self.logger.info('File lists created')
4047

common_utils/__init__.py

Whitespace-only changes.

common_utils/general_utilities.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
from torch import save as ts
3+
from common_utils.logging_util import LoggingUtilities
4+
5+
logger = LoggingUtilities.console_out_logger("general utils")
6+
7+
8+
class GeneralUtils():
9+
10+
def __init__(self):
11+
12+
pass
13+
14+
@staticmethod
15+
def get_file_list(path: str):
16+
17+
file_list = list()
18+
19+
for (dirpath, dirnames, filenames) in os.walk(path):
20+
filenames.sort()
21+
file_list += [os.path.join(dirpath, file) for file in filenames]
22+
23+
return file_list
24+
25+
# save PyTorch tensors
26+
@staticmethod
27+
def save_pytorch_tensors(tensor: object, path: str):
28+
29+
ts(tensor, path)
30+
31+
logger.info(f'Tensor saved to: {path}')
File renamed without changes.

0 commit comments

Comments
 (0)