Skip to content

Commit 0ebbf2a

Browse files
0.9.15
1 parent 2b41e37 commit 0ebbf2a

File tree

6 files changed

+62
-27
lines changed

6 files changed

+62
-27
lines changed

torchstudio/datasets/genericloader.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ class GenericLoader(Dataset):
3434
(samples in one folder: 1.ext, 2.ext, ...)
3535
3636
extensions (str):
37-
file extension to filters (such as: .jpg, .jpeg, .png, .mp3, .wav, .npy, .npz)
37+
file extension to filters
38+
(supported: .jpg, .jpeg, .png, .webp, .tif, .tiff, .mp3, .wav, .ogg, .flac, .npy, .npz)
3839
3940
transforms (list):
4041
list of transforms to apply to the different components of each sample (use None is some components need no transform)
4142
(ie: [torchvision.transforms.Compose([transforms.Resize(64)]), torchaudio.transforms.Spectrogram()])
4243
"""
4344

44-
def __init__(self, path:str='', classification:bool=True, separator:str='/', extensions:str='.jpg, .jpeg, .png, .mp3, .wav, .npy, .npz', transforms=[]):
45+
def __init__(self, path:str='', classification:bool=True, separator:str='/', extensions:str='.jpg, .jpeg, .png, .webp, .tif, .tiff, .mp3, .wav, .ogg, .flac, .npy, .npz', transforms=[]):
4546
exts = tuple(extensions.replace(' ','').split(','))
4647
paths = []
4748
self.samples = []
@@ -126,27 +127,31 @@ def __init__(self, path:str='', classification:bool=True, separator:str='/', ext
126127
self.samples[samples_index[sample_name]].append(path)
127128

128129
def to_tensors(self, path:str):
129-
if path.endswith('.jpg') or path.endswith('.jpeg') or path.endswith('.png'):
130+
tensors = []
131+
if path.endswith('.jpg') or path.endswith('.jpeg') or path.endswith('.png') or path.endswith('.webp') or path.endswith('.tif') or path.endswith('.tiff'):
130132
img=Image.open(path)
131-
if img.mode=='1' or img.mode=='L' or img.mode=='P':
132-
return [torch.from_numpy(np.array(img, dtype=np.uint8))]
133-
else:
134-
trans=torchvision.transforms.ToTensor()
135-
return [trans(img)]
133+
for i in range(img.n_frames):
134+
if img.mode=='1' or img.mode=='L' or img.mode=='P':
135+
tensors.append(torch.from_numpy(np.array(img, dtype=np.uint8)))
136+
else:
137+
trans=torchvision.transforms.ToTensor()
138+
tensors.append(trans(img))
139+
if i<(img.n_frames-1):
140+
img.seek(img.tell()+1)
136141

137-
if path.endswith('.mp3') or path.endswith('.wav'):
142+
if path.endswith('.mp3') or path.endswith('.wav') or path.endswith('.ogg') or path.endswith('.flac'):
138143
waveform, sample_rate = torchaudio.load(path)
139-
return [waveform]
144+
tensors.append(waveform)
140145

141146
if path.endswith('.npy') or path.endswith('.npz'):
142147
arrays = np.load(path)
143148
if type(arrays) == dict:
144-
tensors = []
145149
for array in arrays:
146150
tensors.append(torch.from_numpy(arrays[array]))
147-
return tensors
148151
else:
149-
return [torch.from_numpy(arrays)]
152+
tensors.append(torch.from_numpy(arrays))
153+
154+
return tensors
150155

151156
def __len__(self):
152157
return len(self.samples)

torchstudio/modeltrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def send_results_back():
447447
if error_msg:
448448
print("Error exporting:", error_msg, file=sys.stderr)
449449
else:
450-
error_msg, torchscript_model = safe_exec(torch.onnx.export,{'model':torchscript_model, 'args':input_tensors, 'f':tc.decode_strings(msg_data)[0], 'opset_version':12})
450+
error_msg, torchscript_model = safe_exec(torch.onnx.export,{'model':torchscript_model, 'args':input_tensors, 'f':tc.decode_strings(msg_data)[0], 'input_names': eval(tc.decode_strings(msg_data)[1]), 'output_names': eval(tc.decode_strings(msg_data)[2]), 'dynamic_axes': eval(tc.decode_strings(msg_data)[3]), 'opset_version':17})
451451
if error_msg:
452452
print("Error exporting:", error_msg, file=sys.stderr)
453453
else:

torchstudio/pythoninstall.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
#otherwise conda install may fail
2+
del __file__
3+
__package__=None
4+
__spec__=None
5+
16
import sys
27
import importlib
38
import importlib.util
@@ -25,7 +30,7 @@ def init_patch(self, **kwargs):
2530

2631
if not args.package:
2732
#https://edcarp.github.io/introduction-to-conda-for-data-scientists/03-using-packages-and-channels/index.html#alternative-syntax-for-installing-packages-from-specific-channels
28-
conda_install=f"{args.channel}::pytorch {args.channel}::torchvision {args.channel}::torchaudio {args.channel}::torchtext"
33+
conda_install=f"pytorch torchvision torchaudio torchtext"
2934
if (sys.platform.startswith('win') or sys.platform.startswith('linux')):
3035
if args.cuda:
3136
print("Checking the latest supported CUDA version...")
@@ -47,26 +52,28 @@ def init_patch(self, **kwargs):
4752
highest_cuda_string='.'.join([str(value) for value in highest_cuda_version])
4853
print("Using CUDA "+highest_cuda_string)
4954
print("")
50-
conda_install+=f" {args.channel}::pytorch-cuda="+highest_cuda_string+" -c nvidia"
55+
conda_install+=" pytorch-cuda="+highest_cuda_string+" -c "+args.channel+" -c nvidia"
5156
else:
52-
conda_install+=f" {args.channel}::cpuonly"
57+
conda_install+=" cpuonly -c "+args.channel
58+
else:
59+
conda_install+=" -c "+args.channel
5360
print(f"Downloading and installing {args.channel} packages...")
5461
print("")
55-
conda_install+=" -k" #allow insecure ssl connections
5662
# https://stackoverflow.com/questions/41767340/using-conda-install-within-a-python-script
5763
(stdout_str, stderr_str, return_code_int) = Conda.run_command(Conda.Commands.INSTALL,conda_install.split(),use_exception_handler=True,stdout=sys.stdout,stderr=sys.stderr)
5864
if return_code_int!=0:
5965
exit(return_code_int)
6066
print("")
6167

68+
# onnx required for onnx export
6269
# datasets(+huggingface_hub) is required by hugging face hub
6370
# scipy required by torchvision: Caltech ImageNet SBD SVHN datasets and Inception v3 GoogLeNet models
6471
# pandas required by the dataset tutorial: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
6572
# matplotlib-base required by torchstudio renderers
6673
# python-graphviz required by torchstudio graph
6774
# paramiko required for ssh connections (+updated cffi required on intel mac)
6875
# pysoundfile required by torchaudio datasets: https://pytorch.org/audio/stable/backend.html#soundfile-backend
69-
conda_install="datasets scipy pandas matplotlib-base python-graphviz paramiko pysoundfile"
76+
conda_install="onnx datasets scipy pandas matplotlib-base python-graphviz paramiko pysoundfile"
7077
if sys.platform.startswith('darwin'):
7178
conda_install+=" cffi"
7279

@@ -75,7 +82,7 @@ def init_patch(self, **kwargs):
7582

7683
print("Downloading and installing conda-forge packages...")
7784
print("")
78-
conda_install+=" -c conda-forge -k"
85+
conda_install+=" -c conda-forge"
7986
(stdout_str, stderr_str, return_code_int) = Conda.run_command(Conda.Commands.INSTALL,conda_install.split(),use_exception_handler=True,stdout=sys.stdout,stderr=sys.stderr)
8087
if return_code_int!=0:
8188
exit(return_code_int)

torchstudio/renderers/bitmap.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@ class Bitmap(Renderer):
2020
Values can be 'viridis', 'plasma', 'inferno', 'magma', 'cividis'
2121
colors: List of colors for each channel for multi channels bitmaps (looped if necessary)
2222
rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise)
23-
invert (bool): Invert vertical axis.
23+
invert (bool): Invert vertical axis
24+
normalize (bool): Normalize values
2425
"""
25-
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0, invert=False):
26+
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0, invert=False, normalize=False):
2627
super().__init__()
2728
self.colormap=colormap
2829
self.colors=colors
2930
self.rotate=rotate
3031
self.invert=invert
32+
self.normalize=normalize
3133

