Skip to content

Commit 64e03b6

Browse files
committed
🌹 Add and for all config class.
1 parent f53ecd9 commit 64e03b6

File tree

9 files changed

+58
-10
lines changed

9 files changed

+58
-10
lines changed

tensorflow_tts/configs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from tensorflow_tts.configs.base_config import BaseConfig
12
from tensorflow_tts.configs.fastspeech import FastSpeechConfig
23
from tensorflow_tts.configs.fastspeech2 import FastSpeech2Config
34
from tensorflow_tts.configs.melgan import (

tensorflow_tts/configs/base_config.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2020 TensorFlowTTS Team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Base Config for all config."""
16+
17+
import abc
18+
import yaml
19+
import os
20+
21+
from tensorflow_tts.utils.utils import CONFIG_FILE_NAME
22+
23+
24+
class BaseConfig(abc.ABC):
25+
def set_pretrained_config(self, config):
26+
self.config = config
27+
28+
def save_pretrained(self, saved_path):
29+
"""Save config to file"""
30+
with open(os.path.join(saved_path, CONFIG_FILE_NAME), "w") as file:
31+
yaml.dump(self.config, file)

tensorflow_tts/configs/fastspeech.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import collections
1818

19+
from tensorflow_tts.configs import BaseConfig
1920
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS as lj_symbols
2021
from tensorflow_tts.processor.kss import KSS_SYMBOLS as kss_symbols
2122
from tensorflow_tts.processor.baker import BAKER_SYMBOLS as bk_symbols
@@ -44,12 +45,12 @@
4445
)
4546

4647

47-
class FastSpeechConfig(object):
48+
class FastSpeechConfig(BaseConfig):
4849
"""Initialize FastSpeech Config."""
4950

5051
def __init__(
5152
self,
52-
dataset='ljspeech',
53+
dataset="ljspeech",
5354
vocab_size=len(lj_symbols),
5455
n_speakers=1,
5556
encoder_hidden_size=384,

tensorflow_tts/configs/fastspeech2.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""FastSpeech2 Config object."""
1616

17+
1718
from tensorflow_tts.configs import FastSpeechConfig
1819

1920

tensorflow_tts/configs/hifigan.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
"""HifiGAN Config object."""
1616

1717

18-
class HifiGANGeneratorConfig(object):
18+
from tensorflow_tts.configs import BaseConfig
19+
20+
21+
class HifiGANGeneratorConfig(BaseConfig):
1922
"""Initialize HifiGAN Generator Config."""
2023

2124
def __init__(

tensorflow_tts/configs/melgan.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
"""MelGAN Config object."""
1616

1717

18-
class MelGANGeneratorConfig(object):
18+
from tensorflow_tts.configs import BaseConfig
19+
20+
21+
class MelGANGeneratorConfig(BaseConfig):
1922
"""Initialize MelGAN Generator Config."""
2023

2124
def __init__(

tensorflow_tts/configs/parallel_wavegan.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
"""ParallelWaveGAN Config object."""
1616

1717

18-
class ParallelWaveGANGeneratorConfig(object):
18+
from tensorflow_tts.configs import BaseConfig
19+
20+
21+
class ParallelWaveGANGeneratorConfig(BaseConfig):
1922
"""Initialize ParallelWaveGAN Generator Config."""
2023

2124
def __init__(

tensorflow_tts/configs/tacotron2.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,20 @@
1414
# limitations under the License.
1515
"""Tacotron-2 Config object."""
1616

17+
18+
from tensorflow_tts.configs import BaseConfig
1719
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS as lj_symbols
1820
from tensorflow_tts.processor.kss import KSS_SYMBOLS as kss_symbols
1921
from tensorflow_tts.processor.baker import BAKER_SYMBOLS as bk_symbols
2022
from tensorflow_tts.processor.libritts import LIBRITTS_SYMBOLS as lbri_symbols
2123

2224

23-
class Tacotron2Config(object):
25+
class Tacotron2Config(BaseConfig):
2426
"""Initialize Tacotron-2 Config."""
2527

2628
def __init__(
2729
self,
28-
dataset='ljspeech',
30+
dataset="ljspeech",
2931
vocab_size=len(lj_symbols),
3032
embedding_hidden_size=512,
3133
initializer_range=0.02,
@@ -60,7 +62,7 @@ def __init__(
6062
self.vocab_size = vocab_size
6163
elif dataset == "kss":
6264
self.vocab_size = len(kss_symbols)
63-
elif dataset == 'baker':
65+
elif dataset == "baker":
6466
self.vocab_size = len(bk_symbols)
6567
elif dataset == "libritts":
6668
self.vocab_size = len(lbri_symbols)

tensorflow_tts/inference/auto_config.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
("melgan_generator", MelGANGeneratorConfig),
4242
("hifigan_generator", HifiGANGeneratorConfig),
4343
("tacotron2", Tacotron2Config),
44-
("parallel_wavegan_generator", ParallelWaveGANGeneratorConfig)
44+
("parallel_wavegan_generator", ParallelWaveGANGeneratorConfig),
4545
]
4646
)
4747

@@ -58,7 +58,9 @@ def from_pretrained(cls, pretrained_path, **kwargs):
5858
# load weights from hf hub
5959
if not os.path.isfile(pretrained_path):
6060
# retrieve correct hub url
61-
download_url = hf_hub_url(repo_id=pretrained_path, filename=CONFIG_FILE_NAME)
61+
download_url = hf_hub_url(
62+
repo_id=pretrained_path, filename=CONFIG_FILE_NAME
63+
)
6264

6365
pretrained_path = str(
6466
cached_download(
@@ -76,6 +78,7 @@ def from_pretrained(cls, pretrained_path, **kwargs):
7678
model_type = config["model_type"]
7779
config_class = CONFIG_MAPPING[model_type]
7880
config_class = config_class(**config[model_type + "_params"], **kwargs)
81+
config_class.set_pretrained_config(config)
7982
return config_class
8083
except Exception:
8184
raise ValueError(

0 commit comments

Comments
 (0)