Skip to content

Commit 1a5731a

Browse files
committed
🐽 Add save_pretrained for TFAutoModel.
1 parent 77f6ac5 commit 1a5731a

12 files changed

+70
-21
lines changed

tensorflow_tts/configs/base_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222

2323

2424
class BaseConfig(abc.ABC):
25-
def set_pretrained_config(self, config):
26-
self.config = config
25+
def set_config_params(self, config_params):
26+
self.config_params = config_params
2727

2828
def save_pretrained(self, saved_path):
2929
"""Save config to file"""
3030
os.makedirs(saved_path, exist_ok=True)
3131
with open(os.path.join(saved_path, CONFIG_FILE_NAME), "w") as file:
32-
yaml.dump(self.config, file)
32+
yaml.dump(self.config_params, file, Dumper=yaml.Dumper)

tensorflow_tts/inference/auto_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ def from_pretrained(cls, pretrained_path, **kwargs):
7272
)
7373

7474
with open(pretrained_path) as f:
75-
config = yaml.load(f, Loader=yaml.SafeLoader)
75+
config = yaml.load(f, Loader=yaml.Loader)
7676

7777
try:
7878
model_type = config["model_type"]
7979
config_class = CONFIG_MAPPING[model_type]
8080
config_class = config_class(**config[model_type + "_params"], **kwargs)
81-
config_class.set_pretrained_config(config)
81+
config_class.set_config_params(config)
8282
return config_class
8383
except Exception:
8484
raise ValueError(

tensorflow_tts/inference/auto_model.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
import warnings
1919
import os
20+
import copy
2021

2122
from collections import OrderedDict
2223

@@ -67,9 +68,7 @@ def __init__(self):
6768
raise EnvironmentError("Cannot be instantiated using `__init__()`")
6869

6970
@classmethod
70-
def from_pretrained(cls, config=None, pretrained_path=None, **kwargs):
71-
is_build = kwargs.pop("is_build", True)
72-
71+
def from_pretrained(cls, pretrained_path=None, config=None, **kwargs):
7372
# load weights from hf hub
7473
if pretrained_path is not None:
7574
if not os.path.isfile(pretrained_path):
@@ -101,8 +100,8 @@ def from_pretrained(cls, config=None, pretrained_path=None, **kwargs):
101100
config
102101
):
103102
model = model_class(config=config, **kwargs)
104-
if is_build:
105-
model._build()
103+
model.set_config(config)
104+
model._build()
106105
if pretrained_path is not None and ".h5" in pretrained_path:
107106
try:
108107
model.load_weights(pretrained_path)

