Skip to content

Commit 1804e96

Browse files
0.9.14
1 parent a514e8f commit 1804e96

19 files changed

+535
-248
lines changed

torchstudio/datasetanalyze.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,23 @@
2727
print("Analyzing...\n", file=sys.stderr)
2828

2929
analysis_server, address = tc.generate_server()
30+
tc.send_msg(app_socket, 'ServerRequestingDataset', tc.encode_strings(address))
31+
32+
dataset_socket=tc.start_server(analysis_server)
33+
34+
tc.send_msg(dataset_socket, 'RequestMetaInfos')
3035

3136
if analyzer_env['analyzer'].train is None:
32-
request_msg='AnalysisServerRequestingAllSamples'
37+
request_msg='RequestAllSamples'
3338
elif analyzer_env['analyzer'].train==True:
34-
request_msg='AnalysisServerRequestingTrainingSamples'
39+
request_msg='RequestTrainingSamples'
3540
elif analyzer_env['analyzer'].train==False:
36-
request_msg='AnalysisServerRequestingValidationSamples'
37-
tc.send_msg(app_socket, request_msg, tc.encode_strings(address))
38-
dataset_socket=tc.start_server(analysis_server)
41+
request_msg='RequestValidationSamples'
42+
tc.send_msg(dataset_socket, request_msg, tc.encode_strings(address))
3943

4044
while True:
4145
dataset_msg_type, dataset_msg_data = tc.recv_msg(dataset_socket)
4246

43-
if dataset_msg_type == 'NumSamples':
44-
num_samples=tc.decode_ints(dataset_msg_data)[0]
45-
pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters
46-
4747
if dataset_msg_type == 'InputTensorsID':
4848
input_tensors_id=tc.decode_ints(dataset_msg_data)
4949

@@ -53,6 +53,10 @@
5353
if dataset_msg_type == 'Labels':
5454
labels=tc.decode_strings(dataset_msg_data)
5555

56+
if dataset_msg_type == 'NumSamples':
57+
num_samples=tc.decode_ints(dataset_msg_data)[0]
58+
pbar=tqdm(total=num_samples, desc='Analyzing...', bar_format='{l_bar}{bar}| {remaining} left\n\n') #see https://github.com/tqdm/tqdm#parameters
59+
5660
if dataset_msg_type == 'StartSending':
5761
error_msg, return_value = safe_exec(analyzer_env['analyzer'].start_analysis, (num_samples, input_tensors_id, output_tensors_id, labels), description='analyzer definition')
5862
if error_msg is not None:
@@ -85,7 +89,7 @@
8589
if dataset_msg_type == 'DoneSending':
8690
pbar.close()
8791
error_msg, return_value = safe_exec(analyzer_env['analyzer'].finish_analysis, description='analyzer definition')
88-
tc.send_msg(dataset_socket, 'DoneReceiving')
92+
tc.send_msg(dataset_socket, 'DisconnectFromWorkerServer')
8993
dataset_socket.close()
9094
analysis_server.close()
9195
if error_msg is not None:
@@ -106,12 +110,14 @@
106110

107111
if msg_type == 'RequestAnalysisReport':
108112
resolution = tc.decode_ints(msg_data)
109-
if 'analyzer' in analyzer_env:
113+
if 'analyzer' in analyzer_env and resolution[0]>0 and resolution[1]>0:
110114
error_msg, return_value = safe_exec(analyzer_env['analyzer'].generate_report, (resolution[0:2],resolution[2]), description='analyzer definition')
111115
if error_msg is not None:
112116
print(error_msg, file=sys.stderr)
113117
if return_value is not None:
114118
tc.send_msg(app_socket, 'ReportImage', tc.encode_image(return_value))
119+
else:
120+
tc.send_msg(app_socket, 'ReportImage')
115121

116122
if msg_type == 'Exit':
117123
break

torchstudio/datasetload.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717
from collections.abc import Iterable
1818
from tqdm.auto import tqdm
19+
import hashlib
1920

