-
Notifications
You must be signed in to change notification settings - Fork 0
/
_test_AM.py
67 lines (51 loc) · 2.1 KB
/
_test_AM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import scipy.io.wavfile as wav
from scipy.fftpack import fft
from utils import decode_ctc,compute_fbank
import os
# 0.准备解码所需字典
from utils import get_data, data_hparams
data_args = data_hparams()
data_args.data_type = 'train'
data_args.self_wav = True
data_args.thchs30 = False
data_args.aishell = False
data_args.prime = False
data_args.stcmd = False
train_data = get_data(data_args)
# 1.声学模型-----------------------------------
from Model_Speech import Am, am_hparams
#不打印警告
#os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
am_args = am_hparams()
am_args.vocab_size = len(train_data.am_vocab)
am = Am(am_args)
print('loading acoustic model...')
am.ctc_model.load_weights('model_speech/model_self.h5')#从绝对路径的检查点恢复权重数据
import matplotlib.pyplot as plt
filepath = 'test_wav/5_.wav'
_, wavsignal = wav.read(filepath)
#plt.plot(wavsignal)
#plt.show()
fbank = compute_fbank(filepath)
#plt.imshow(fbank.T, origin = 'lower')
#plt.show()
pad_fbank = np.zeros((fbank.shape[0]//8*8+8, fbank.shape[1])) #“//”整除,向下取整,“//”与“*”优先级相同,从左往右计算
#结果是a.shape[0]即每个元素的帧长可以被8整除
pad_fbank[:fbank.shape[0], :] = fbank
wav_data_lst = []
wav_data_lst.append(pad_fbank)
wav_lens = [len(data) for data in wav_data_lst]
wav_max_len = max(wav_lens)
new_wav_data_lst = np.zeros((len(wav_data_lst), wav_max_len, 200, 1))
wav_lens = np.array([leng//8 for leng in wav_lens])
new_wav_data_lst[0, :wav_data_lst[0].shape[0], :, 0] = wav_data_lst[0]
#new_wav_data_lst = tf.expand_dims(new_wav_data_lst, 0)#3d->4d
result = am.model.predict(new_wav_data_lst, steps=1)#steps预测周期结束前的总步骤数(样品批次),predict返回numpy数组类型的预测
_, text = decode_ctc(result, train_data.am_vocab) # num2pny
#_, text = decode_ctc(result, train_data.am_vocab) # num2pny
text = ' '.join(text) # 以空格为分隔符合将多元素列表text合并成一个字符串
print('文本结果:', text)
if (__name__ == '__main__'):
pass