3234
def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]):
3335
#check dimensions
34-
print(str(tensor.dtype))
3536
if len(tensor.shape)!=3 and (len(tensor.shape)!=2 or 'int' not in str(tensor.dtype)):
3637
print("Bitmap renderer requires a 3D tensor or 2D tensor of ints, got a "+str(len(tensor.shape))+"D tensor.", file=sys.stderr)
3738
return None
@@ -50,6 +51,12 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
5051
if self.rotate>0:
5152
tensor=np.rot90(tensor, self.rotate, axes=(1, 2))
5253

54+
tensor=tensor.astype(np.float32)
55+
if self.normalize:
56+
max_value=np.amax(tensor)
57+
if max_value>0:
58+
tensor=tensor/max_value
59+
5360
#apply brightness, gamma and conversion to uint8, then transform CHW to HWC
5461
tensor = np.multiply(np.clip(np.power(np.clip(tensor*scale[0],0,1),1/scale[3]),0,1),255).astype(np.uint8)
5562
tensor = tensor.transpose((1, 2, 0))

torchstudio/renderers/spectrogram.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ class Spectrogram(Renderer):
2020
Values can be 'viridis', 'plasma', 'inferno', 'magma', 'cividis'
2121
colors: List of colors for each channel for multi channels spectrograms (looped if necessary)
2222
rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise)
23+
normalize (bool): Normalize values
2324
"""
24-
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0, invert=False):
25+
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0, invert=False, normalize=False):
2526
super().__init__()
2627
self.colormap=colormap
2728
self.colors=colors
2829
self.rotate=rotate
2930
self.invert=invert
31+
self.normalize=normalize
3032

3133
def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]):
3234
#check dimensions
@@ -35,8 +37,8 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
3537
return None
3638

3739
if np.iscomplexobj(tensor)==False and tensor.shape[0]%2!=0:
38-
print("Spectrogram renderer requires a complex tensor or a tensor with an even number of channels", file=sys.stderr)
39-
return None
40+
#add missing channel (needs pairs to be interpred as complex channels)
41+
tensor=np.append(tensor, np.zeros((1,tensor.shape[1],tensor.shape[2])), axis=0)
4042

4143
#convert complex spectrogram to amplitude spectrogram
4244
if np.iscomplexobj(tensor):
@@ -55,6 +57,12 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
5557
if self.rotate>0:
5658
tensor=np.rot90(tensor, self.rotate, axes=(1, 2))
5759

60+
tensor=tensor.astype(np.float32)
61+
if self.normalize:
62+
max_value=np.amax(tensor)
63+
if max_value>0:
64+
tensor=tensor/max_value
65+
5866
#apply brightness, gamma and conversion to uint8, then transform CHW to HWC
5967
tensor = np.multiply(np.clip(np.power(tensor*scale[0],1/scale[3]),0,1),255).astype(np.uint8)
6068
tensor = tensor.transpose((1, 2, 0))

torchstudio/renderers/volume.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ class Volume(Renderer):
2121
Values can be 'viridis', 'plasma', 'inferno', 'magma', 'cividis'
2222
colors: List of colors for each channel for multi channels volumes (looped if necessary)
2323
rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise)
24+
normalize (bool): Normalize values
2425
"""
25-
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0, invert=False):
26+
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0, invert=False, normalize=False):
2627
super().__init__()
2728
self.colormap=colormap
2829
self.colors=colors
2930
self.rotate=rotate
3031
self.invert=invert
32+
self.normalize=normalize
3133

3234
def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]):
3335
#check dimensions
@@ -52,6 +54,12 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
5254
if self.rotate>0:
5355
tensor=np.rot90(tensor, self.rotate, axes=(1, 2))
5456

57+
tensor=tensor.astype(np.float32)
58+
if self.normalize:
59+
max_value=np.amax(tensor)
60+
if max_value>0:
61+
tensor=tensor/max_value
62+
5563
#apply luminosity and conversion to uint8, then transform CHW to HWC
5664
tensor = np.multiply(np.clip(np.power(np.clip(tensor*scale[0],0,1),1/scale[3]),0,1),255).astype(np.uint8)
5765
tensor = tensor.transpose((1, 2, 0))

0 commit comments

Comments
 (0)