2021
#monkey patch ssl to fix ssl certificate fail when downloading datasets on some configurations: https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error
2122
import ssl
@@ -207,9 +208,7 @@ def __getitem__(self, id):
207208
if msg_type == 'OutputTensorsID':
208209
output_tensors_id = tc.decode_ints(msg_data)
209210

210-
if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples':
211-
train_set=True if msg_type == 'ConnectAndSendTrainingSamples' or msg_type == 'ConnectAndSendAllSamples' else False
212-
valid_set=True if msg_type == 'ConnectAndSendValidationSamples' or msg_type == 'ConnectAndSendAllSamples' else False
211+
if msg_type == 'ConnectToWorkerServer':
213212
name, sshaddress, sshport, username, password, keydata, address, port = tc.decode_strings(msg_data)
214213
port=int(port)
215214

@@ -241,30 +240,49 @@ def __getitem__(self, id):
241240

242241
try:
243242
worker_socket = tc.connect((address,port),timeout=10)
244-
num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0)
245-
tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples))
246-
tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id))
247-
tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id))
248-
tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes))
249-
250-
tc.send_msg(worker_socket, 'StartSending')
251-
with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar:
252-
if train_set:
253-
meta_dataset.train()
254-
for i in range(len(meta_dataset)):
255-
tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i]))
256-
pbar.update(1)
257-
if valid_set:
258-
meta_dataset.valid()
259-
for i in range(len(meta_dataset)):
260-
tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i]))
261-
pbar.update(1)
262-
263-
tc.send_msg(worker_socket, 'DoneSending')
264-
train_msg_type, train_msg_data = tc.recv_msg(worker_socket)
265-
if train_msg_type == 'DoneReceiving':
266-
worker_socket.close()
267-
print('Samples transfer to '+name+' completed')
243+
while True:
244+
worker_msg_type, worker_msg_data = tc.recv_msg(worker_socket)
245+
246+
if worker_msg_type == 'RequestMetaInfos':
247+
tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id))
248+
tc.send_msg(worker_socket, 'OutputTensorsID', tc.encode_ints(output_tensors_id))
249+
tc.send_msg(worker_socket, 'Labels', tc.encode_strings(meta_dataset.classes))
250+
251+
if worker_msg_type == 'RequestHash':
252+
dataset_hash = hashlib.md5()
253+
dataset_hash.update(int(len(meta_dataset.train())).to_bytes(4, 'little'))
254+
if len(meta_dataset)>0:
255+
dataset_hash.update(tc.encode_torch_tensors(meta_dataset[0]))
256+
dataset_hash.update(int(len(meta_dataset.valid())).to_bytes(4, 'little'))
257+
if len(meta_dataset)>0:
258+
dataset_hash.update(tc.encode_torch_tensors(meta_dataset[0]))
259+
tc.send_msg(worker_socket, 'DatasetHash', dataset_hash.digest())
260+
261+
if worker_msg_type == 'RequestTrainingSamples' or worker_msg_type == 'RequestValidationSamples' or worker_msg_type == 'RequestAllSamples':
262+
train_set=True if worker_msg_type == 'RequestTrainingSamples' or worker_msg_type == 'RequestAllSamples' else False
263+
valid_set=True if worker_msg_type == 'RequestValidationSamples' or worker_msg_type == 'RequestAllSamples' else False
264+
num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0)
265+
tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples))
266+
267+
tc.send_msg(worker_socket, 'StartSending')
268+
with tqdm(total=num_samples, desc='Sending samples to '+name+'...', bar_format='{l_bar}{bar}| {remaining} left\n\n') as pbar:
269+
if train_set:
270+
meta_dataset.train()
271+
for i in range(len(meta_dataset)):
272+
tc.send_msg(worker_socket, 'TrainingSample', tc.encode_torch_tensors(meta_dataset[i]))
273+
pbar.update(1)
274+
if valid_set:
275+
meta_dataset.valid()
276+
for i in range(len(meta_dataset)):
277+
tc.send_msg(worker_socket, 'ValidationSample', tc.encode_torch_tensors(meta_dataset[i]))
278+
pbar.update(1)
279+
280+
tc.send_msg(worker_socket, 'DoneSending')
281+
282+
if worker_msg_type == 'DisconnectFromWorkerServer':
283+
worker_socket.close()
284+
print('Samples transfer to '+name+' completed')
285+
break
268286