tensorflow_tts/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from tensorflow_tts.models.base_model import BaseModel
12
from tensorflow_tts.models.fastspeech import TFFastSpeech
23
from tensorflow_tts.models.fastspeech2 import TFFastSpeech2
34
from tensorflow_tts.models.melgan import (

tensorflow_tts/models/base_model.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2020 TensorFlowTTS Team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Base Model for all model."""
16+
17+
import tensorflow as tf
18+
import yaml
19+
import os
20+
import numpy as np
21+
22+
from tensorflow_tts.utils.utils import MODEL_FILE_NAME, CONFIG_FILE_NAME
23+
24+
25+
class BaseModel(tf.keras.Model):
26+
def set_config(self, config):
27+
self.config = config
28+
29+
def save_pretrained(self, saved_path):
30+
"""Save config and weights to file"""
31+
os.makedirs(saved_path, exist_ok=True)
32+
self.config.save_pretrained(saved_path)
33+
self.save_weights(os.path.join(saved_path, MODEL_FILE_NAME))

tensorflow_tts/models/fastspeech.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import numpy as np
1818
import tensorflow as tf
1919

20+
from tensorflow_tts.models import BaseModel
21+
2022

2123
def get_initializer(initializer_range=0.02):
2224
"""Creates a `tf.initializers.truncated_normal` with the given range.
@@ -746,7 +748,7 @@ def body(
746748
return outputs, encoder_masks
747749

748750

749-
class TFFastSpeech(tf.keras.Model):
751+
class TFFastSpeech(BaseModel):
750752
"""TF Fastspeech module."""
751753

752754
def __init__(self, config, **kwargs):

tensorflow_tts/models/hifigan.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorflow_tts.utils import GroupConv1D
2424
from tensorflow_tts.utils import WeightNormalization
2525

26+
from tensorflow_tts.models import BaseModel
2627
from tensorflow_tts.models import TFMelGANGenerator
2728

2829

@@ -133,7 +134,7 @@ def call(self, x, training=False):
133134
return xs / len(self.list_resblock)
134135

135136

136-
class TFHifiGANGenerator(tf.keras.Model):
137+
class TFHifiGANGenerator(BaseModel):
137138
def __init__(self, config, **kwargs):
138139
super().__init__(**kwargs)
139140
# check hyper parameter is valid or not
@@ -338,7 +339,7 @@ def _apply_weightnorm(self, list_layers):
338339
pass
339340

340341

341-
class TFHifiGANMultiPeriodDiscriminator(tf.keras.Model):
342+
class TFHifiGANMultiPeriodDiscriminator(BaseModel):
342343
"""Tensorflow Hifigan Multi Period discriminator module."""
343344

344345
def __init__(self, config, **kwargs):

tensorflow_tts/models/mb_melgan.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tensorflow as tf
2222
from scipy.signal import kaiser
2323

24+
from tensorflow_tts.models import BaseModel
2425
from tensorflow_tts.models import TFMelGANGenerator
2526

2627

tensorflow_tts/models/melgan.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import tensorflow as tf
1919

20+
from tensorflow_tts.models import BaseModel
2021
from tensorflow_tts.utils import GroupConv1D, WeightNormalization
2122

2223

@@ -186,7 +187,7 @@ def _apply_weightnorm(self, list_layers):
186187
pass
187188

188189

189-
class TFMelGANGenerator(tf.keras.Model):
190+
class TFMelGANGenerator(BaseModel):
190191
"""Tensorflow MelGAN generator module."""
191192

192193
def __init__(self, config, **kwargs):
@@ -450,7 +451,7 @@ def _apply_weightnorm(self, list_layers):
450451
pass
451452

452453

453-
class TFMelGANMultiScaleDiscriminator(tf.keras.Model):
454+
class TFMelGANMultiScaleDiscriminator(BaseModel):
454455
"""MelGAN multi-scale discriminator module."""
455456

456457
def __init__(self, config, **kwargs):

tensorflow_tts/models/parallel_wavegan.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import tensorflow as tf
1919

20+
from tensorflow_tts.models import BaseModel
21+
2022

2123
def get_initializer(initializer_seed=42):
2224
"""Creates a `tf.initializers.he_normal` with the given seed.
@@ -345,7 +347,7 @@ def call(self, c):
345347
return self.upsample(c_)
346348

347349

348-
class TFParallelWaveGANGenerator(tf.keras.Model):
350+
class TFParallelWaveGANGenerator(BaseModel):
349351
"""Parallel WaveGAN Generator module."""
350352

351353
def __init__(self, config, **kwargs):
@@ -491,7 +493,7 @@ def inference(self, mels):
491493
return x
492494

493495

494-
class TFParallelWaveGANDiscriminator(tf.keras.Model):
496+
class TFParallelWaveGANDiscriminator(BaseModel):
495497
"""Parallel WaveGAN Discriminator module."""
496498

497499
def __init__(self, config, **kwargs):

tensorflow_tts/models/tacotron2.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
from tensorflow_tts.utils import dynamic_decode
2929

30+
from tensorflow_tts.models import BaseModel
31+
3032

3133
def get_initializer(initializer_range=0.02):
3234
"""Creates a `tf.initializers.truncated_normal` with the given range.
@@ -737,7 +739,7 @@ def step(self, time, inputs, state, training=False):
737739
return (outputs, next_state, next_inputs, finished)
738740

739741

740-
class TFTacotron2(tf.keras.Model):
742+
class TFTacotron2(BaseModel):
741743
"""Tensorflow tacotron-2 model."""
742744

743745
def __init__(self, config, **kwargs):
@@ -760,10 +762,10 @@ def __init__(self, config, **kwargs):
760762
units=config.n_mels, name="residual_projection"
761763
)
762764

763-
self.config = config
764765
self.use_window_mask = False
765766
self.maximum_iterations = 4000
766767
self.enable_tflite_convertible = enable_tflite_convertible
768+
self.config = config
767769

768770
def setup_window(self, win_front, win_back):
769771
"""Call only for inference."""

test/test_auto.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
)
4343
def test_auto_processor(mapper_path):
4444
processor = AutoProcessor.from_pretrained(pretrained_path=mapper_path)
45+
processor.save_pretrained("./test_saved")
46+
processor = AutoProcessor.from_pretrained("./test_saved/processor.json")
4547

4648

4749
@pytest.mark.parametrize(
@@ -65,7 +67,12 @@ def test_auto_processor(mapper_path):
6567
)
6668
def test_auto_model(config_path):
6769
config = AutoConfig.from_pretrained(pretrained_path=config_path)
68-
model = TFAutoModel.from_pretrained(config=config, pretrained_path=None)
70+
model = TFAutoModel.from_pretrained(pretrained_path=None, config=config)
6971

7072
# test save_pretrained
71-
config.save_pretrained("./")
73+
config.save_pretrained("./test_saved")
74+
model.save_pretrained("./test_saved")
75+
76+
# test from_pretrained
77+
config = AutoConfig.from_pretrained("./test_saved/config.yml")
78+
model = TFAutoModel.from_pretrained("./test_saved/model.h5", config=config)

0 commit comments

Comments
 (0)