Skip to content

Commit 9d486cd

Browse files
committed
Initial commit
0 parents  commit 9d486cd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+3310
-0
lines changed

.gitignore

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
env/
12+
build/
13+
develop-eggs/
14+
dist/
15+
downloads/
16+
eggs/
17+
.eggs/
18+
lib/
19+
lib64/
20+
parts/
21+
sdist/
22+
var/
23+
wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
49+
# Translations
50+
*.mo
51+
*.pot
52+
53+
# Django stuff:
54+
*.log
55+
local_settings.py
56+
57+
# Flask stuff:
58+
instance/
59+
.webassets-cache
60+
61+
# Scrapy stuff:
62+
.scrapy
63+
64+
# Sphinx documentation
65+
docs/_build/
66+
67+
# PyBuilder
68+
target/
69+
70+
# Jupyter Notebook
71+
.ipynb_checkpoints
72+
73+
# pyenv
74+
.python-version
75+
76+
# celery beat schedule file
77+
celerybeat-schedule
78+
79+
# SageMath parsed files
80+
*.sage.py
81+
82+
# dotenv
83+
.env
84+
85+
# virtualenv
86+
.venv
87+
venv/
88+
ENV/
89+
90+
# Spyder project settings
91+
.spyderproject
92+
.spyproject
93+
94+
# Rope project settings
95+
.ropeproject
96+
97+
# mkdocs documentation
98+
/site
99+
100+
# mypy
101+
.mypy_cache/
102+
103+
# input data, saved log, checkpoints
104+
data/
105+
input/
106+
saved/
107+
108+
# editor, os cache directory
109+
.vscode/
110+
.idea/
111+
__MACOSX/

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2021 Daniil Ivanov
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# ASR project barebones
2+
3+
## Installation guide
4+
5+
< Write your installation guide here >
6+
7+
```shell
8+
pip install -r ./requirements.txt
9+
```
10+
11+
## Implementation guide
12+
13+
1) Search project for `raise NotImplementedError` and implement missing functionality
14+
2) Make sure all tests work without errors
15+
```shell
16+
python -m unittest discover hw_asr/tests
17+
```
18+
19+
3) Make sure `test.py` works fine and works as expected.
20+
You should create files `default_test_config.json` and your installation guide should download your model
21+
checpoint and configs in `default_test_model/checkpoint.pth` and `default_test_model/config.json`.
22+
```shell
23+
python test.py \
24+
-c default_test_config.json \
25+
-r default_test_model/checkpoint.pth \
26+
-t test_data \
27+
-o test_result.json
28+
```
29+
4) Use `train.py` for training
30+
31+
## Credits
32+
this repository is based on a heavily modified fork of [pytorch-template](https://github.com/victoresque/pytorch-template) repository.

hw_asr/__init__.py

Whitespace-only changes.

hw_asr/augmentations/__init__.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from collections import Callable
2+
from typing import List
3+
4+
import hw_asr.augmentations.spectrogram_augmentations
5+
import hw_asr.augmentations.wave_augmentations
6+
from hw_asr.augmentations.sequential import SequentialAugmentation
7+
from hw_asr.utils.parse_config import ConfigParser
8+
9+
10+
def from_configs(configs: ConfigParser):
11+
wave_augs = []
12+
if "augmentations" in configs.config and "wave" in configs.config["augmentations"]:
13+
for aug_dict in configs.config["augmentations"]["wave"]:
14+
wave_augs.append(
15+
configs.init_obj(aug_dict, hw_asr.augmentations.wave_augmentations)
16+
)
17+
18+
spec_augs = []
19+
if "augmentations" in configs.config and "spectrogram" in configs.config["augmentations"]:
20+
for aug_dict in configs.config["augmentations"]["spectrogram"]:
21+
spec_augs.append(
22+
configs.init_obj(aug_dict, hw_asr.augmentations.spectrogram_augmentations)
23+
)
24+
return _to_function(wave_augs), _to_function(spec_augs)
25+
26+
27+
def _to_function(augs_list: List[Callable]):
28+
if len(augs_list) == 0:
29+
return None
30+
elif len(augs_list) == 1:
31+
return augs_list[0]
32+
else:
33+
return SequentialAugmentation(augs_list)

hw_asr/augmentations/base.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from torch import Tensor
2+
3+
4+
class AugmentationBase:
5+
def __call__(self, data: Tensor) -> Tensor:
6+
raise NotImplementedError

hw_asr/augmentations/random_apply.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import random
2+
from typing import Callable
3+
4+
from torch import Tensor
5+
6+
7+
class RandomApply:
8+
def __init__(self, augmentation: Callable, p: float):
9+
assert 0 <= p <= 1
10+
self.augmentation = augmentation
11+
self.p = p
12+
13+
def __call__(self, data: Tensor) -> Tensor:
14+
if random.random() < self.p:
15+
return self.augmentation(data)
16+
else:
17+
return data

hw_asr/augmentations/sequential.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import List, Callable
2+
3+
from torch import Tensor
4+
5+
from hw_asr.augmentations.base import AugmentationBase
6+
7+
8+
class SequentialAugmentation(AugmentationBase):
9+
def __init__(self, augmentation_list: List[Callable]):
10+
self.augmentation_list = augmentation_list
11+
12+
def __call__(self, data: Tensor) -> Tensor:
13+
x = data
14+
for augmentation in self.augmentation_list:
15+
x = augmentation(data)
16+
return x
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch_audiomentations
2+
from torch import Tensor
3+
4+
from hw_asr.augmentations.base import AugmentationBase
5+
6+
7+
class Gain(AugmentationBase):
8+
def __init__(self, *args, **kwargs):
9+
self._aug = torch_audiomentations.Gain(*args, **kwargs)
10+
11+
def __call__(self, data: Tensor):
12+
x = data.unsqueeze(1)
13+
return self._aug(x).squeeze(1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from hw_asr.augmentations.wave_augmentations.Gain import Gain
2+
3+
__all__ = [
4+
"Gain"
5+
]

hw_asr/base/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .base_model import *
2+
from .base_trainer import *

0 commit comments

Comments
 (0)