269287
except:
270288
if sshaddress and sshport and username:
@@ -277,7 +295,7 @@ def __getitem__(self, id):
277295
except:
278296
pass
279297
try:
280-
sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DoneReceiving ping pong
298+
sshclient.close() #ssh connection must be closed only when all tcp socket data was received on the remote side, hence the DoneSending/DisconnectFromWorkerServer ping pong
281299
except:
282300
pass
283301

torchstudio/datasets/genericloader.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import os
2+
import torch
3+
from torch.utils.data import Dataset
4+
from PIL import Image
5+
import torchvision
6+
import torchaudio
7+
import numpy as np
8+
import sys
9+
10+
class GenericLoader(Dataset):
11+
"""A generic dataset loader.
12+
Suitable for classification, segmentation and regression datasets.
13+
Supports image, audio, and numpy array files.
14+
15+
Args:
16+
path (str):
17+
path to the dataset
18+
19+
classification (bool):
20+
True: classification dataset (single class prediction: class1, class2, ...)
21+
False: segmentation or regression dataset (multiple components: input, target, ...)
22+
23+
separator (str or None):
24+
'/': folders will be used to determine classes or components
25+
(classes: class1/1.ext, class1/2.ext, class2/1.ext, class2/2.ext, ...)
26+
(components: inputs/1.ext, inputs/2.ext, targets/1.ext, targets/2.ext, ...)
27+
28+
'_' or other separator: file name parts will be used to determine classes or components
29+
(classes: class1_1.ext, class1_2.ext, class2_1.ext, class2_2.ext, ...)
30+
(components: 1_input.ext, 1_output.ext, 2_input.ext, 2_output.ext, ...)
31+
32+
'' or None: file names or their content will be used to determine components
33+
(one sample per folder: 1/input.ext, 1/output.ext, 2/input.ext, 2/output.ext, ...)
34+
(samples in one folder: 1.ext, 2.ext, ...)
35+
36+
extensions (str):
37+
file extension to filters (such as: .jpg, .jpeg, .png, .mp3, .wav, .npy, .npz)
38+
39+
transforms (list):
40+
list of transforms to apply to the different components of each sample (use None is some components need no transform)
41+
(ie: [torchvision.transforms.Compose([transforms.Resize(64)]), torchaudio.transforms.Spectrogram()])
42+
"""
43+
44+
def __init__(self, path:str='', classification:bool=True, separator:str='/', extensions:str='.jpg, .jpeg, .png, .mp3, .wav, .npy, .npz', transforms=[]):
45+
exts = tuple(extensions.replace(' ','').split(','))
46+
paths = []
47+
self.samples = []
48+
self.classes = []
49+
self.transforms = transforms
50+
if not os.path.exists(path):
51+
print("Path not found.", file=sys.stderr)
52+
return
53+
for root, dirs, files in os.walk(path):
54+
for file in files:
55+
if file.endswith(exts):
56+
paths.append(os.path.join(root, file).replace('\\','/'))
57+
paths=sorted(paths)
58+
if not paths:
59+
print("No files found.", file=sys.stderr)
60+
return
61+
self.classification=classification
62+
if classification:
63+
if separator == '/':
64+
for path in paths:
65+
class_name=path.split('/')[-2]
66+
if class_name not in self.classes:
67+
self.classes.append(class_name)
68+
self.samples.append([path, self.classes.index(class_name)])
69+
elif separator:
70+
for path in paths:
71+
class_name = path.split('/')[-1].split(separator)[0]
72+
if class_name not in self.classes:
73+
self.classes.append(class_name)
74+
self.samples.append([path, self.classes.index(class_name)])
75+
else:
76+
print("You need a separator with classication datasets", file=sys.stderr)
77+
return
78+
else:
79+
samples_index = dict()
80+
if separator == '/':
81+
for path in paths:
82+
components_name=path.split('/')[-2]
83+
sample_name = path.split('/')[-1].split('.')[-2]
84+
if sample_name not in samples_index:
85+
samples_index[sample_name] = len(self.samples)
86+
self.samples.append([])
87+
self.samples[samples_index[sample_name]].append(path)
88+
elif separator:
89+
for path in paths:
90+
components_name = path.split('.')[-2].split(separator)[-1]
91+
sample_name = path.split('/')[-1].split(separator)[0]
92+
if sample_name not in samples_index:
93+
samples_index[sample_name] = len(self.samples)
94+
self.samples.append([])
95+
self.samples[samples_index[sample_name]].append(path)
96+
else:
97+
single_folder=True
98+
file_root=path[:path.rfind("/")]
99+
for path in paths:
100+
if not path.startswith(file_root):
101+
single_folder=False
102+
break
103+
if single_folder:
104+
for path in paths:
105+
sample_name = path.split('/')[-1].split('.')[-2]
106+
if sample_name not in samples_index:
107+
samples_index[sample_name] = len(self.samples)
108+
self.samples.append([])
109+
self.samples[samples_index[sample_name]].append(path)
110+
else:
111+
for path in paths:
112+
components_name = path.split('/')[-1].split('.')[-2]
113+
sample_name = path.split('/')[-2]
114+
if sample_name not in samples_index:
115+
samples_index[sample_name] = len(self.samples)
116+
self.samples.append([])
117+
self.samples[samples_index[sample_name]].append(path)
118+
119+
def to_tensors(self, path:str):
120+
if path.endswith('.jpg') or path.endswith('.jpeg') or path.endswith('.png'):
121+
img=Image.open(path)
122+
if img.getpalette():
123+
return [torch.from_numpy(np.array(img, dtype=np.uint8))]
124+
else:
125+
trans=torchvision.transforms.ToTensor()
126+
return [trans(img)]
127+
128+
if path.endswith('.mp3') or path.endswith('.wav'):
129+
waveform, sample_rate = torchaudio.load(path)
130+
return [waveform]
131+
132+
if path.endswith('.npy') or path.endswith('.npz'):
133+
arrays = np.load(path)
134+
if type(arrays) == dict:
135+
tensors = []
136+
for array in arrays:
137+
tensors.append(torch.from_numpy(arrays[array]))
138+
return tensors
139+
else:
140+
return [torch.from_numpy(arrays)]
141+
142+
def __len__(self):
143+
return len(self.samples)
144+
145+
def __getitem__(self, id):
146+
"""
147+
Returns:
148+
A tuple of tensors.
149+
"""
150+
151+
if id < 0 or id >= len(self):
152+
raise IndexError
153+
154+
components = []
155+
for component in self.samples[id]:
156+
if type(component) is str:
157+
components.extend(self.to_tensors(component))
158+
else:
159+
components.extend([torch.tensor(component)])
160+
161+
if self.transforms:
162+
if type(self.transforms) is not list and type(self.transforms) is not tuple:
163+
self.transforms = [self.transforms]
164+
for i, transform in enumerate(self.transforms):
165+
if i < len(components) and transform is not None:
166+
components[i] = transform(components[i])
167+
168+
return tuple(components)

torchstudio/metrics/accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class Accuracy(Metric):
99
threshold: error threshold below which predictions are considered accurate (not used in multiclass)
1010
normalize: if set to True, normalize predictions with sigmoid or softmax before calculating the accuracy
1111
"""
12-
def __init__(self, threshold: float = 0.1, normalize: bool = False):
12+
def __init__(self, threshold: float = 0.01, normalize: bool = False):
1313
self.threshold = threshold
1414
self.normalize = normalize
1515
self.reset()

0 commit comments

Comments
 (0)