Commit 5b0ac742 authored by 翟艳秋(20软)'s avatar 翟艳秋(20软)

1.[add] 为窗口添加icon;

2.[modified] 基于asr的输出结果中添加旁白字数推荐,规范起止时间为2位小数; 3.[modified] 调整音频合成部分临时音频的存储位置; 4.[modified] 为输出的表格添加自动换行
parent 945a2b39
Subproject commit 081f7807a2ce0e12b98e6f0a0da0e650133f2d9e
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
21 hdrcharset=BINARY
62 path=PaddlePaddle-DeepSpeech2/Paddle deepspeech安装.docx
30 mtime=1640747647.779457699
# DeepSpeech2 语音识别
![License](https://img.shields.io/badge/license-Apache%202-red.svg)
![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
![support os](https://img.shields.io/badge/os-linux-yellow.svg)
![GitHub Repo stars](https://img.shields.io/github/stars/yeyupiaoling/PaddlePaddle-DeepSpeech?style=social)
本项目是基于PaddlePaddle的[DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech) 项目开发的,做了较大的修改,方便训练中文自定义数据集,同时也方便测试和使用。DeepSpeech2是基于PaddlePaddle实现的端到端自动语音识别(ASR)引擎,其论文为[《Baidu's Deep Speech 2 paper》](http://proceedings.mlr.press/v48/amodei16.pdf) ,本项目同时还支持各种数据增强方法,以适应不同的使用场景。支持在Windows,Linux下训练和预测,支持Nvidia Jetson等开发板推理预测,该分支为新版本,如果要使用旧版本,请查看[release/1.0分支](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech/tree/release/1.0)
本项目使用的环境:
- Python 3.7
- PaddlePaddle 2.2.0
- Windows or Ubuntu
## 更新记录
- 2021.11.26: 修改集束解码bug。
- 2021.11.09: 提供WenetSpeech数据集制作脚本。
- 2021.09.05: 提供GUI界面识别部署。
- 2021.09.04: 提供三个公开数据的预训练模型。
- 2021.08.30: 支持中文数字转阿拉伯数字,具体请看[预测文档](docs/infer.md)
- 2021.08.29: 完成训练代码和预测代码,同时完善相关文档。
- 2021.08.07: 支持导出预测模型,使用预测模型进行推理。使用webrtcvad工具,实现长语音识别。
- 2021.08.06: 将项目大部分的代码修改为PaddlePaddle2.0之后的新API。
## 模型下载
| 数据集 | 卷积层数量 | 循环神经网络的数量 | 循环神经网络的大小 | 测试集字错率 | 下载地址 |
| :---: | :---: | :---: | :---: | :---: | :---: |
| aishell(179小时) | 2 | 3 | 1024 | 0.084532 | [点击下载](https://download.csdn.net/download/qq_33200967/21773253) |
| free_st_chinese_mandarin_corpus(109小时) | 2 | 3 | 1024 | 0.170260 | [点击下载](https://download.csdn.net/download/qq_33200967/21866900) |
| thchs_30(34小时) | 2 | 3 | 1024 | 0.026838 | [点击下载](https://download.csdn.net/download/qq_33200967/21774247) |
| 超大数据集(1600多小时真实数据)+(1300多小时合成数据) | 2 | 3 | 1024 | 训练中 | [训练中]() |
**说明:** 这里提供的是训练参数,如果要用于预测,还需要执行[导出模型](docs/export_model.md),使用的解码方法是集束搜索。
>有问题欢迎提 [issue](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech/issues) 交流
## 文档教程
- [快速安装](docs/install.md)
- [数据准备](docs/dataset.md)
- [WenetSpeech数据集](docs/wenetspeech.md)
- [合成语音数据](docs/generate_audio.md)
- [数据增强](docs/augment.md)
- [训练模型](docs/train.md)
- [集束搜索解码](docs/beam_search.md)
- [执行评估](docs/eval.md)
- [导出模型](docs/export_model.md)
- 预测
- [本地模型](docs/infer.md)
- [长语音模型](docs/infer.md)
- [Web部署模型](docs/infer.md)
- [Nvidia Jetson部署](docs/nvidia-jetson.md)
## 快速预测
- 下载作者提供的模型或者训练模型,然后执行[导出模型](docs/export_model.md),使用`infer_path.py`预测音频,通过参数`--wav_path`指定需要预测的音频路径,完成语音识别,详情请查看[模型部署](docs/infer.md)
```shell script
python infer_path.py --wav_path=./dataset/test.wav
```
输出结果:
```
----------- Configuration Arguments -----------
alpha: 1.2
beam_size: 10
beta: 0.35
cutoff_prob: 1.0
cutoff_top_n: 40
decoding_method: ctc_greedy
enable_mkldnn: False
is_long_audio: False
lang_model_path: ./lm/zh_giga.no_cna_cmn.prune01244.klm
mean_std_path: ./dataset/mean_std.npz
model_dir: ./models/infer/
to_an: True
use_gpu: True
use_tensorrt: False
vocab_path: ./dataset/zh_vocab.txt
wav_path: ./dataset/test.wav
------------------------------------------------
消耗时间:132, 识别结果: 近几年不但我用书给女儿儿压岁也劝说亲朋不要给女儿压岁钱而改送压岁书, 得分: 94
```
- 长语音预测
```shell script
python infer_path.py --wav_path=./dataset/test_vad.wav --is_long_audio=True
```
- Web部署
![录音测试页面](docs/images/infer_server.jpg)
- GUI界面部署
![GUI界面](docs/images/infer_gui.jpg)
## 相关项目
- 基于PaddlePaddle实现的声纹识别:[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)
- 基于PaddlePaddle动态图实现的语音识别:[PPASR](https://github.com/yeyupiaoling/PPASR)
- 基于Pytorch实现的语音识别:[MASR](https://github.com/yeyupiaoling/MASR)
[
{
"type": "noise",
"aug_type": "audio",
"params": {
"min_snr_dB": 10,
"max_snr_dB": 50,
"noise_manifest_path": "dataset/manifest.noise"
},
"prob": 0.5
},
{
"type": "speed",
"aug_type": "audio",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 1.0
},
{
"type": "shift",
"aug_type": "audio",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "volume",
"aug_type": "audio",
"params": {
"min_gain_dBFS": -15,
"max_gain_dBFS": 15
},
"prob": 1.0
},
{
"type": "specaug",
"aug_type": "feature",
"params": {
"W": 0,
"warp_mode": "PIL",
"F": 10,
"n_freq_masks": 2,
"T": 50,
"n_time_masks": 2,
"p": 1.0,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}
]
\ No newline at end of file
import argparse
import functools
import json
import os
import wave
from collections import Counter
from zhconv import convert
import numpy as np
import soundfile
from tqdm import tqdm
from data_utils.normalizer import FeatureNormalizer
from utils.utility import add_arguments, print_arguments, read_manifest, change_rate
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('annotation_path', str, 'dataset/annotation/', '标注文件的路径,如果annotation_path包含了test.txt,就全部使用test.txt的数据作为测试数据')
add_arg('manifest_prefix', str, 'dataset/', '训练数据清单,包括音频路径和标注信息')
add_arg('is_change_frame_rate', bool, True, '是否统一改变音频为16000Hz,这会消耗大量的时间')
add_arg('max_test_manifest', int, 10000, '最大的测试数据数量')
add_arg('count_threshold', int, 2, '字符计数的截断阈值,0为不做限制')
add_arg('vocab_path', str, 'dataset/zh_vocab.txt', '生成的数据字典文件')
add_arg('num_workers', int, 8, '读取数据的线程数量')
add_arg('manifest_paths', str, 'dataset/manifest.train', '数据列表路径')
add_arg('num_samples', int, 1000000, '用于计算均值和标准值得音频数量,当为-1使用全部数据')
add_arg('output_path', str, './dataset/mean_std.npz', '保存均值和标准值得numpy文件路径,后缀 (.npz).')
args = parser.parse_args()
# 创建数据列表
def create_manifest(annotation_path, manifest_path_prefix):
data_list = []
test_list = []
durations_all = []
duration_0_10 = 0
duration_10_20 = 0
duration_20 = 0
# 获取全部的标注文件
for annotation_text in os.listdir(annotation_path):
durations = []
print('正在创建%s的数据列表,请等待 ...' % annotation_text)
annotation_text_path = os.path.join(annotation_path, annotation_text)
# 读取标注文件
with open(annotation_text_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in tqdm(lines):
audio_path = line.split('\t')[0]
try:
# 过滤非法的字符
text = is_ustr(line.split('\t')[1].replace('\n', '').replace('\r', ''))
# 保证全部都是简体
text = convert(text, 'zh-cn')
# 重新调整音频格式并保存
if args.is_change_frame_rate:
change_rate(audio_path)
# 获取音频的长度
f_wave = wave.open(audio_path, "rb")
duration = f_wave.getnframes() / f_wave.getframerate()
if duration <= 10:
duration_0_10 += 1
elif 10 < duration <= 20:
duration_10_20 += 1
else:
duration_20 += 1
durations.append(duration)
d = json.dumps(
{
'audio_filepath': audio_path.replace('\\', '/'),
'duration': duration,
'text': text
},
ensure_ascii=False)
if annotation_text == 'test.txt':
test_list.append(d)
else:
data_list.append(d)
except Exception as e:
print(e)
continue
durations_all.append(sum(durations))
print("%s数据一共[%d]小时!" % (annotation_text, int(sum(durations) / 3600)))
print("0-10秒的数量:%d,10-20秒的数量:%d,大于20秒的数量:%d" % (duration_0_10, duration_10_20, duration_20))
# 将音频的路径,长度和标签写入到数据列表中
f_train = open(os.path.join(manifest_path_prefix, 'manifest.train'), 'w', encoding='utf-8')
f_test = open(os.path.join(manifest_path_prefix, 'manifest.test'), 'w', encoding='utf-8')
for line in test_list:
f_test.write(line + '\n')
interval = 500
if len(data_list) / 500 > args.max_test_manifest:
interval = len(data_list) // args.max_test_manifest
for i, line in enumerate(data_list):
if i % interval == 0 and i != 0:
if len(test_list) == 0:
f_test.write(line + '\n')
else:
f_train.write(line + '\n')
else:
f_train.write(line + '\n')
f_train.close()
f_test.close()
print("创建数量列表完成,全部数据一共[%d]小时!" % int(sum(durations_all) / 3600))
# 过滤非文字的字符
def is_ustr(in_str):
out_str = ''
for i in range(len(in_str)):
if is_uchar(in_str[i]):
out_str = out_str + in_str[i]
else:
out_str = out_str + ' '
return ''.join(out_str.split())
# 判断是否为文字字符
def is_uchar(uchar):
if u'\u4e00' <= uchar <= u'\u9fa5':
return True
if u'\u0030' <= uchar <= u'\u0039':
return False
if (u'\u0041' <= uchar <= u'\u005a') or (u'\u0061' <= uchar <= u'\u007a'):
return False
if uchar in ('-', ',', '.', '>', '?'):
return False
return False
# 生成噪声的数据列表
def create_noise(path='dataset/audio/noise', min_duration=30):
if not os.path.exists(path):
print('噪声音频文件为空,已跳过!')
return
json_lines = []
print('正在创建噪声数据列表,路径:%s,请等待 ...' % path)
for file in tqdm(os.listdir(path)):
audio_path = os.path.join(path, file)
try:
# 噪声的标签可以标记为空
text = ""
# 重新调整音频格式并保存
if args.is_change_frame_rate:
change_rate(audio_path)
f_wave = wave.open(audio_path, "rb")
duration = f_wave.getnframes() / f_wave.getframerate()
# 拼接音频
if duration < min_duration:
wav = soundfile.read(audio_path)[0]
data = wav
for i in range(int(min_duration / duration) + 1):
data = np.hstack([data, wav])
soundfile.write(audio_path, data, samplerate=16000)
f_wave = wave.open(audio_path, "rb")
duration = f_wave.getnframes() / f_wave.getframerate()
json_lines.append(
json.dumps(
{
'audio_filepath': audio_path.replace('\\', '/'),
'duration': duration,
'text': text
},
ensure_ascii=False))
except Exception as e:
continue
with open(os.path.join(args.manifest_prefix, 'manifest.noise'), 'w', encoding='utf-8') as f_noise:
for json_line in json_lines:
f_noise.write(json_line + '\n')
# 获取全部字符
def count_manifest(counter, manifest_path):
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
for char in line_json['text']:
counter.update(char)
# 计算数据集的均值和标准值
def compute_mean_std(manifest_path, num_samples, output_path):
# 随机取指定的数量计算平均值归一化
normalizer = FeatureNormalizer(mean_std_filepath=None,
manifest_path=manifest_path,
num_samples=num_samples,
num_workers=args.num_workers)
# 将计算的结果保存的文件中
normalizer.write_to_file(output_path)
print('计算的均值和标准值已保存在 %s!' % output_path)
def main():
print_arguments(args)
print('开始生成数据列表...')
create_manifest(annotation_path=args.annotation_path,
manifest_path_prefix=args.manifest_prefix)
print('='*70)
print('开始生成噪声数据列表...')
create_noise(path='dataset/audio/noise')
print('='*70)
print('开始生成数据字典...')
counter = Counter()
# 获取全部数据列表中的标签字符
count_manifest(counter, args.manifest_paths)
# 为每一个字符都生成一个ID
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
with open(args.vocab_path, 'w', encoding='utf-8') as fout:
fout.write('<blank>\t-1\n')
for char, count in count_sorted:
# 跳过指定的字符阈值,超过这大小的字符都忽略
if count < args.count_threshold: break
fout.write('%s\t%d\n' % (char, count))
print('数据词汇表已生成完成,保存与:%s' % args.vocab_path)
print('='*70)
print('开始抽取%s条数据计算均值和标准值...' % args.num_samples)
compute_mean_std(args.manifest_paths, args.num_samples, args.output_path)
print('='*70)
if __name__ == '__main__':
main()
"""Contains the audio segment class."""
import copy
import io
import random
import re
import struct
import numpy as np
import resampy
import soundfile
from scipy import signal
class AudioSegment(object):
"""Monaural audio segment abstraction.
:param samples: Audio samples [num_samples x num_channels].
:type samples: ndarray.float32
:param sample_rate: Audio sample rate.
:type sample_rate: int
:raises TypeError: If the sample data type is not float or int.
"""
def __init__(self, samples, sample_rate):
"""Create audio segment from samples.
Samples are convert float32 internally, with int scaled to [-1, 1].
"""
self._samples = self._convert_samples_to_float32(samples)
self._sample_rate = sample_rate
if self._samples.ndim >= 2:
self._samples = np.mean(self._samples, 1)
def __eq__(self, other):
"""返回两个对象是否相等"""
if type(other) is not type(self):
return False
if self._sample_rate != other._sample_rate:
return False
if self._samples.shape != other._samples.shape:
return False
if np.any(self.samples != other._samples):
return False
return True
def __ne__(self, other):
"""返回两个对象是否不相等"""
return not self.__eq__(other)
def __str__(self):
"""返回该音频的信息"""
return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
"rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate,
self.duration, self.rms_db))
@classmethod
def from_file(cls, file):
"""从音频文件创建音频段
:param filepath: 文件路径或文件对象
:type filepath: str|file
:return: 音频片段实例
:rtype: AudioSegment
"""
if isinstance(file, str) and re.findall(r".seqbin_\d+$", file):
return cls.from_sequence_file(file)
else:
samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate)
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""只加载一小段音频,而不需要将整个文件加载到内存中,这是非常浪费的。
:param file: 输入音频文件路径或文件对象
:type file: str|file
:param start: 开始时间,单位为秒。如果start是负的,则它从末尾开始计算。如果没有提供,这个函数将从最开始读取。
:type start: float
:param end: 结束时间,单位为秒。如果end是负的,则它从末尾开始计算。如果没有提供,默认的行为是读取到文件的末尾。
:type end: float
:return: AudioSegment输入音频文件的指定片的实例。
:rtype: AudioSegment
:raise ValueError: 如开始或结束的设定不正确,例如时间不允许。
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = duration if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("切片起始位置(%f s)越界" % start)
if end < 0.0:
raise ValueError("切片结束位置(%f s)越界" % end)
if start > end:
raise ValueError("切片开始位置(%f s)晚于切片结束位置(%f s)" % (start, end))
if end > duration:
raise ValueError("切片结束位置(%f s)越界(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod
def from_sequence_file(cls, filepath):
"""从序列文件创建音频段。序列文件是一个二进制文件,
包含多个音频文件的集合,头部中的几个头字节指示每个音频字节数据块的偏移量
The format is:
4 bytes (int, version),
4 bytes (int, num of utterance),
4 bytes (int, bytes per header),
[bytes_per_header*(num_utterance+1)] bytes (offsets for each audio),
audio_bytes_data_of_1st_utterance,
audio_bytes_data_of_2nd_utterance,
......
Sequence file name must end with ".seqbin". And the filename of the 5th
utterance's audio file in sequence file "xxx.seqbin" must be
"xxx.seqbin_5", with "5" indicating the utterance index within this
sequence file (starting from 1).
:param filepath: Filepath of sequence file.
:type filepath: str
:return: Audio segment instance.
:rtype: AudioSegment
"""
# parse filepath
matches = re.match(r"(.+\.seqbin)_(\d+)", filepath)
if matches is None:
raise IOError("File type of %s is not supported" % filepath)
filename = matches.group(1)
fileno = int(matches.group(2))
# read headers
f = io.open(filename, mode='rb', encoding='utf8')
version = f.read(4)
num_utterances = struct.unpack("i", f.read(4))[0]
bytes_per_header = struct.unpack("i", f.read(4))[0]
header_bytes = f.read(bytes_per_header * (num_utterances + 1))
header = [
struct.unpack("i", header_bytes[bytes_per_header * i:
bytes_per_header * (i + 1)])[0]
for i in range(num_utterances + 1)
]
# read audio bytes
f.seek(header[fileno - 1])
audio_bytes = f.read(header[fileno] - header[fileno - 1])
f.close()
# create audio segment
try:
return cls.from_bytes(audio_bytes)
except Exception as e:
samples = np.frombuffer(audio_bytes, dtype='int16')
return cls(samples=samples, sample_rate=8000)
@classmethod
def from_bytes(cls, bytes):
"""从包含音频样本的字节字符串创建音频段
:param bytes: Byte string containing audio samples.
:type bytes: str
:return: Audio segment instance.
:rtype: AudioSegment
"""
samples, sample_rate = soundfile.read(io.BytesIO(bytes), dtype='float32')
return cls(samples, sample_rate)
@classmethod
def concatenate(cls, *segments):
"""将任意数量的音频片段连接在一起
:param *segments: Input audio segments to be concatenated.
:type *segments: tuple of AudioSegment
:return: Audio segment instance as concatenating results.
:rtype: AudioSegment
:raises ValueError: If the number of segments is zero, or if the
sample_rate of any segments does not match.
:raises TypeError: If any segment is not AudioSegment instance.
"""
# Perform basic sanity-checks.
if len(segments) == 0:
raise ValueError("没有音频片段被给予连接")
sample_rate = segments[0]._sample_rate
for seg in segments:
if sample_rate != seg._sample_rate:
raise ValueError("能用不同的采样率连接片段")
if type(seg) is not cls:
raise TypeError("只有相同类型的音频片段可以连接")
samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""创建给定持续时间和采样率的静音音频段
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def to_wav_file(self, filepath, dtype='float32'):
"""保存音频段到磁盘为wav文件
:param filepath: WAV filepath or file object to save the
audio segment.
:type filepath: str|file
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:raises TypeError: If dtype is not supported.
"""
samples = self._convert_samples_from_float32(self._samples, dtype)
subtype_map = {
'int16': 'PCM_16',
'int32': 'PCM_32',
'float32': 'FLOAT',
'float64': 'DOUBLE'
}
soundfile.write(
filepath,
samples,
self._sample_rate,
format='WAV',
subtype=subtype_map[dtype])
def superimpose(self, other):
"""将另一个段的样本添加到这个段的样本中(以样本方式添加,而不是段连接)。
Note that this is an in-place transformation.
:param other: Segment containing samples to be added in.
:type other: AudioSegments
:raise TypeError: If type of two segments don't match.
:raise ValueError: If the sample rates of the two segments are not
equal, or if the lengths of segments don't match.
"""
if isinstance(other, type(self)):
raise TypeError("不能添加不同类型的段: %s 和 %s" % (type(self), type(other)))
if self._sample_rate != other._sample_rate:
raise ValueError("采样率必须匹配才能添加片段")
if len(self._samples) != len(other._samples):
raise ValueError("段长度必须匹配才能添加段")
self._samples += other._samples
def to_bytes(self, dtype='float32'):
"""创建包含音频内容的字节字符串
:param dtype: Data type for export samples. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:return: Byte string containing audio content.
:rtype: str
"""
samples = self._convert_samples_from_float32(self._samples, dtype)
return samples.tostring()
def gain_db(self, gain):
"""对音频施加分贝增益。
Note that this is an in-place transformation.
:param gain: Gain in decibels to apply to samples.
:type gain: float|1darray
"""
self._samples *= 10.**(gain / 20.)
def change_speed(self, speed_rate):
"""通过线性插值改变音频速度
Note that this is an in-place transformation.
:param speed_rate: Rate of speed change:
speed_rate > 1.0, speed up the audio;
speed_rate = 1.0, unchanged;
speed_rate < 1.0, slow down the audio;
speed_rate <= 0.0, not allowed, raise ValueError.
:type speed_rate: float
:raises ValueError: If speed_rate <= 0.0.
"""
if speed_rate <= 0:
raise ValueError("速度速率应大于零")
old_length = self._samples.shape[0]
new_length = int(old_length / speed_rate)
old_indices = np.arange(old_length)
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
self._samples = np.interp(new_indices, old_indices, self._samples)
def normalize(self, target_db=-20, max_gain_db=300.0):
"""将音频归一化,使其具有所需的有效值(以分贝为单位)
Note that this is an in-place transformation.
:param target_db: Target RMS value in decibels. This value should be
less than 0.0 as 0.0 is full-scale audio.
:type target_db: float
:param max_gain_db: Max amount of gain in dB that can be applied for
normalization. This is to prevent nans when
attempting to normalize a signal consisting of
all zeros.
:type max_gain_db: float
:raises ValueError: If the required gain to normalize the segment to
the target_db value exceeds max_gain_db.
"""
gain = target_db - self.rms_db
if gain > max_gain_db:
raise ValueError(
"无法将段规范化到 %f dB,因为可能的增益已经超过max_gain_db (%f dB)" % (target_db, max_gain_db))
self.gain_db(min(max_gain_db, target_db - self.rms_db))
def resample(self, target_sample_rate, filter='kaiser_best'):
"""按目标采样率重新采样音频
Note that this is an in-place transformation.
:param target_sample_rate: Target sample rate.
:type target_sample_rate: int
:param filter: The resampling filter to use one of {'kaiser_best',
'kaiser_fast'}.
:type filter: str
"""
self._samples = resampy.resample(self.samples, self.sample_rate, target_sample_rate, filter=filter)
self._sample_rate = target_sample_rate
def pad_silence(self, duration, sides='both'):
"""在这个音频样本上加一段静音
Note that this is an in-place transformation.
:param duration: Length of silence in seconds to pad.
:type duration: float
:param sides: Position for padding:
'beginning' - adds silence in the beginning;
'end' - adds silence in the end;
'both' - adds silence in both the beginning and the end.
:type sides: str
:raises ValueError: If sides is not supported.
"""
if duration == 0.0:
return self
cls = type(self)
silence = self.make_silence(duration, self._sample_rate)
if sides == "beginning":
padded = cls.concatenate(silence, self)
elif sides == "end":
padded = cls.concatenate(self, silence)
elif sides == "both":
padded = cls.concatenate(silence, self, silence)
else:
raise ValueError("Unknown value for the sides %s" % sides)
self._samples = padded._samples
def shift(self, shift_ms):
"""音频偏移。如果shift_ms为正,则随时间提前移位;如果为负,则随时间延迟移位。填补静音以保持持续时间不变。
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if abs(shift_ms) / 1000.0 > self.duration:
raise ValueError("shift_ms的绝对值应该小于音频持续时间")
shift_samples = int(shift_ms * self._sample_rate / 1000)
if shift_samples > 0:
# time advance
self._samples[:-shift_samples] = self._samples[shift_samples:]
self._samples[-shift_samples:] = 0
elif shift_samples < 0:
# time delay
self._samples[-shift_samples:] = self._samples[:shift_samples]
self._samples[:-shift_samples] = 0
def subsegment(self, start_sec=None, end_sec=None):
"""在给定的边界之间切割音频片段
Note that this is an in-place transformation.
:param start_sec: Beginning of subsegment in seconds.
:type start_sec: float
:param end_sec: End of subsegment in seconds.
:type end_sec: float
:raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
of bounds in time.
"""
start_sec = 0.0 if start_sec is None else start_sec
end_sec = self.duration if end_sec is None else end_sec
if start_sec < 0.0:
start_sec = self.duration + start_sec
if end_sec < 0.0:
end_sec = self.duration + end_sec
if start_sec < 0.0:
raise ValueError("切片起始位置(%f s)越界" % start_sec)
if end_sec < 0.0:
raise ValueError("切片结束位置(%f s)越界" % end_sec)
if start_sec > end_sec:
raise ValueError("切片的起始位置(%f s)晚于结束位置(%f s)" % (start_sec, end_sec))
if end_sec > self.duration:
raise ValueError("切片结束位置(%f s)越界(> %f s)" % (end_sec, self.duration))
start_sample = int(round(start_sec * self._sample_rate))
end_sample = int(round(end_sec * self._sample_rate))
self._samples = self._samples[start_sample:end_sample]
def random_subsegment(self, subsegment_length, rng=None):
"""随机剪切指定长度的音频片段
Note that this is an in-place transformation.
:param subsegment_length: Subsegment length in seconds.
:type subsegment_length: float
:param rng: Random number generator state.
:type rng: random.Random
:raises ValueError: If the length of subsegment is greater than
the origineal segemnt.
"""
rng = random.Random() if rng is None else rng
if subsegment_length > self.duration:
raise ValueError("Length of subsegment must not be greater "
"than original segment.")
start_time = rng.uniform(0.0, self.duration - subsegment_length)
self.subsegment(start_time, start_time + subsegment_length)
def convolve(self, impulse_segment, allow_resample=False):
"""将这个音频段与给定的脉冲段进行卷积
Note that this is an in-place transformation.
:param impulse_segment: Impulse response segments.
:type impulse_segment: AudioSegment
:param allow_resample: Indicates whether resampling is allowed when
the impulse_segment has a different sample
rate from this signal.
:type allow_resample: bool
:raises ValueError: If the sample rate is not match between two
audio segments when resample is not allowed.
"""
if allow_resample and self.sample_rate != impulse_segment.sample_rate:
impulse_segment.resample(self.sample_rate)
if self.sample_rate != impulse_segment.sample_rate:
raise ValueError("脉冲段采样率(%d Hz)不等于基信号采样率(%d Hz)" %
(impulse_segment.sample_rate, self.sample_rate))
samples = signal.fftconvolve(self.samples, impulse_segment.samples,
"full")
self._samples = samples
def convolve_and_normalize(self, impulse_segment, allow_resample=False):
"""对所产生的音频段进行卷积并归一化,使其具有与输入信号相同的平均功率
Note that this is an in-place transformation.
:param impulse_segment: Impulse response segments.
:type impulse_segment: AudioSegment
:param allow_resample: Indicates whether resampling is allowed when
the impulse_segment has a different sample
rate from this signal.
:type allow_resample: bool
"""
target_db = self.rms_db
self.convolve(impulse_segment, allow_resample=allow_resample)
self.normalize(target_db)
def add_noise(self,
noise,
snr_dB,
allow_downsampling=False,
max_gain_db=300.0,
rng=None):
"""以特定的信噪比添加给定的噪声段。如果噪声段比该噪声段长,则从该噪声段中采样匹配长度的随机子段。
Note that this is an in-place transformation.
:param noise: Noise signal to add.
:type noise: AudioSegment
:param snr_dB: Signal-to-Noise Ratio, in decibels.
:type snr_dB: float
:param allow_downsampling: Whether to allow the noise signal to be
downsampled to match the base signal sample
rate.
:type allow_downsampling: bool
:param max_gain_db: Maximum amount of gain to apply to noise signal
before adding it in. This is to prevent attempting
to apply infinite gain to a zero signal.
:type max_gain_db: float
:param rng: Random number generator state.
:type rng: None|random.Random
:raises ValueError: If the sample rate does not match between the two
audio segments when downsampling is not allowed, or
if the duration of noise segments is shorter than
original audio segments.
"""
rng = random.Random() if rng is None else rng
if allow_downsampling and noise.sample_rate > self.sample_rate:
noise = noise.resample(self.sample_rate)
if noise.sample_rate != self.sample_rate:
raise ValueError("噪声采样率(%d Hz)不等于基信号采样率(%d Hz)" % (noise.sample_rate, self.sample_rate))
if noise.duration < self.duration:
raise ValueError("噪声信号(%f秒)必须至少与基信号(%f秒)一样长" % (noise.duration, self.duration))
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
noise_new = copy.deepcopy(noise)
noise_new.random_subsegment(self.duration, rng=rng)
noise_new.gain_db(noise_gain_db)
self.superimpose(noise_new)
@property
def samples(self):
"""返回音频样本
:return: Audio samples.
:rtype: ndarray
"""
return self._samples.copy()
@property
def sample_rate(self):
"""返回音频采样率
:return: Audio sample rate.
:rtype: int
"""
return self._sample_rate
@property
def num_samples(self):
"""返回样品数量
:return: Number of samples.
:rtype: int
"""
return self._samples.shape[0]
@property
def duration(self):
"""返回音频持续时间
:return: Audio duration in seconds.
:rtype: float
"""
return self._samples.shape[0] / float(self._sample_rate)
@property
def rms_db(self):
"""返回以分贝为单位的音频均方根能量
:return: Root mean square energy in decibels.
:rtype: float
"""
# square root => multiply by 10 instead of 20 for dBs
mean_square = np.mean(self._samples ** 2)
return 10 * np.log10(mean_square)
def _convert_samples_to_float32(self, samples):
"""Convert sample type to float32.
Audio sample type is usually integer or float-point.
Integers will be scaled to [-1, 1] in float32.
"""
float32_samples = samples.astype('float32')
if samples.dtype in np.sctypes['int']:
bits = np.iinfo(samples.dtype).bits
float32_samples *= (1. / 2 ** (bits - 1))
elif samples.dtype in np.sctypes['float']:
pass
else:
raise TypeError("Unsupported sample type: %s." % samples.dtype)
return float32_samples
def _convert_samples_from_float32(self, samples, dtype):
"""Convert sample type from float32 to dtype.
Audio sample type is usually integer or float-point. For integer
type, float32 will be rescaled from [-1, 1] to the maximum range
supported by the integer type.
This is for writing a audio file.
"""
dtype = np.dtype(dtype)
output_samples = samples.copy()
if dtype in np.sctypes['int']:
bits = np.iinfo(dtype).bits
output_samples *= (2 ** (bits - 1) / 1.)
min_val = np.iinfo(dtype).min
max_val = np.iinfo(dtype).max
output_samples[output_samples > max_val] = max_val
output_samples[output_samples < min_val] = min_val
elif samples.dtype in np.sctypes['float']:
min_val = np.finfo(dtype).min
max_val = np.finfo(dtype).max
output_samples[output_samples > max_val] = max_val
output_samples[output_samples < min_val] = min_val
else:
raise TypeError("Unsupported sample type: %s." % samples.dtype)
return output_samples.astype(dtype)
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.normalizer import FeatureNormalizer
from data_utils.speech import SpeechSegment
class AudioInferProcess(object):
"""
识别程序所使用的是对音频预处理的工具
:param vocab_filepath: 词汇表文件路径
:type vocab_filepath: str
:param mean_std_filepath: 平均值和标准差的文件路径
:type mean_std_filepath: str
:param stride_ms: 生成帧的跨步大小(以毫秒为单位)
:type stride_ms: float
:param window_ms: 用于生成帧的窗口大小(毫秒)
:type window_ms: float
:param use_dB_normalization: 提取特征前是否将音频归一化至-20 dB
:type use_dB_normalization: bool
"""
def __init__(self,
vocab_filepath,
mean_std_filepath,
stride_ms=10.0,
window_ms=20.0,
use_dB_normalization=True):
self._normalizer = FeatureNormalizer(mean_std_filepath)
self._speech_featurizer = SpeechFeaturizer(vocab_filepath=vocab_filepath,
stride_ms=stride_ms,
window_ms=window_ms,
use_dB_normalization=use_dB_normalization)
def process_utterance(self, audio_file):
"""对语音数据加载、预处理
:param audio_file: 音频文件的文件路径或文件对象
:type audio_file: str | file
:return: 预处理的音频数据
:rtype: 2darray
"""
speech_segment = SpeechSegment.from_file(audio_file, "")
specgram, _ = self._speech_featurizer.featurize(speech_segment, False)
specgram = self._normalizer.apply(specgram)
return specgram
@property
def vocab_size(self):
"""返回词汇表大小
:return: 词汇表大小
:rtype: int
"""
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
"""返回词汇表列表
:return: 词汇表列表
:rtype: list
"""
return self._speech_featurizer.vocab_list
"""Contains the data augmentation pipeline."""
import json
import os
import random
import sys
from datetime import datetime
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor
from data_utils.augmentor.spec_augment import SpecAugmentor
from data_utils.augmentor.resample import ResampleAugmentor
class AugmentationPipeline(object):
"""Build a pre-processing pipeline with various augmentation models.Such a
data augmentation pipeline is oftern leveraged to augment the training
samples to make the model invariant to certain types of perturbations in the
real world, improving model's generalization ability.
The pipeline is built according the the augmentation configuration in json
string, e.g.
.. code-block::
[
{
"type": "noise",
"params": {
"min_snr_dB": 10,
"max_snr_dB": 50,
"noise_manifest_path": "dataset/manifest.noise"
},
"prob": 0.5
},
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 1.0
},
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "volume",
"params": {
"min_gain_dBFS": -15,
"max_gain_dBFS": 15
},
"prob": 1.0
},
{
"type": "specaug",
"params": {
"W": 0,
"warp_mode": "PIL",
"F": 10,
"n_freq_masks": 2,
"T": 50,
"n_time_masks": 2,
"p": 1.0,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}
]
This augmentation configuration inserts two augmentation models
into the pipeline, with one is VolumePerturbAugmentor and the other
SpeedPerturbAugmentor. "prob" indicates the probability of the current
augmentor to take effect. If "prob" is zero, the augmentor does not take
effect.
:param augmentation_config: Augmentation configuration in json string.
:type augmentation_config: str
:param random_seed: Random seed.
:type random_seed: int
:raises ValueError: If the augmentation json config is in incorrect format".
"""
def __init__(self, augmentation_config, random_seed=0):
self._rng = random.Random(random_seed)
self._augmentors, self._rates = self._parse_pipeline_from(augmentation_config, aug_type='audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(augmentation_config, aug_type='feature')
def transform_audio(self, audio_segment):
"""Run the pre-processing pipeline for data augmentation.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
for augmentor, rate in zip(self._augmentors, self._rates):
if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment)
def transform_feature(self, spec_segment):
"""spectrogram augmentation.
Args:
spec_segment (np.ndarray): audio feature, (D, T).
"""
for augmentor, rate in zip(self._spec_augmentors, self._spec_rates):
if self._rng.uniform(0., 1.) < rate:
spec_segment = augmentor.transform_feature(spec_segment)
return spec_segment
def _parse_pipeline_from(self, config_json, aug_type):
"""Parse the config json to build a augmentation pipelien."""
try:
configs = []
configs_temp = json.loads(config_json)
for config in configs_temp:
if config['aug_type'] != aug_type: continue
if config['type'] == 'noise' and not os.path.exists(config['params']['noise_manifest_path']):
print('%s不存在,已经忽略噪声增强操作!' % config['params']['noise_manifest_path'], file=sys.stderr)
continue
print('[%s] 数据增强配置:%s' % (datetime.now(), config))
configs.append(config)
augmentors = [self._get_augmentor(config["type"], config["params"]) for config in configs]
rates = [config["prob"] for config in configs]
except Exception as e:
raise ValueError("Failed to parse the augmentation config json: %s" % str(e))
return augmentors, rates
def _get_augmentor(self, augmentor_type, params):
"""Return an augmentation model by the type name, and pass in params."""
if augmentor_type == "volume":
return VolumePerturbAugmentor(self._rng, **params)
elif augmentor_type == "shift":
return ShiftPerturbAugmentor(self._rng, **params)
elif augmentor_type == "speed":
return SpeedPerturbAugmentor(self._rng, **params)
elif augmentor_type == "resample":
return ResampleAugmentor(self._rng, **params)
elif augmentor_type == "noise":
return NoisePerturbAugmentor(self._rng, **params)
elif augmentor_type == "specaug":
return SpecAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
"""Contains the abstract base class for augmentation models."""
from abc import ABCMeta, abstractmethod
class AugmentorBase(object):
"""Abstract base class for augmentation model (augmentor) class.
All augmentor classes should inherit from this class, and implement the
following abstract methods.
"""
__metaclass__ = ABCMeta
@abstractmethod
def __init__(self):
pass
@abstractmethod
def transform_audio(self, audio_segment):
"""Adds various effects to the input audio segment. Such effects
will augment the training data to make the model invariant to certain
types of perturbations in the real world, improving model's
generalization ability.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
pass
"""Contains the noise perturb augmentation model."""
from data_utils.augmentor.base import AugmentorBase
from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
class NoisePerturbAugmentor(AugmentorBase):
"""用于添加背景噪声的增强模型
:param rng: Random generator object.
:type rng: random.Random
:param min_snr_dB: Minimal signal noise ratio, in decibels.
:type min_snr_dB: float
:param max_snr_dB: Maximal signal noise ratio, in decibels.
:type max_snr_dB: float
:param noise_manifest_path: Manifest path for noise audio data.
:type noise_manifest_path: str
"""
def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest_path):
self._min_snr_dB = min_snr_dB
self._max_snr_dB = max_snr_dB
self._rng = rng
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
def transform_audio(self, audio_segment):
"""Add background noise audio.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
noise_json = self._rng.sample(self._noise_manifest, 1)[0]
if noise_json['duration'] >= audio_segment.duration:
diff_duration = noise_json['duration'] - audio_segment.duration
start = self._rng.uniform(0, diff_duration)
end = start + audio_segment.duration
noise_segment = AudioSegment.slice_from_file(noise_json['audio_filepath'], start=start, end=end)
snr_dB = self._rng.uniform(self._min_snr_dB, self._max_snr_dB)
audio_segment.add_noise(noise_segment, snr_dB, allow_downsampling=True, rng=self._rng)
"""Contain the resample augmentation model."""
from data_utils.augmentor.base import AugmentorBase
class ResampleAugmentor(AugmentorBase):
"""重采样的增强模型
See more info here:
https://ccrma.stanford.edu/~jos/resample/index.html
:param rng: Random generator object.
:type rng: random.Random
:param new_sample_rate: New sample rate in Hz.
:type new_sample_rate: int
"""
def __init__(self, rng, new_sample_rate):
self._new_sample_rate = new_sample_rate
self._rng = rng
def transform_audio(self, audio_segment):
"""Resamples the input audio to a target sample rate.
Note that this is an in-place transformation.
:param audio: Audio segment to add effects to.
:type audio: AudioSegment|SpeechSegment
"""
audio_segment.resample(self._new_sample_rate)
"""Contains the volume perturb augmentation model."""
from data_utils.augmentor.base import AugmentorBase
class ShiftPerturbAugmentor(AugmentorBase):
"""添加随机位移扰动的增强模型
:param rng: Random generator object.
:type rng: random.Random
:param min_shift_ms: Minimal shift in milliseconds.
:type min_shift_ms: float
:param max_shift_ms: Maximal shift in milliseconds.
:type max_shift_ms: float
"""
def __init__(self, rng, min_shift_ms, max_shift_ms):
self._min_shift_ms = min_shift_ms
self._max_shift_ms = max_shift_ms
self._rng = rng
def transform_audio(self, audio_segment):
"""Shift audio.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
audio_segment.shift(shift_ms)
import random
import numpy as np
from PIL import Image
from PIL.Image import BICUBIC
from data_utils.augmentor.base import AugmentorBase
class SpecAugmentor(AugmentorBase):
"""Augmentation model for Time warping, Frequency masking, Time masking.
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
https://arxiv.org/abs/1904.08779
SpecAugment on Large Scale Datasets
https://arxiv.org/abs/1912.05533
"""
def __init__(self,
rng,
F,
T,
n_freq_masks,
n_time_masks,
p=1.0,
W=40,
adaptive_number_ratio=0,
adaptive_size_ratio=0,
max_n_time_masks=20,
replace_with_zero=True,
warp_mode='PIL'):
"""SpecAugment class.
Args:
rng (random.Random): random generator object.
F (int): parameter for frequency masking
T (int): parameter for time masking
n_freq_masks (int): number of frequency masks
n_time_masks (int): number of time masks
p (float): parameter for upperbound of the time mask
W (int): parameter for time warping
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
adaptive_size_ratio (float): adaptive size ratio for time masking
max_n_time_masks (int): maximum number of time masking
replace_with_zero (bool): pad zero on mask if true else use mean
warp_mode (str): "PIL" (default, fast, not differentiable)
or "sparse_image_warp" (slow, differentiable)
"""
super().__init__()
self._rng = rng
self.inplace = True
self.replace_with_zero = replace_with_zero
self.mode = warp_mode
self.W = W
self.F = F
self.T = T
self.n_freq_masks = n_freq_masks
self.n_time_masks = n_time_masks
self.p = p
# adaptive SpecAugment
self.adaptive_number_ratio = adaptive_number_ratio
self.adaptive_size_ratio = adaptive_size_ratio
self.max_n_time_masks = max_n_time_masks
if adaptive_number_ratio > 0:
self.n_time_masks = 0
if adaptive_size_ratio > 0:
self.T = 0
self._freq_mask = None
self._time_mask = None
@property
def freq_mask(self):
return self._freq_mask
@property
def time_mask(self):
return self._time_mask
def __repr__(self):
return f"specaug: F-{self.F}, T-{self.T}, F-n-{self.n_freq_masks}, T-n-{self.n_time_masks}"
def time_warp(self, x, mode='PIL'):
"""time warp for spec augment
move random center frame by the random width ~ uniform(-window, window)
Args:
x (np.ndarray): spectrogram (time, freq)
mode (str): PIL or sparse_image_warp
Raises:
NotImplementedError: [description]
NotImplementedError: [description]
Returns:
np.ndarray: time warped spectrogram (time, freq)
"""
window = max_time_warp = self.W
if window == 0:
return x
if mode == "PIL":
t = x.shape[0]
if t - window <= window:
return x
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
center = random.randrange(window, t - window)
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
if self.inplace:
x[:warped] = left
x[warped:] = right
return x
return np.concatenate((left, right), 0)
elif mode == "sparse_image_warp":
raise NotImplementedError('sparse_image_warp')
else:
raise NotImplementedError(
"unknown resize mode: " + mode +
", choose one from (PIL, sparse_image_warp).")
def mask_freq(self, x, replace_with_zero=False):
"""freq mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
Returns:
np.ndarray: freq mask spectrogram (time, freq)
"""
n_bins = x.shape[1]
for i in range(0, self.n_freq_masks):
f = int(self._rng.uniform(a=0, b=self.F))
f_0 = int(self._rng.uniform(a=0, b=n_bins - f))
assert f_0 <= f_0 + f
if replace_with_zero:
x[:, f_0:f_0 + f] = 0
else:
x[:, f_0:f_0 + f] = x.mean()
self._freq_mask = (f_0, f_0 + f)
return x
def mask_time(self, x, replace_with_zero=False):
"""time mask
Args:
x (np.ndarray): spectrogram (time, freq)
replace_with_zero (bool, optional): Defaults to False.
Returns:
np.ndarray: time mask spectrogram (time, freq)
"""
n_frames = x.shape[0]
if self.adaptive_number_ratio > 0:
n_masks = int(n_frames * self.adaptive_number_ratio)
n_masks = min(n_masks, self.max_n_time_masks)
else:
n_masks = self.n_time_masks
if self.adaptive_size_ratio > 0:
T = self.adaptive_size_ratio * n_frames
else:
T = self.T
for i in range(n_masks):
t = int(self._rng.uniform(a=0, b=T))
t = min(t, int(n_frames * self.p))
t_0 = int(self._rng.uniform(a=0, b=n_frames - t))
assert t_0 <= t_0 + t
if replace_with_zero:
x[t_0:t_0 + t, :] = 0
else:
x[t_0:t_0 + t, :] = x.mean()
self._time_mask = (t_0, t_0 + t)
return x
def __call__(self, x, train=True):
if not train:
return x
return self.transform_feature(x)
def transform_feature(self, x: np.ndarray):
"""
Args:
x (np.ndarray): `[T, F]`
Returns:
x (np.ndarray): `[T, F]`
"""
assert isinstance(x, np.ndarray)
assert x.ndim == 2
x = self.time_warp(x, self.mode)
x = self.mask_freq(x, self.replace_with_zero)
x = self.mask_time(x, self.replace_with_zero)
return x
"""Contain the speech perturbation augmentation model."""
import numpy as np
from data_utils.augmentor.base import AugmentorBase
class SpeedPerturbAugmentor(AugmentorBase):
"""添加速度扰动的增强模型
See reference paper here:
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
:param rng: Random generator object.
:type rng: random.Random
:param min_speed_rate: Lower bound of new speed rate to sample and should
not be smaller than 0.9.
:type min_speed_rate: float
:param max_speed_rate: Upper bound of new speed rate to sample and should
not be larger than 1.1.
:type max_speed_rate: float
"""
def __init__(self, rng, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=3):
if min_speed_rate < 0.9:
raise ValueError("Sampling speed below 0.9 can cause unnatural effects")
if max_speed_rate > 1.1:
raise ValueError("Sampling speed above 1.1 can cause unnatural effects")
self._min_speed_rate = min_speed_rate
self._max_speed_rate = max_speed_rate
self._rng = rng
self._num_rates = num_rates
if num_rates > 0:
self._rates = np.linspace(self._min_speed_rate, self._max_speed_rate, self._num_rates, endpoint=True)
def transform_audio(self, audio_segment):
"""Sample a new speed rate from the given range and
changes the speed of the given audio clip.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegment|SpeechSegment
"""
if self._num_rates < 0:
speed_rate = self._rng.uniform(self._min_speed_rate, self._max_speed_rate)
else:
speed_rate = self._rng.choice(self._rates)
if speed_rate == 1.0: return
audio_segment.change_speed(speed_rate)
"""Contains the volume perturb augmentation model."""
from data_utils.augmentor.base import AugmentorBase
class VolumePerturbAugmentor(AugmentorBase):
"""添加随机体积扰动的增强模型
This is used for multi-loudness training of PCEN. See
https://arxiv.org/pdf/1607.05666v1.pdf
for more details.
:param rng: Random generator object.
:type rng: random.Random
:param min_gain_dBFS: Minimal gain in dBFS.
:type min_gain_dBFS: float
:param max_gain_dBFS: Maximal gain in dBFS.
:type max_gain_dBFS: float
"""
def __init__(self, rng, min_gain_dBFS, max_gain_dBFS):
self._min_gain_dBFS = min_gain_dBFS
self._max_gain_dBFS = max_gain_dBFS
self._rng = rng
def transform_audio(self, audio_segment):
"""Change audio loadness.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
audio_segment.gain_db(gain)
"""Contains data generator for orgnaizing various audio data preprocessing
pipeline and offering data reader interface of PaddlePaddle_DeepSpeech2 requirements.
"""
import random
import numpy as np
import paddle
import paddle.fluid as fluid
from threading import local
from data_utils.utility import read_manifest
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.speech import SpeechSegment
from data_utils.normalizer import FeatureNormalizer
class DataGenerator(object):
"""
DataGenerator provides basic audio data preprocessing pipeline, and offers
data reader interfaces of PaddlePaddle_DeepSpeech2 requirements.
:param vocab_filepath: Vocabulary filepath for indexing tokenized
transcripts.
:type vocab_filepath: str
:param mean_std_filepath: File containing the pre-computed mean and stddev.
:type mean_std_filepath: None|str
:param augmentation_config: Augmentation configuration in json string.
Details see AugmentationPipeline.__doc__.
:type augmentation_config: str
:param max_duration: Audio with duration (in seconds) greater than
this will be discarded.
:type max_duration: float
:param min_duration: Audio with duration (in seconds) smaller than
this will be discarded.
:type min_duration: float
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param use_dB_normalization: Whether to normalize the audio to -20 dB
before extracting the features.
:type use_dB_normalization: bool
:param random_seed: Random seed.
:type random_seed: int
:param keep_transcription_text: If set to True, transcription text will
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
:param place: The place to run the program.
:type place: CPUPlace or CUDAPlace
:param is_training: If set to True, generate text data for training,
otherwise, generate text data for infer.
:type is_training: bool
"""
def __init__(self,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False,
place=paddle.CPUPlace(),
is_training=True):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
self._augmentation_pipeline = AugmentationPipeline(augmentation_config=augmentation_config,
random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer(vocab_filepath=vocab_filepath,
stride_ms=stride_ms,
window_ms=window_ms,
use_dB_normalization=use_dB_normalization)
self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
self.epoch = 0
self._is_training = is_training
# for caching tar files info
self._local_data = local()
self._local_data.tar2info = {}
self._local_data.tar2object = {}
self._place = place
def process_utterance(self, audio_file, transcript):
"""对语音数据加载、扩充、特征化和归一化
:param audio_file: 音频文件的文件路径或文件对象
:type audio_file: str | file
:param transcript: 音频对应的文本
:type transcript: str
:return: 经过归一化等预处理的音频数据,音频文件对应文本的ID
:rtype: tuple of (2darray, list)
"""
speech_segment = SpeechSegment.from_file(audio_file, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram)
specgram = self._augmentation_pipeline.transform_feature(specgram)
return specgram, transcript_part
def batch_reader_creator(self,
manifest_path,
batch_size,
padding_to=-1,
flatten=False,
shuffle_method="batch_shuffle"):
"""
Batch data reader creator for audio data. Return a callable generator
function to produce batches of data.
Audio features within one batch will be padded with zeros to have the
same shape, or a user-defined shape.
:param manifest_path: Filepath of manifest for audio files.
:type manifest_path: str
:param batch_size: Number of instances in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun shape in the batch
will be used as the target shape for padding.
Otherwise, `padding_to` will be the target shape.
:type padding_to: int
:param flatten: If set True, audio features will be flatten to 1darray.
:type flatten: bool
:param shuffle_method: Shuffle method. Options:
'' or None: no shuffle.
'instance_shuffle': instance-wise shuffle.
'batch_shuffle': similarly-sized instances are
put into batches, and then
batch-wise shuffle the batches.
For more details, please see
``_batch_shuffle.__doc__``.
'batch_shuffle_clipped': 'batch_shuffle' with
head shift and tail
clipping. For more
details, please see
``_batch_shuffle``.
If sortagrad is True, shuffle is disabled
for the first epoch.
:type shuffle_method: None|str
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def batch_reader():
# 读取数据列表
manifest = read_manifest(manifest_path=manifest_path,
max_duration=self._max_duration,
min_duration=self._min_duration)
# 将数据列表长到短排序
if self.epoch == 0:
manifest.sort(key=lambda x: x["duration"], reverse=False)
else:
if shuffle_method == "batch_shuffle":
manifest = self._batch_shuffle(manifest, batch_size, clipped=False)
elif shuffle_method == "batch_shuffle_clipped":
manifest = self._batch_shuffle(manifest, batch_size, clipped=True)
elif shuffle_method == "instance_shuffle":
self._rng.shuffle(manifest)
elif shuffle_method is None:
pass
else:
raise ValueError("Unknown shuffle method %s." % shuffle_method)
# 准备批量数据
batch = []
instance_reader = self._instance_reader_creator(manifest)
for instance in instance_reader():
batch.append(instance)
if len(batch) == batch_size:
yield self._padding_batch(batch, padding_to, flatten)
batch = []
if len(batch) >= 1:
yield self._padding_batch(batch, padding_to, flatten)
self.epoch += 1
return batch_reader
@property
def feeding(self):
"""返回数据读取器的exe读取字典
:return: 数据读取字典
:rtype: dict
"""
feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
return feeding_dict
@property
def vocab_size(self):
"""返回词汇表大小
:return: 词汇表大小
:rtype: int
"""
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
"""返回词汇表列表
:return: 词汇表列表
:rtype: list
"""
return self._speech_featurizer.vocab_list
def _instance_reader_creator(self, manifest):
"""
创建一个数据生成器reader
Instance: 生成器得到的数据是一个元组,包含了经过预处理音频数据和音频对应文本的ID
"""
def reader():
for instance in manifest:
inst = self.process_utterance(instance["audio_filepath"], instance["text"])
yield inst
return reader
def _padding_batch(self, batch, padding_to=-1, flatten=False):
"""
用零填充音频功能,使它们在同一个batch具有相同的形状(或一个用户定义的形状)
如果padding_to为-1,则批处理中的最大形状将被使用 作为填充的目标形状。
否则,' padding_to '将是目标形状(仅指第二轴)。
如果“flatten”为True,特征将被flatten为一维数据
"""
# 获取目标形状
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("如果padding_to不是-1,它应该大于批处理中任何实例的形状")
max_length = padding_to
# 填充操作
padded_audios = []
texts, text_lens = [], []
audio_lens = []
masks = []
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
padded_audios.append(padded_audio)
if self._is_training:
texts += text
else:
texts.append(text)
text_lens.append(len(text))
audio_lens.append(audio.shape[1])
mask_shape0 = (audio.shape[0] - 1) // 2 + 1
mask_shape1 = (audio.shape[1] - 1) // 3 + 1
mask_max_len = (max_length - 1) // 3 + 1
mask_ones = np.ones((mask_shape0, mask_shape1))
mask_zeros = np.zeros((mask_shape0, mask_max_len - mask_shape1))
mask = np.repeat(
np.reshape(np.concatenate((mask_ones, mask_zeros), axis=1),
(1, mask_shape0, mask_max_len)), 32, axis=0)
masks.append(mask)
padded_audios = np.array(padded_audios).astype('float32')
if self._is_training:
texts = np.expand_dims(np.array(texts).astype('int32'), axis=-1)
texts = fluid.create_lod_tensor(texts, recursive_seq_lens=[text_lens], place=self._place)
audio_lens = np.array(audio_lens).astype('int64').reshape([-1, 1])
masks = np.array(masks).astype('float32')
return padded_audios, texts, audio_lens, masks
def _batch_shuffle(self, manifest, batch_size, clipped=False):
"""将大小相似的实例放入小批量中可以提高效率,并进行批量打乱
1. 按持续时间对音频剪辑进行排序
2. 生成一个随机数k, k的范围[0,batch_size)
3. 随机移动k实例,为不同的epoch训练创建不同的批次
4. 打乱minibatches.
:param manifest: 数据列表
:type manifest: list
:param batch_size: 批量大小。这个大小还用于为批量洗牌生成一个随机数。
:type batch_size: int
:param clipped: 是否剪辑头部(小移位)和尾部(不完整批处理)实例。
:type clipped: bool
:return: Batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self._rng.randint(0, batch_size - 1)
batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size))
self._rng.shuffle(batch_manifest)
batch_manifest = [item for batch in batch_manifest for item in batch]
if not clipped:
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest
"""Contains the audio featurizer class."""
import numpy as np
from data_utils.audio import AudioSegment
class AudioFeaturizer(object):
"""音频特征器,用于从AudioSegment或SpeechSegment内容中提取特性。
Currently, it supports feature types of linear spectrogram and mfcc.
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: int
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def __init__(self,
stride_ms=10.0,
window_ms=20.0,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._stride_ms = stride_ms
self._window_ms = window_ms
self._target_sample_rate = target_sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
def featurize(self, audio_segment, allow_downsampling=True, allow_upsampling=True):
"""从AudioSegment或SpeechSegment中提取音频特征
:param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
"""
# upsampling or downsampling
if ((audio_segment.sample_rate > self._target_sample_rate and
allow_downsampling) or
(audio_segment.sample_rate < self._target_sample_rate and
allow_upsampling)):
audio_segment.resample(self._target_sample_rate)
if audio_segment.sample_rate != self._target_sample_rate:
raise ValueError("Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on.")
# decibel normalization
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# extract spectrogram
return self._compute_linear_specgram(audio_segment.samples, audio_segment.sample_rate,
stride_ms=self._stride_ms, window_ms=self._window_ms)
# 用快速傅里叶变换计算线性谱图
@staticmethod
def _compute_linear_specgram(samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
eps=1e-14):
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
truncate_size = (len(samples) - window_size) % stride_size
samples = samples[:len(samples) - truncate_size]
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
windows = np.lib.stride_tricks.as_strided(samples, shape=nshape, strides=nstrides)
assert np.all(windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# 快速傅里叶变换
weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, n=None, axis=0)
fft = np.absolute(fft)
fft = fft ** 2
scale = np.sum(weighting ** 2) * sample_rate
fft[1:-1, :] *= (2.0 / scale)
fft[(0, -1), :] /= scale
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
ind = np.where(freqs <= (sample_rate / 2))[0][-1] + 1
return np.log(fft[:ind, :] + eps)
"""Contains the speech featurizer class."""
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
from data_utils.featurizer.text_featurizer import TextFeaturizer
class SpeechFeaturizer(object):
"""Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment.
Currently, for audio parts, it supports feature types of linear
spectrogram and mfcc; for transcript parts, it only supports char-level
tokenizing and conversion into a list of token indices. Note that the
token indexing order follows the given vocabulary file.
:param vocab_filepath: Filepath to load vocabulary for token indices
conversion.
:type vocab_filepath: str
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param target_sample_rate: Speech are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: int
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def __init__(self,
vocab_filepath,
stride_ms=10.0,
window_ms=20.0,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20):
self._audio_featurizer = AudioFeaturizer(stride_ms=stride_ms,
window_ms=window_ms,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment, keep_transcription_text):
"""提取语音片段的特征
1. For audio parts, extract the audio features.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
:param audio_segment: Speech segment to extract features from.
:type audio_segment: SpeechSegment
:return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of
char-level token indices.
:rtype: tuple
"""
audio_feature = self._audio_featurizer.featurize(speech_segment)
if keep_transcription_text:
return audio_feature, speech_segment.transcript
text_ids = self._text_featurizer.featurize(speech_segment.transcript)
return audio_feature, text_ids
@property
def vocab_size(self):
"""返回词汇表大小
:return: Vocabulary size.
:rtype: int
"""
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
"""返回词汇表的list
:return: Vocabulary in list.
:rtype: list
"""
return self._text_featurizer.vocab_list
class TextFeaturizer(object):
"""文本特征器,用于处理或从文本中提取特征。支持字符级的令牌化和转换为令牌索引列表
:param vocab_filepath: 令牌索引转换词汇表的文件路径
:type vocab_filepath: str
"""
def __init__(self, vocab_filepath):
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
def featurize(self, text):
"""将文本字符串转换为字符级的令牌索引列表
:param text: 文本
:type text: str
:return:字符级令牌索引列表
:rtype: list
"""
tokens = self._char_tokenize(text)
token_indices = []
for token in tokens:
# 跳过词汇表不存在的字符
if token not in self._vocab_list:continue
token_indices.append(self._vocab_dict[token])
return token_indices
@property
def vocab_size(self):
"""返回词汇表大小
:return: Vocabulary size.
:rtype: int
"""
return len(self._vocab_list)
@property
def vocab_list(self):
"""返回词汇表的列表
:return: Vocabulary in list.
:rtype: list
"""
return self._vocab_list
def _char_tokenize(self, text):
"""Character tokenizer."""
return list(text.strip())
def _load_vocabulary_from_file(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r', encoding='utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line.split('\t')[0].replace('\n', '') for line in vocab_lines]
vocab_dict = dict(
[(token, id) for (id, token) in enumerate(vocab_list)])
return vocab_dict, vocab_list
"""特征归一化"""
import math
import numpy as np
import random
from tqdm import tqdm
from paddle.io import Dataset, DataLoader
from data_utils.utility import read_manifest
from data_utils.audio import AudioSegment
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
class FeatureNormalizer(object):
"""音频特征归一化类
如果mean_std_filepath不是None,则normalizer将直接从文件初始化。否则,使用manifest_path应该给特征mean和stddev计算
:param mean_std_filepath: 均值和标准值的文件路径
:type mean_std_filepath: None|str
:param manifest_path: 用于计算均值和标准值的数据列表,一般是训练的数据列表
:type meanifest_path: None|str
:param featurize_func:函数提取特征。它应该是可调用的``featurize_func(audio_segment)``
:type featurize_func: None|callable
:param num_samples: 用于计算均值和标准值的音频数量
:type num_samples: int
:param random_seed: 随机种子
:type random_seed: int
:raises ValueError: 如果mean_std_filepath和manifest_path(或mean_std_filepath和featurize_func)都为None
"""
def __init__(self,
mean_std_filepath,
manifest_path=None,
num_workers=4,
num_samples=5000,
random_seed=0):
if not mean_std_filepath:
if not manifest_path:
raise ValueError("如果mean_std_filepath是None,那么meanifest_path和featurize_func不应该是None")
self._rng = random.Random(random_seed)
self._compute_mean_std(manifest_path, num_samples, num_workers)
else:
self._read_mean_std_from_file(mean_std_filepath)
def apply(self, features, eps=1e-20):
"""使用均值和标准值计算音频特征的归一化值
:param features: 需要归一化的音频
:type features: ndarray
:param eps: 添加到标准值以提供数值稳定性
:type eps: float
:return: 已经归一化的数据
:rtype: ndarray
"""
return (features - self._mean) / (self._std + eps)
def write_to_file(self, filepath):
"""将计算得到的均值和标准值写入到文件中
:param filepath: 均值和标准值写入的文件路径
:type filepath: str
"""
np.savez(filepath, mean=self._mean, std=self._std)
def _read_mean_std_from_file(self, filepath):
"""从文件中加载均值和标准值"""
npzfile = np.load(filepath)
self._mean = npzfile["mean"]
self._std = npzfile["std"]
def _compute_mean_std(self, manifest_path, num_samples, num_workers):
"""从随机抽样的实例中计算均值和标准值"""
manifest = read_manifest(manifest_path)
if num_samples < 0 or num_samples > len(manifest):
sampled_manifest = manifest
else:
sampled_manifest = self._rng.sample(manifest, num_samples)
dataset = NormalizerDataset(sampled_manifest)
test_loader = DataLoader(dataset=dataset, batch_size=64, collate_fn=collate_fn, num_workers=num_workers)
# 求总和
std, means = None, None
number = 0
for std1, means1, number1 in tqdm(test_loader()):
number += number1
if means is None:
means = means1
else:
means += means1
if std is None:
std = std1
else:
std += std1
# 求总和的均值和标准值
for i in range(len(means)):
means[i] /= number
std[i] = std[i] / number - means[i] * means[i]
if std[i] < 1.0e-20:
std[i] = 1.0e-20
std[i] = math.sqrt(std[i])
self._mean = means.reshape([-1, 1])
self._std = std.reshape([-1, 1])
class NormalizerDataset(Dataset):
def __init__(self, sampled_manifest):
super(NormalizerDataset, self).__init__()
self.audio_featurizer = AudioFeaturizer()
self.sampled_manifest = sampled_manifest
def __getitem__(self, idx):
instance = self.sampled_manifest[idx]
# 获取音频特征
audio = AudioSegment.from_file(instance["audio_filepath"])
feature = self.audio_featurizer.featurize(audio)
return feature, 0
def __len__(self):
return len(self.sampled_manifest)
def collate_fn(features):
std, means = None, None
number = 0
for feature, _ in features:
number += feature.shape[1]
sums = np.sum(feature, axis=1)
if means is None:
means = sums
else:
means += sums
square_sums = np.sum(np.square(feature), axis=1)
if std is None:
std = square_sums
else:
std += square_sums
return std, means, number
"""Contains the speech segment class."""
import numpy as np
from data_utils.audio import AudioSegment
class SpeechSegment(AudioSegment):
"""语音片段抽象是音频片段的一个子类,附加文字记录。
:param samples: Audio samples [num_samples x num_channels].
:type samples: ndarray.float32
:param sample_rate: 训练数据的采样率
:type sample_rate: int
:param transcript: 音频文件对应的文本
:type transript: str
:raises TypeError: If the sample data type is not float or int.
"""
def __init__(self, samples, sample_rate, transcript):
AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript
def __eq__(self, other):
"""Return whether two objects are equal.
"""
if not AudioSegment.__eq__(self, other):
return False
if self._transcript != other._transcript:
return False
return True
def __ne__(self, other):
"""Return whether two objects are unequal."""
return not self.__eq__(other)
@classmethod
def from_file(cls, filepath, transcript):
"""从音频文件和相应的文本创建语音片段
:param filepath: 音频文件路径
:type filepath: str|file
:param transcript: 音频文件对应的文本
:type transript: str
:return: Speech segment instance.
:rtype: SpeechSegment
"""
audio = AudioSegment.from_file(filepath)
return cls(audio.samples, audio.sample_rate, transcript)
@classmethod
def from_bytes(cls, bytes, transcript):
"""从字节串和相应的文本创建语音片段
:param bytes: 包含音频样本的字节字符串
:type bytes: str
:param transcript: 音频文件对应的文本
:type transript: str
:return: Speech segment instance.
:rtype: Speech Segment
"""
audio = AudioSegment.from_bytes(bytes)
return cls(audio.samples, audio.sample_rate, transcript)
@classmethod
def concatenate(cls, *segments):
"""将任意数量的语音片段连接在一起,音频和文本都将被连接
:param *segments: 要连接的输入语音片段
:type *segments: tuple of SpeechSegment
:return: 返回SpeechSegment实例
:rtype: SpeechSegment
:raises ValueError: 不能用不同的抽样率连接片段
:raises TypeError: 只有相同类型SpeechSegment实例的语音片段可以连接
"""
if len(segments) == 0:
raise ValueError("音频片段为空")
sample_rate = segments[0]._sample_rate
transcripts = ""
for seg in segments:
if sample_rate != seg._sample_rate:
raise ValueError("不能用不同的抽样率连接片段")
if type(seg) is not cls:
raise TypeError("只有相同类型SpeechSegment实例的语音片段可以连接")
transcripts += seg._transcript
samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate, transcripts)
@classmethod
def slice_from_file(cls, filepath, transcript, start=None, end=None):
"""只加载一小部分SpeechSegment,而不需要将整个文件加载到内存中,这是非常浪费的。
:param filepath:文件路径或文件对象到音频文件
:type filepath: str|file
:param start: 开始时间,单位为秒。如果start是负的,则它从末尾开始计算。如果没有提供,这个函数将从最开始读取。
:type start: float
:param end: 结束时间,单位为秒。如果end是负的,则它从末尾开始计算。如果没有提供,默认的行为是读取到文件的末尾。
:type end: float
:param transcript: 音频文件对应的文本,如果没有提供,默认值是一个空字符串。
:type transript: str
:return: SpeechSegment实例
:rtype: SpeechSegment
"""
audio = AudioSegment.slice_from_file(filepath, start, end)
return cls(audio.samples, audio.sample_rate, transcript)
@classmethod
def make_silence(cls, duration, sample_rate):
"""创建指定安静音频长度和采样率的SpeechSegment实例,音频文件对应的文本将为空字符串。
:param duration: 安静音频的时间,单位秒
:type duration: float
:param sample_rate: 音频采样率
:type sample_rate: float
:return: 安静音频SpeechSegment实例
:rtype: SpeechSegment
"""
audio = AudioSegment.make_silence(duration, sample_rate)
return cls(audio.samples, audio.sample_rate, "")
@property
def transcript(self):
"""返回音频文件对应的文本
:return: 音频文件对应的文本
:rtype: str
"""
return self._transcript
"""数据工具函数"""
import json
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
"""解析数据列表
持续时间在[min_duration, max_duration]之外的实例将被过滤。
:param manifest_path: 数据列表的路径
:type manifest_path: str
:param max_duration: 过滤的最长音频长度
:type max_duration: float
:param min_duration: 过滤的最短音频长度
:type min_duration: float
:return: 数据列表,JSON格式
:rtype: list
:raises IOError: If failed to parse the manifest.
"""
manifest = []
for json_line in open(manifest_path, 'r', encoding='utf-8'):
try:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
if max_duration >= json_data["duration"] >= min_duration:
manifest.append(json_data)
return manifest
<blank>
广
使
西
便
绿
穿
线
怀
访
亿
寿
耀
鸿
湿
齿
退
贿
宿
沿
饿
仿
鹿
稿
姿
屿
轿
from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_beam_search_decoder_batch, ctc_beam_search_decoder
class BeamSearchDecoder:
def __init__(self, beam_alpha, beam_beta, language_model_path, vocab_list):
if language_model_path != 'None' and language_model_path != '' and language_model_path is not None:
print("初始化解码器...")
self._ext_scorer = Scorer(beam_alpha, beam_beta, language_model_path, vocab_list)
lm_char_based = self._ext_scorer.is_character_based()
lm_max_order = self._ext_scorer.get_max_order()
lm_dict_size = self._ext_scorer.get_dict_size()
print('='*70)
print("language model: "
"is_character_based = %d," % lm_char_based +
" max_order = %d," % lm_max_order +
" dict_size = %d" % lm_dict_size)
print('='*70)
print("初始化解码器完成!")
else:
self._ext_scorer = None
print("没有语言模型,解码由纯集束搜索,解码速度慢!")
# 单个数据解码
def decode_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
vocab_list, blank_id=0):
if self._ext_scorer is not None:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
# beam search decode
beam_search_result = ctc_beam_search_decoder(probs_seq=probs_split,
vocabulary=vocab_list,
beam_size=beam_size,
ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n,
blank_id=blank_id)
return beam_search_result[0]
# 一批数据解码
def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
vocab_list, num_processes, blank_id=0):
if self._ext_scorer is not None:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
# beam search decode
num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch(probs_split=probs_split,
vocabulary=vocab_list,
beam_size=beam_size,
num_processes=num_processes,
ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n,
blank_id=blank_id)
results = [result[0][1] for result in beam_search_results]
return results
from itertools import groupby
import numpy as np
def greedy_decoder(probs_seq, vocabulary, blank_index=0):
"""CTC贪婪(最佳路径)解码器
由最可能的令牌组成的路径将被进一步后处理到去掉连续重复和所有空白
:param probs_seq: 每一条都是2D的概率表。每个元素都是浮点数概率的列表一个字符
:type probs_seq: numpy.ndarray
:param vocabulary: 词汇列表
:type vocabulary: list
:param blank_index 需要移除的空白索引
:type blank_index int
:return: 解码后得到的字符串
:rtype: baseline
"""
# 获得每个时间步的最佳索引
max_index_list = list(np.array(probs_seq).argmax(axis=1))
max_prob_list = [probs_seq[i][max_index_list[i]] for i in range(len(max_index_list)) if max_index_list[i] != blank_index]
# 删除连续的重复索引和空索引
index_list = [index_group[0] for index_group in groupby(max_index_list)]
index_list = [index for index in index_list if index != blank_index]
# 索引列表转换为字符串
text = ''.join([vocabulary[index] for index in index_list])
score = 0
if len(max_prob_list) > 0:
score = float(sum(max_prob_list) / len(max_prob_list)) * 100.0
return score, text
def greedy_decoder_batch(probs_split, vocabulary, blank_index=0):
"""CTC贪婪(最佳路径)解码器
:param probs_split: 一批包含2D的概率表
:type probs_split: list
:param vocabulary: 词汇列表
:type vocabulary: list
:param blank_index 需要移除的空白索引
:type blank_index int
:return: 字符串列表
:rtype: list
"""
results = []
for i, probs in enumerate(probs_split):
output_transcription = greedy_decoder(probs, vocabulary, blank_index=blank_index)
results.append(output_transcription[1])
return results
"""Wrapper for various CTC decoders in SWIG."""
import swig_decoders
class Scorer(swig_decoders.Scorer):
"""Wrapper for Scorer.
:param alpha: 与语言模型相关的参数。当alpha = 0时不要使用语言模型
:type alpha: float
:param beta: 与字计数相关的参数。当beta = 0时不要使用统计字
:type beta: float
:model_path: 语言模型的路径
:type model_path: str
"""
def __init__(self, alpha, beta, model_path, vocabulary):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
"""CTC贪婪(最佳路径)解码器
由最可能的令牌组成的路径将被进一步后处理到去掉连续重复和所有空白
:param probs_seq: 每一条都是2D的概率表。每个元素都是浮点数概率的列表一个字符
:type probs_seq: numpy.ndarray
:param vocabulary: 词汇列表
:type vocabulary: list
:param blank_index 需要移除的空白索引
:type blank_index int
:return: 解码后得到的字符串
:rtype: baseline
"""
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary, blank_id)
return result
def ctc_beam_search_decoder(probs_seq,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
blank_id=0,
ext_scoring_func=None):
"""集束搜索解码器
:param probs_seq: 单个2-D概率分布列表,每个元素是词汇表和空白上的标准化概率列表
:type probs_seq: 2-D list
:param vocabulary: 词汇列表
:type vocabulary: list
:param beam_size: 集束搜索宽度
:type beam_size: int
:param cutoff_prob: 剪枝中的截断概率,默认1.0,没有剪枝
:type cutoff_prob: float
:param cutoff_top_n: 剪枝时的截断数,仅在词汇表中具有最大probs的cutoff_top_n字符用于光束搜索,默认为40
:type cutoff_top_n: int
:param blank_id 空白索引
:type blank_id int
:param ext_scoring_func: 外部评分功能部分解码句子,如字计数或语言模型
:type external_scoring_func: callable
:return: 解码结果为log概率和句子的元组列表,按概率降序排列
:rtype: list
"""
beam_results = swig_decoders.ctc_beam_search_decoder(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, ext_scoring_func, blank_id)
beam_results = [(res[0], res[1]) for res in beam_results]
return beam_results
def ctc_beam_search_decoder_batch(probs_split,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
blank_id=0,
ext_scoring_func=None):
"""Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D列表,每个元素作为ctc_beam_search_decoder()使用的2-D概率列表的实例
:type probs_seq: 3-D list
:param vocabulary: 词汇列表
:type vocabulary: list
:param beam_size: 集束搜索宽度
:type beam_size: int
:param cutoff_prob: 剪枝中的截断概率,默认1.0,没有剪枝
:type cutoff_prob: float
:param cutoff_top_n: 剪枝时的截断数,仅在词汇表中具有最大probs的cutoff_top_n字符用于光束搜索,默认为40
:type cutoff_top_n: int
:param blank_id 空白索引
:type blank_id int
:param num_processes: 并行解码进程数
:type num_processes: int
:param ext_scoring_func: 外部评分功能部分解码句子,如字计数或语言模型
:type ext_scoring_func: callable
:return: 解码结果为log概率和句子的元组列表,按概率降序排列的列表
:rtype: list
"""
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
return batch_beam_results
# 数据增强
数据增强是用来提升深度学习性能的非常有效的技术。通过在原始音频中添加小的随机扰动(标签不变转换)获得新音频来增强的语音数据。开发者不必自己合成,因为数据增强已经嵌入到数据生成器中并且能够即时完成,在训练模型的每个epoch中随机合成音频。
目前提供五个可选的增强组件供选择,配置并插入处理过程。
- 噪声干扰(需要背景噪音的音频文件)
- 速度扰动
- 移动扰动
- 音量扰动
- SpenAugment增强方式
为了让训练模块知道需要哪些增强组件以及它们的处理顺序,需要事先准备一个JSON格式的*扩展配置文件*。例如:
```json
[
{
"type": "noise",
"params": {
"min_snr_dB": 10,
"max_snr_dB": 50,
"noise_manifest_path": "dataset/manifest.noise"
},
"prob": 0.5
},
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1
},
"prob": 0.5
},
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 0.5
},
{
"type": "volume",
"params": {
"min_gain_dBFS": -15,
"max_gain_dBFS": 15
},
"prob": 0.5
},
{
"type": "specaug",
"params": {
"F": 10,
"T": 50,
"n_freq_masks": 2,
"n_time_masks": 2,
"p": 1.0,
"W": 80,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20
},
"prob": 1.0
}
]
```
`train.py``--augment_conf_file`参数被设置为上述示例配置文件的路径时,每个epoch中的每个音频片段都将被处理。首先,均匀随机采样速率会有50%的概率在 0.95 和 1.05
之间对音频片段进行速度扰动。然后,音频片段有 50% 的概率在时间上被挪移,挪移偏差值是 -5 毫秒和 5 毫秒之间的随机采样。最后,这个新合成的音频片段将被传送给特征提取器,以用于接下来的训练。
使用数据增强技术时要小心,由于扩大了训练和测试集的差异,不恰当的增强会对训练模型不利,导致训练和预测的差距增大。
# 集束搜索解码
本项目目前支持两种解码方法,分别是集束搜索(ctc_beam_search)和贪婪策略(ctc_greedy),项目全部默认都是使用贪婪策略解码的,集束搜索解码只支持Linux且Python为3.7的,如果要使用集束搜索方法,首先要安装`ctc_decoders`库,执行以下命令即可安装完成。
```shell
pip3 install paddlespeech-ctcdecoders==0.0.2a0
```
# 语言模型
集束搜索解码需要使用到语言模型,下载语言模型并放在lm目录下,下面下载的小语言模型,如何有足够大性能的机器,可以下载70G的超大语言模型,点击下载[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) ,这个模型会大超多。
```shell script
cd PaddlePaddle_DeepSpeech2-DeepSpeech/
mkdir lm
cd lm
wget https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
```
# 寻找最优的alpha和beta
这一步可以跳过,使用默认的alpha和beta也是不错的,如果想精益求精,可以执行下面的命令,可能速度会比较慢。执行完成之后会得到效果最好的alpha和beta参数值。
```shell
python tools/tune.py --model_path=./models/param/50.pdparams
```
# 使用集束搜索解码
在需要使用到解码器的程序,如评估,预测,指定参数`--decoding_method``ctc_beam_search`即可,如果alpha和beta参数值有改动,修改对应的值即可。
# 数据准备
1.`download_data`目录下是公开数据集的下载和制作训练数据列表和词汇表的,本项目提供了下载公开的中文普通话语音数据集,分别是Aishell,Free ST-Chinese-Mandarin-Corpus,THCHS-30 这三个数据集,总大小超过28G。下载这三个数据只需要执行一下代码即可,当然如果想快速训练,也可以只下载其中一个。**注意:** `noise.py`可下载可不下载,这是用于训练时数据增强的,如果不想使用噪声数据增强,可以不用下载。
```shell script
cd download_data/
python aishell.py
python free_st_chinese_mandarin_corpus.py
python thchs_30.py
python noise.py
```
**注意:** 以上代码只支持在Linux下执行,如果是Windows的话,可以获取程序中的`DATA_URL`单独下载,建议用迅雷等下载工具,这样下载速度快很多。然后把`download()`函数改为文件的绝对路径,如下,我把`aishell.py`的文件单独下载,然后替换`download()`函数,再执行该程序,就会自动解压文件文本生成数据列表。
```python
# 把这行代码
filepath = download(url, md5sum, target_dir)
# 修改为
filepath = "D:\\Download\\data_aishell.tgz"
```
2. 如果开发者有自己的数据集,可以使用自己的数据集进行训练,当然也可以跟上面下载的数据集一起训练。自定义的语音数据需要符合以下格式,另外对于音频的采样率,本项目默认使用的是16000Hz,在`create_data.py`中也提供了统一音频数据的采样率转换为16000Hz,只要`is_change_frame_rate`参数设置为True就可以。
1. 语音文件需要放在`PaddlePaddle-DeepSpeech/dataset/audio/`目录下,例如我们有个`wav`的文件夹,里面都是语音文件,我们就把这个文件存放在`PaddlePaddle-DeepSpeech/dataset/audio/`
2. 然后把数据列表文件存在`PaddlePaddle-DeepSpeech/dataset/annotation/`目录下,程序会遍历这个文件下的所有数据列表文件。例如这个文件下存放一个`my_audio.txt`,它的内容格式如下。每一行数据包含该语音文件的相对路径和该语音文件对应的中文文本,他们之间用`\t`隔开,要注意的是该中文文本只能包含纯中文,不能包含标点符号、阿拉伯数字以及英文字母。
```shell script
dataset/audio/wav/0175/H0175A0171.wav 我需要把空调温度调到二十度
dataset/audio/wav/0175/H0175A0377.wav 出彩中国人
dataset/audio/wav/0175/H0175A0470.wav 据克而瑞研究中心监测
dataset/audio/wav/0175/H0175A0180.wav 把温度加大到十八
```
3. 最后执行下面的数据集处理脚本,这个是把我们的数据集生成三个JSON格式的数据列表,分别是`manifest.test、manifest.train、manifest.noise`。然后建立词汇表,把所有出现的字符都存放子在`zh_vocab.txt`文件中,一行一个字符。最后计算均值和标准差用于归一化,默认使用全部的语音计算均值和标准差,并将结果保存在`mean_std.npz`中。以上生成的文件都存放在`PaddlePaddle-DeepSpeech/dataset/`目录下。
```shell script
# 生成数据列表
python create_data.py
```
# 评估
执行下面这个脚本对模型进行评估,通过字符错误率来评价模型的性能。
```shell
python eval.py --resume_model=./models/param/50.pdparams
```
输出结果:
```
----------- Configuration Arguments -----------
alpha: 1.2
batch_size: 64
beam_size: 10
beta: 0.35
cutoff_prob: 1.0
cutoff_top_n: 40
decoding_method: ctc_greedy
error_rate_type: cer
lang_model_path: ./lm/zh_giga.no_cna_cmn.prune01244.klm
mean_std_path: ./dataset/mean_std.npz
resume_model: ./models/param/50.pdparams
num_conv_layers: 2
num_proc_bsearch: 8
num_rnn_layers: 3
rnn_layer_size: 1024
test_manifest: ./dataset/manifest.test
use_gpu: True
vocab_path: ./dataset/zh_vocab.txt
------------------------------------------------
W0318 16:38:49.200599 19032 device_context.cc:252] Please NOTE: device: 0, CUDA Capability: 75, Driver API Version: 11.0, Runtime API Version: 10.0
W0318 16:38:49.242089 19032 device_context.cc:260] device: 0, cuDNN Version: 7.6.
[INFO 2021-03-18 16:38:53,689 eval.py:83] 开始评估 ...
错误率:[cer] (64/284) = 0.077040
错误率:[cer] (128/284) = 0.062989
错误率:[cer] (192/284) = 0.055674
错误率:[cer] (256/284) = 0.054918
错误率:[cer] (284/284) = 0.055882
消耗时间:44526ms, 总错误率:[cer] (284/284) = 0.055882
[INFO 2021-03-18 16:39:38,215 eval.py:117] 完成评估!
```
# 导出模型
训练保存的或者下载作者提供的模型都是模型参数,我们要将它导出为预测模型,这样可以直接使用模型,不再需要模型结构代码,同时使用Inference接口可以加速预测。
```shell
python export_model.py --resume_model=./models/param/50.pdparams
```
输出结果:
```
成功加载了预训练模型:./models/param/50.pdparams
----------- Configuration Arguments -----------
mean_std_path: ./dataset/mean_std.npz
num_conv_layers: 2
num_rnn_layers: 3
rnn_layer_size: 1024
resume_model: ./models/param/50.pdparams
save_model_path: ./models/infer/
use_gpu: True
vocab_path: ./dataset/zh_vocab.txt
------------------------------------------------
成功导出模型,模型保存在:./models/infer/
```
\ No newline at end of file
# LLVM版本错误
**如果出现LLVM版本错误**,则执行下面的命令,然后重新执行上面的安装命令,否则不需要执行。
```shell
cd ~
wget https://releases.llvm.org/9.0.0/llvm-9.0.0.src.tar.xz
wget http://releases.llvm.org/9.0.0/cfe-9.0.0.src.tar.xz
wget http://releases.llvm.org/9.0.0/clang-tools-extra-9.0.0.src.tar.xz
tar xvf llvm-9.0.0.src.tar.xz
tar xvf cfe-9.0.0.src.tar.xz
tar xvf clang-tools-extra-9.0.0.src.tar.xz
mv llvm-9.0.0.src llvm-src
mv cfe-9.0.0.src llvm-src/tools/clang
mv clang-tools-extra-9.0.0.src llvm-src/tools/clang/tools/extra
sudo mkdir -p /usr/local/llvm
sudo mkdir -p llvm-src/build
cd llvm-src/build
sudo cmake -G "Unix Makefiles" -DLLVM_TARGETS_TO_BUILD=X86 -DCMAKE_BUILD_TYPE="Release" -DCMAKE_INSTALL_PREFIX="/usr/local/llvm" ..
sudo make -j8
sudo make install
export LLVM_CONFIG=/usr/local/llvm/bin/llvm-config
```
- git clone 本项目源码
```shell script
git clone https://github.com/yeyupiaoling/DeepSpeech.git
```
# 合成语音数据
1. 为了拟补数据集的不足,我们合成一批语音用于训练,使用PaddlePaddle官方的Parakeet合成中文语音。首先安装Parakeet,执行下面命令即可安装完成。
```shell
git clone https://github.com/PaddlePaddle/Parakeet
cd Parakeet
python setup.py install
```
2. 然后分别下载下面模型压缩包并解压到`tools/generate_audio/`目录下。
```shell
https://download.csdn.net/download/qq_33200967/33826147
```
3. 把需要说话人的语音放在`tools/generate_audio/speaker_audio`目录下,可以使用`dataset/test.wav`文件,可以到找多个人的音频放在`tools/generate_audio/speaker_audio`目录下,开发者也可以尝试入自己的音频放入该目录,这样训练出来的模型能更好识别开发者的语音,采样率最好是16000Hz。
4. 然后下载一个语料,如果开发者有其他更好的语料也可以替换。然后解压`dgk_lost_conv/results`目录下的压缩文件,windows用户可以手动解压。
```shell
cd tools/generate_audio
git clone https://github.com/aceimnorstuvwxz/dgk_lost_conv.git
cd dgk_lost_conv/results
unzip dgk_shooter_z.conv.zip
unzip xiaohuangji50w_fenciA.conv.zip
unzip xiaohuangji50w_nofenci.conv.zip
```
5. 接着执行下面命令生成中文语料数据集,生成的中文语料存放在`tools/generate_audio/corpus.txt`
```shell
cd tools/generate_audio/
python generate_corpus.py
```
6. 最后执行以下命令即可自动合成语音,合成的语音会放在`dataset/audio/generate`, 标注文件会放在`dataset/annotation/generate.txt`
```shell
cd tools/generate_audio/
python generate_audio.py
```
# 本地预测
我们可以使用这个脚本使用模型进行预测,如果如何还没导出模型,需要执行[导出模型](export_model.md)操作把模型参数导出为预测模型,通过传递音频文件的路径进行识别,通过参数`--wav_path`指定需要预测的音频路径。支持中文数字转阿拉伯数字,将参数`--to_an`设置为True即可,默认为True。
```shell script
python infer_path.py --wav_path=./dataset/test.wav
```
输出结果:
```
----------- Configuration Arguments -----------
alpha: 1.2
beam_size: 10
beta: 0.35
cutoff_prob: 1.0
cutoff_top_n: 40
decoding_method: ctc_greedy
enable_mkldnn: False
is_long_audio: False
lang_model_path: ./lm/zh_giga.no_cna_cmn.prune01244.klm
mean_std_path: ./dataset/mean_std.npz
model_dir: ./models/infer/
to_an: True
use_gpu: True
vocab_path: ./dataset/zh_vocab.txt
wav_path: ./dataset/test.wav
------------------------------------------------
消耗时间:132, 识别结果: 近几年不但我用书给女儿儿压岁也劝说亲朋不要给女儿压岁钱而改送压岁书, 得分: 94
```
## 长语音预测
通过参数`--is_long_audio`可以指定使用长语音识别方式,这种方式通过VAD分割音频,再对短音频进行识别,拼接结果,最终得到长语音识别结果。
```shell script
python infer_path.py --wav_path=./dataset/test_vad.wav --is_long_audio=True
```
输出结果:
```
----------- Configuration Arguments -----------
alpha: 1.2
beam_size: 10
beta: 0.35
cutoff_prob: 1.0
cutoff_top_n: 40
decoding_method: ctc_greedy
enable_mkldnn: False
is_long_audio: 1
lang_model_path: ./lm/zh_giga.no_cna_cmn.prune01244.klm
mean_std_path: ./dataset/mean_std.npz
model_dir: ./models/infer/
to_an: True
use_gpu: True
vocab_path: ./dataset/zh_vocab.txt
wav_path: dataset/test_vad.wav
------------------------------------------------
第0个分割音频, 得分: 70, 识别结果: 记的12铺地补买上过了矛乱钻吃出满你都着们现上就只有1良解太穷了了臭力量紧不着还绑在大理达高的铁股上
第1个分割音频, 得分: 86, 识别结果: 我们都是骑自行说
第2个分割音频, 得分: 91, 识别结果: 他李达康知不知道党的组织原则
第3个分割音频, 得分: 71, 识别结果: 没是把就都路通着奖了李达方就是请他作现长件2着1把爽他作收记书就是发爽
第4个分割音频, 得分: 76, 识别结果: 那的当了熊掌我还得听她了哈哈他这太快还里生长还那得聊嘛安不乖怎么说
第5个分割音频, 得分: 97, 识别结果: 他老婆总是出事了嘛
第6个分割音频, 得分: 63, 识别结果: 就是前急次
第7个分割音频, 得分: 87, 识别结果: 欧阳箐是他前妻
第8个分割音频, 得分: 0, 识别结果:
第9个分割音频, 得分: 97, 识别结果: 我最后再说1句啊
第10个分割音频, 得分: 84, 识别结果: 能不能帮我个的小忙
第11个分割音频, 得分: 86, 识别结果: 说
第12个分割音频, 得分: 85, 识别结果: 她那陈清泉放了别再追究的
第13个分割音频, 得分: 93, 识别结果: 这陈清泉
第14个分割音频, 得分: 79, 识别结果: 跟你有生我来啊
第15个分割音频, 得分: 87, 识别结果: 我不认识个人
第16个分割音频, 得分: 81, 识别结果: 就是高小琴的人那你管这么宽干嘛啊
第17个分割音频, 得分: 94, 识别结果: 真以天下为己任了
第18个分割音频, 得分: 76, 识别结果: 你天下为竟人那是哪那耍我就是上在上晚上你们再山水张院的人让我照片和宁练个在我整么那不那板法
第19个分割音频, 得分: 67, 识别结果: 你就生涯真说晚啦是长微台过会来决定了
最终结果,消耗时间:1587, 得分: 79, 识别结果: ,记的12铺地补买上过了矛乱钻吃出满你都着们现上就只有1良解太穷了了臭力量紧不着还绑在大理达高的铁股上,我们都是骑自行说,他李达康知不知道党的组织原则,没是把就都路通着奖了李达方就是请他作现长件2着1把爽他作收记书就是发爽,那的当了熊掌我还得听她了哈哈他这太快还里生长还那得聊嘛安不乖怎么说,他老婆总是出事了嘛,就是前急次,欧阳箐是他前妻,,我最后再说1句啊,能不能帮我个的小忙,说,她那陈清泉放了别再追究的,这陈清泉,跟你有生我来啊,我不认识个人,就是高小琴的人那你管这么宽干嘛啊,真以天下为己任了,你天下为竟人那是哪那耍我就是上在上晚上你们再山水张院的人让我照片和宁练个在我整么那不那板法,你就生涯真说晚啦是长微台过会来决定了
```
## Web部署
在服务器执行下面命令通过创建一个Web服务,通过提供HTTP接口来实现语音识别。启动服务之后,如果在本地运行的话,在浏览器上访问`http://localhost:5000`,否则修改为对应的 IP地址。打开页面之后可以选择上传长音或者短语音音频文件,也可以在页面上直接录音,录音完成之后点击上传,播放功能只支持录音的音频。支持中文数字转阿拉伯数字,将参数`--to_an`设置为True即可,默认为True。
```shell script
python infer_server.py
```
打开页面如下:
![录音测试页面](images/infer_server.jpg)
## GUI界面部署
通过打开页面,在页面上选择长语音或者短语音进行识别,也支持录音识别,同时播放识别的音频。默认使用的是贪心解码策略,如果需要使用集束搜索方法的话,需要在启动参数的时候指定。
```shell script
python infer_gui.py
```
打开界面如下:
![GUI界面](images/infer_gui.jpg)
# 搭建本地环境
本人用的就是本地环境和使用Anaconda,并创建了Python3.7的虚拟环境,建议读者也本地环境,方便交流,出现安装问题,随时提[issue](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech/issues) ,如果想使用docker,请查看**搭建Docker环境**
- 首先安装的是PaddlePaddle 2.1.3的GPU版本,如果已经安装过了,请跳过。
```shell
conda install paddlepaddle-gpu==2.1.3 cudatoolkit=10.2 --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/
```
- 安装其他依赖库。
```shell
python -m pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
```
**注意:** 如果出现LLVM版本错误,解决办法[LLVM版本错误](faq.md)
# 搭建Docker环境
- 请提前安装好显卡驱动,然后执行下面的命令。
```shell script
# 卸载系统原有docker
sudo apt-get remove docker docker-engine docker.io containerd runc
# 更新apt-get源
sudo apt-get update
# 安装docker的依赖
sudo apt-get install \
apt-transport-https \
ca-certificates \
curl \
gnupg-agent \
software-properties-common
# 添加Docker的官方GPG密钥:
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
# 验证拥有指纹
sudo apt-key fingerprint 0EBFCD88
# 设置稳定存储库
sudo add-apt-repository \
"deb [arch=amd64] https://download.docker.com/linux/ubuntu \
$(lsb_release -cs) \
stable"
```
- 安装Docker
```shell script
# 再次更新apt-get源
sudo apt-get update
# 开始安装docker
sudo apt-get install docker-ce
# 加载docker
sudo apt-cache madison docker-ce
# 验证docker是否安装成功
sudo docker run hello-world
```
- 安装nvidia-docker
```shell script
# 设置stable存储库和GPG密钥
distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \
&& curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - \
&& curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
# 更新软件包清单后
sudo apt-get update
# 安装软件包
sudo apt-get install -y nvidia-docker2
# 设置默认运行时后,重新启动Docker守护程序以完成安装:
sudo systemctl restart docker
# 测试
sudo docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi
```
- 拉取PaddlePaddle 2.1.2镜像。
```shell script
sudo nvidia-docker pull registry.baidubce.com/paddlepaddle/paddle:2.1.2-gpu-cuda10.2-cudnn7
```
- git clone 本项目源码
```shell script
git clone https://github.com/yeyupiaoling/DeepSpeech.git
```
- 运行PaddlePaddle语音识别镜像,这里设置与主机共同拥有IP和端口号。
```shell script
sudo nvidia-docker run -it --net=host -v $(pwd)/DeepSpeech:/DeepSpeech registry.baidubce.com/paddlepaddle/paddle:2.1.2-gpu-cuda10.2-cudnn7 /bin/bash
```
- 安装其他依赖库。
```shell
python -m pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
```
# Nvidia Jetson部署
1. 这对Nvidia Jetson设备,如Nano、Nx、AGX等设备,可以通过下面命令安装PaddlePaddle的Inference预测库。
```shell
wget https://paddle-inference-lib.bj.bcebos.com/2.1.1-nv-jetson-jetpack4.4-all/paddlepaddle_gpu-2.1.1-cp36-cp36m-linux_aarch64.whl
pip3 install paddlepaddle_gpu-2.1.1-cp36-cp36m-linux_aarch64.whl
```
2. 安装scikit-learn依赖库。
```shell
git clone git://github.com/scikit-learn/scikit-learn.git
cd scikit-learn
pip3 install cython
git checkout 0.24.2
pip3 install --verbose --no-build-isolation --editable .
```
3. 安装其他依赖库。
```shell
pip3 install -r requirements.txt
```
3. 执行预测,直接使用根目录下的预测代码。
```shell
python infer_path.py --wav_path=./dataset/test.wav
```
以Nvidia AGX为例,输出结果如下:
```
WARNING: AVX is not support on your machine. Hence, no_avx core will be imported, It has much worse preformance than avx core.
----------- Configuration Arguments -----------
alpha: 1.2
beam_size: 10
beta: 0.35
cutoff_prob: 1.0
cutoff_top_n: 40
decoding_method: ctc_greedy
enable_mkldnn: False
is_long_audio: False
lang_model_path: ./lm/zh_giga.no_cna_cmn.prune01244.klm
mean_std_path: ./dataset/mean_std.npz
model_dir: ./models/infer/
to_an: True
use_gpu: True
use_tensorrt: False
vocab_path: ./dataset/zh_vocab.txt
wav_path: ./dataset/test.wav
------------------------------------------------
消耗时间:416ms, 识别结果: 近几年不但我用书给女儿压岁也劝说亲朋不要给女儿压岁钱而改送压岁书, 得分: 97
```
\ No newline at end of file
# 训练模型
- 执行训练脚本,开始训练语音识别模型, 每训练一轮和每2000个batch都会保存一次模型,模型保存在`PaddlePaddle-DeepSpeech/models/param/`目录下,默认会使用数据增强训练,如何不想使用数据增强,只需要将参数`augment_conf_path`设置为`None`即可。关于数据增强,请查看[数据增强](faq.md)部分。如果没有关闭测试,在每一轮训练结果之后,都会执行一次测试计算模型在测试集的准确率。执行训练时,如果是Linux下,通过`CUDA_VISIBLE_DEVICES`可以指定多卡训练。
```shell script
CUDA_VISIBLE_DEVICES=0,1 python train.py
```
训练输出结果如下:
```
----------- Configuration Arguments -----------
augment_conf_path: ./conf/augmentation.json
batch_size: 32
learning_rate: 0.0001
max_duration: 20.0
mean_std_path: ./dataset/mean_std.npz
min_duration: 0.0
num_conv_layers: 2
num_epoch: 50
num_rnn_layers: 3
output_model_dir: ./models/param
pretrained_model: None
resume_model: None
rnn_layer_size: 1024
shuffle_method: batch_shuffle_clipped
test_manifest: ./dataset/manifest.test
test_off: False
train_manifest: ./dataset/manifest.train
use_gpu: True
vocab_path: ./dataset/zh_vocab.txt
------------------------------------------------
dataset/manifest.noise不存在,已经忽略噪声增强操作!
[2021-08-31 22:40:36.473431] 训练数据数量:102394
W0831 22:40:36.624647 4879 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.0, Runtime API Version: 10.2
W0831 22:40:36.626874 4879 device_context.cc:422] device: 0, cuDNN Version: 7.6.
W0831 22:40:37.996898 4879 parallel_executor.cc:601] Cannot enable P2P access from 0 to 1
W0831 22:40:37.996917 4879 parallel_executor.cc:601] Cannot enable P2P access from 1 to 0
W0831 22:40:39.725975 4879 fuse_all_reduce_op_pass.cc:76] Find all_reduce operators: 44. To make the speed faster, some all_reduce ops are fused during training, after fusion, the number of all_reduce ops is 23.
Train [2021-08-31 22:40:41.633553] epoch: [1/50], batch: [0/1599], learning rate: 0.00020000, train loss: 2053.378662, eta: 3 days, 8:51:04
Train [2021-08-31 22:41:43.713666] epoch: [1/50], batch: [100/1599], learning rate: 0.00020000, train loss: 86.532333, eta: 12:59:38
Train [2021-08-31 22:42:40.098206] epoch: [1/50], batch: [200/1599], learning rate: 0.00020000, train loss: 87.101303, eta: 12:07:12
Train [2021-08-31 22:43:33.444587] epoch: [1/50], batch: [300/1599], learning rate: 0.00020000, train loss: 84.560562, eta: 11:42:34
Train [2021-08-31 22:44:24.759048] epoch: [1/50], batch: [400/1599], learning rate: 0.00020000, train loss: 81.681633, eta: 11:18:40
Train [2021-08-31 22:45:14.196539] epoch: [1/50], batch: [500/1599], learning rate: 0.00020000, train loss: 72.275848, eta: 10:27:17
Train [2021-08-31 22:46:02.194968] epoch: [1/50], batch: [600/1599], learning rate: 0.00020000, train loss: 76.041451, eta: 9:51:43
```
- 在训练过程中,程序会使用VisualDL记录训练结果,可以通过以下的命令启动VisualDL。
```shell
visualdl --logdir=log --host=0.0.0.0
```
- 然后再浏览器上访问`http://localhost:8040`可以查看结果显示,如下。
![Learning rate](https://img-blog.csdnimg.cn/20210318165719805.png)
![Test Cer](https://s3.ax1x.com/2021/03/01/6PJaZV.jpg)
![Train Loss](https://s3.ax1x.com/2021/03/01/6PJNq0.jpg)
# 恢复训练
如果在训练的时候中断了,可以通过参数`resume_model`指定模型,然后在这基础上恢复训练,在启动训练之后会加载该模型,并以当前epoch继续训练。
```shell script
CUDA_VISIBLE_DEVICES=0,1 python train.py --resume_model=models/param/23.pdparams
```
# 微调模型
如果读者已经训练或者下载了模型,想使用自己的数据集微调模型,除了使用`resume_model`参数指定模型外,还需要修改训练的`num_epoch`,因为该模型已经是最大`num_epoch`保存的模型,如果不修改参数的话,可能直接就停止训练了,可以设置为60,模型就会在原来的模型在训练10个epoch。数据集需要加上原来的数据合并一起训练。
```shell script
CUDA_VISIBLE_DEVICES=0,1 python train.py --resume_model=models/param/50.pdparams --num_epoch=60
```
# WenetSpeech数据集
10000+小时的普通话语音数据集
![WenetSpeech数据集](images/wenetspeech.jpg)
[WenetSpeech数据集](https://wenet-e2e.github.io/WenetSpeech/) 包含了10000+小时的普通话语音数据集,所有数据均来自 YouTube 和 Podcast。采用光学字符识别(OCR)和自动语音识别(ASR)技术分别标记每个YouTube和Podcast录音。为了提高语料库的质量,WenetSpeech使用了一种新颖的端到端标签错误检测方法来进一步验证和过滤数据。
- 所有数据分为 3 类,如下表所示:
| 数据分类 | 时长(小时) | 可信度 | 可用系统 |
|:---:|:---:|:---:|:---:|
| 强标签 | 10005 | \>=0.95 | 监督训练 |
| 弱标签 | 2478 | [0.6, 0.95] | 半监督或噪音训练 |
| 无标签 | 9952 | / | 无监督训练或预训练 |
| 总共 | 22435 | / | / |
- 领域、说话风格和场景将高标签分为 10 组,如下表所示:
| 领域 | Youtube(小时) | Podcast(小时) | 全部(小时) |
|:---:|:---:|:---:|:---:|
| 有声读物 | 0 | 250.9 | 250.9 |
| 现场解说 | 112.6 | 135.7 | 248.3 |
| 纪录片 | 386.7 | 90.5 | 477.2 |
| 戏剧 | 4338.2 | 0 | 4338.2 |
| 采访 | 324.2 | 614 | 938.2 |
| 新闻 | 0 | 868 | 868 |
| 阅读 | 0 | 1110.2 | 1110.2 |
| 讨论 | 204 | 90.7 | 294.7 |
| 综艺 | 603.3 | 224.5 | 827.8 |
| 其他 | 144 | 507.5 | 651.5 |
| 总共 | 6113 | 3892 | 10005 |
- 3个子集,即S,M并且L对不同的数据规模建设ASR系统
| 训练数据 | 可信度 | 时长(小时) |
|:---:|:---:|:---:|
| L | [0.95, 1.0] | 10005 |
| M | 1.0 | 1000 |
| S | 1.0 | 100 |
- 评估测试数据
| 评估数据 | 时长(小时) | 来源 | 描述 |
|:---:|:---:|:---:|:---:|
| DEV | 20 | 互联网 | 专为一些需要在训练中设置交叉验证的语音工具而设计 |
| TEST\_NET | 23 | 互联网 | 比赛测试 |
| TEST\_MEETING | 15 | 会议 | 远场、对话、自发和会议数据集 |
1. 本教程介绍如何使用该数据集训练语音识别模型,只是用强标签的数据,主要分三步。下载并解压WenetSpeech数据集,在[官网](https://wenet-e2e.github.io/WenetSpeech/#download) 填写表单之后,会收到邮件,执行邮件上面的三个命令就可以下载并解压数据集了,注意这要500G的磁盘空间。
2. 然后制作数据集,下载原始的数据是没有裁剪的,我们需要根据JSON标注文件裁剪并标注音频文件。在`tools`目录下执行`create_wenetspeech_data.py`程序就可以制作数据集了,注意此时需要3T的磁盘空间。`--wenetspeech_json`参数是指定WenetSpeech数据集的标注文件路径,具体根据读者下载的地址设置。
```shell
cd tools/
python create_wenetspeech_data.py --wenetspeech_json=/media/wenetspeech/WenetSpeech.json
```
3. 最后创建训练数据,跟普通使用一样,在项目根目录执行`create_data.py`就能过生成训练所需的数据列表,词汇表和均值标准差文件。这一步结束后就可以训练模型了,具体看[训练模型](train.md)
```shell
python create_data.py
```
**温馨提示:** 数据集超大,费时费资源,看自己的情况使用,无金刚钻就不要揽瓷器活。
\ No newline at end of file
import argparse
import os
import functools
from utility import download, unpack
from utility import add_arguments, print_arguments
DATA_URL = 'https://openslr.magicdatatech.com/resources/33/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
parser.add_argument("--target_dir",
default="../dataset/audio/",
type=str,
help="存放音频文件的目录 (默认: %(default)s)")
parser.add_argument("--annotation_text",
default="../dataset/annotation/",
type=str,
help="存放音频标注文件的目录 (默认: %(default)s)")
args = parser.parse_args()
def create_annotation_text(data_dir, annotation_path):
print('Create Aishell annotation text ...')
if not os.path.exists(annotation_path):
os.makedirs(annotation_path)
f_a = open(os.path.join(annotation_path, 'aishell.txt'), 'w', encoding='utf-8')
transcript_path = os.path.join(data_dir, 'transcript', 'aishell_transcript_v0.8.txt')
transcript_dict = {}
for line in open(transcript_path, 'r', encoding='utf-8'):
line = line.strip()
if line == '': continue
audio_id, text = line.split(' ', 1)
# remove space
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for type in data_types:
audio_dir = os.path.join(data_dir, 'wav', type)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
text = transcript_dict[audio_id]
f_a.write(audio_path[3:] + '\t' + text + '\n')
f_a.close()
def prepare_dataset(url, md5sum, target_dir, annotation_path):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'wav')
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for ftar in filelist:
unpack(os.path.join(subfolder, ftar), subfolder, True)
os.remove(filepath)
else:
print("Skip downloading and unpacking. Aishell data already exists in %s." % target_dir)
create_annotation_text(data_dir, annotation_path)
def main():
print_arguments(args)
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
annotation_path=args.annotation_text)
if __name__ == '__main__':
main()
import argparse
import functools
import os
from utility import download, unpack
from utility import add_arguments, print_arguments
DATA_URL = 'https://openslr.magicdatatech.com/resources/38/ST-CMDS-20170001_1-OS.tar.gz'
MD5_DATA = 'c28ddfc8e4ebe48949bc79a0c23c5545'
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
parser.add_argument("--target_dir",
default="../dataset/audio/",
type=str,
help="存放音频文件的目录 (默认: %(default)s)")
parser.add_argument("--annotation_text",
default="../dataset/annotation/",
type=str,
help="存放音频标注文件的目录 (默认: %(default)s)")
args = parser.parse_args()
def create_annotation_text(data_dir, annotation_path):
print('Create Free ST-Chinese-Mandarin-Corpus annotation text ...')
f_a = open(os.path.join(annotation_path, 'free_st_chinese_mandarin_corpus.txt'), 'w', encoding='utf-8')
for subfolder, _, filelist in sorted(os.walk(data_dir)):
for file in filelist:
if '.wav' in file:
file = os.path.join(subfolder, file)
with open(file[:-4] + '.txt', 'r', encoding='utf-8') as f:
line = f.readline()
f_a.write(file[3:] + '\t' + line + '\n')
f_a.close()
def prepare_dataset(url, md5sum, target_dir, annotation_path):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'ST-CMDS-20170001_1-OS')
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
os.remove(filepath)
else:
print("Skip downloading and unpacking. Free ST-Chinese-Mandarin-Corpus data already exists in %s." % target_dir)
create_annotation_text(data_dir, annotation_path)
def main():
print_arguments(args)
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
annotation_path=args.annotation_text)
if __name__ == '__main__':
main()
import argparse
import os
import functools
import shutil
from utility import download, unzip
from utility import add_arguments, print_arguments
DATA_URL = 'http://www.openslr.org/resources/28/rirs_noises.zip'
MD5_DATA = 'e6f48e257286e05de56413b4779d8ffb'
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
parser.add_argument("--target_dir",
default="../dataset/audio/",
type=str,
help="存放音频文件的目录 (默认: %(default)s)")
parser.add_argument("--noise_path",
default="../dataset/audio/noise/",
type=str,
help="存放噪声音频的目录 (默认: %(default)s)")
args = parser.parse_args()
def prepare_dataset(url, md5sum, target_dir, noise_path):
"""Download, unpack and move noise file."""
data_dir = os.path.join(target_dir, 'RIRS_NOISES')
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unzip(filepath, target_dir)
os.remove(filepath)
else:
print("Skip downloading and unpacking. RIRS_NOISES data already exists in %s." % target_dir)
# 移动噪声音频到指定文件夹
if not os.path.exists(noise_path):
os.makedirs(noise_path)
json_lines = []
data_types = [
'pointsource_noises', 'real_rirs_isotropic_noises', 'simulated_rirs'
]
for dtype in data_types:
del json_lines[:]
audio_dir = os.path.join(data_dir, dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
if '.wav' not in fname:continue
audio_path = os.path.join(subfolder, fname)
shutil.move(audio_path, os.path.join(noise_path, fname))
shutil.rmtree(data_dir, ignore_errors=True)
def main():
print_arguments(args)
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
noise_path=args.noise_path)
if __name__ == '__main__':
main()
import argparse
import os
import functools
from utility import download, unpack
from utility import add_arguments, print_arguments
DATA_URL = 'https://openslr.magicdatatech.com/resources/18/data_thchs30.tgz'
MD5_DATA = '2d2252bde5c8429929e1841d4cb95e90'
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
parser.add_argument("--target_dir",
default="../dataset/audio/",
type=str,
help="存放音频文件的目录 (默认: %(default)s)")
parser.add_argument("--annotation_text",
default="../dataset/annotation/",
type=str,
help="存放音频标注文件的目录 (默认: %(default)s)")
args = parser.parse_args()
def create_annotation_text(data_dir, annotation_path):
if not os.path.exists(annotation_path):
os.makedirs(annotation_path)
print('Create THCHS-30 annotation text ...')
f_a = open(os.path.join(annotation_path, 'thchs_30.txt'), 'w', encoding='utf-8')
data_path = 'data'
for file in os.listdir(os.path.join(data_dir, data_path)):
if '.trn' in file:
file = os.path.join(data_dir, data_path, file)
with open(file, 'r', encoding='utf-8') as f:
line = f.readline()
line = ''.join(line.split())
f_a.write(file[3:-4] + '\t' + line + '\n')
f_a.close()
def prepare_dataset(url, md5sum, target_dir, annotation_path):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_thchs30')
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
os.remove(filepath)
else:
print("Skip downloading and unpacking. THCHS-30 data already exists in %s." % target_dir)
create_annotation_text(data_dir, annotation_path)
def main():
print_arguments(args)
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
annotation_path=args.annotation_text)
if __name__ == '__main__':
main()
import distutils.util
import hashlib
import os
import tarfile
import zipfile
def print_arguments(args):
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
type = distutils.util.strtobool if type == bool else type
argparser.add_argument("--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def getfile_insensitive(path):
"""Get the actual file path when given insensitive filename."""
directory, filename = os.path.split(path)
directory, filename = (directory or '.'), filename.lower()
for f in os.listdir(directory):
newpath = os.path.join(directory, f)
if os.path.isfile(newpath) and f.lower() == filename:
return newpath
def download_multi(url, target_dir, extra_args):
"""Download multiple files from url to target_dir."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
print("Downloading %s ..." % url)
ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " +
target_dir)
return ret_code
def md5file(fname):
hash_md5 = hashlib.md5()
f = open(fname, "rb")
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
f.close()
return hash_md5.hexdigest()
def download(url, md5sum, target_dir):
"""Download file from url to target_dir, and check md5sum."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
filepath = os.path.join(target_dir, url.split("/")[-1])
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print("Downloading %s ..." % url)
os.system("wget -c " + url + " -P " + target_dir)
print("\nMD5 Chesksum %s ..." % filepath)
if not md5file(filepath) == md5sum:
raise RuntimeError("MD5 checksum failed.")
else:
print("File exists, skip downloading. (%s)" % filepath)
return filepath
def unpack(filepath, target_dir, rm_tar=False):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)
tar = tarfile.open(filepath)
tar.extractall(target_dir)
tar.close()
if rm_tar:
os.remove(filepath)
def unzip(filepath, target_dir):
"""Unzip the file to the target_dir."""
print("Unpacking %s ..." % filepath)
tar = zipfile.ZipFile(filepath, 'r')
tar.extractall(target_dir)
tar.close()
import argparse
import functools
import time
import paddle
from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model
from utils.error_rate import char_errors, word_errors
from decoders.ctc_greedy_decoder import greedy_decoder_batch
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "是否使用GPU评估")
add_arg('batch_size', int, 64, "评估是每一批数据的大小")
add_arg('num_conv_layers', int, 2, "卷积层数量")
add_arg('num_rnn_layers', int, 3, "循环神经网络的数量")
add_arg('rnn_layer_size', int, 1024, "循环神经网络的大小")
add_arg('beam_size', int, 300, "集束搜索解码相关参数,搜索大小,范围:[5, 500]")
add_arg('alpha', float, 1.2, "集束搜索解码相关参数,LM系数")
add_arg('num_proc_bsearch', int, 8, "集束搜索解码相关参数,使用CPU数量")
add_arg('beta', float, 0.35, "集束搜索解码相关参数,WC系数")
add_arg('cutoff_prob', float, 0.99, "集束搜索解码相关参数,剪枝的概率")
add_arg('cutoff_top_n', int, 40, "集束搜索解码相关参数,剪枝的最大值")
add_arg('test_manifest', str, './dataset/manifest.train', "需要评估的测试数据列表")
add_arg('mean_std_path', str, './dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
add_arg('vocab_path', str, './dataset/zh_vocab.txt', "数据集的字典文件路径")
add_arg('resume_model', str, './models/param/50.pdparams', "恢复模型文件路径")
add_arg('lang_model_path', str, './lm/zh_giga.no_cna_cmn.prune01244.klm', "集束搜索解码相关参数,语言模型文件路径")
add_arg('decoding_method', str, 'ctc_greedy', "结果解码方法,有集束搜索(ctc_beam_search)、贪婪策略(ctc_greedy)", choices=['ctc_beam_search', 'ctc_greedy'])
add_arg('error_rate_type', str, 'cer', "评估所使用的错误率方法,有字错率(cer)、词错率(wer)", choices=['wer', 'cer'])
args = parser.parse_args()
# 评估模型
def evaluate():
# 是否使用GPU
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
# 获取数据生成器
data_generator = DataGenerator(vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
keep_transcription_text=True,
place=place,
is_training=False)
# 获取评估数据
batch_reader = data_generator.batch_reader_creator(manifest_path=args.test_manifest,
batch_size=args.batch_size,
shuffle_method=None)
# 获取DeepSpeech2模型,并设置为预测
ds2_model = DeepSpeech2Model(vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
place=place,
resume_model=args.resume_model)
# 读取数据列表
with open(args.test_manifest, 'r', encoding='utf-8') as f_m:
test_len = len(f_m.readlines())
# 集束搜索方法的处理
if args.decoding_method == "ctc_beam_search":
try:
from decoders.beam_search_decoder import BeamSearchDecoder
beam_search_decoder = BeamSearchDecoder(args.alpha, args.beta, args.lang_model_path, data_generator.vocab_list)
except ModuleNotFoundError:
raise Exception('缺少swig_decoders库,请根据文档安装,如果是Windows系统,请使用ctc_greedy。')
# 获取评估函数,有字错率和词错率
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
errors_sum, len_refs, num_ins = 0.0, 0, 0
ds2_model.logger.info("开始评估 ...")
start = time.time()
# 开始评估
for infer_data in batch_reader():
# 获取一批的识别结果
probs_split = ds2_model.infer_batch_data(infer_data=infer_data)
# 执行解码
if args.decoding_method == 'ctc_greedy':
# 贪心解码策略
result_transcripts = greedy_decoder_batch(probs_split=probs_split, vocabulary=data_generator.vocab_list)
else:
# 集束搜索解码策略
result_transcripts = beam_search_decoder.decode_batch_beam_search(probs_split=probs_split,
beam_alpha=args.alpha,
beam_beta=args.beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=data_generator.vocab_list,
num_processes=args.num_proc_bsearch)
# 实际的结果
target_transcripts = infer_data[1]
# 计算字错率
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
print("错误率:[%s] (%d/%d) = %f" % (args.error_rate_type, num_ins, test_len, errors_sum / len_refs))
end = time.time()
print("消耗时间:%ds, 总错误率:[%s] (%d/%d) = %f" % ((end - start), args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
ds2_model.logger.info("完成评估!")
def main():
print_arguments(args)
evaluate()
if __name__ == '__main__':
main()
import argparse
import functools
import paddle
from model_utils.model import DeepSpeech2Model
from utils.utility import add_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('num_conv_layers', int, 2, "卷积层数量")
add_arg('num_rnn_layers', int, 3, "循环神经网络的数量")
add_arg('rnn_layer_size', int, 1024, "循环神经网络的大小")
add_arg('use_gpu', bool, False, "是否使用GPU加载模型")
add_arg('vocab_path', str, './dataset/zh_vocab.txt', "数据集的词汇表文件路径")
add_arg('resume_model', str, './models/param/50.pdparams', "恢复模型文件路径")
add_arg('save_model_path', str, './models/infer/', "保存导出的预测模型文件夹路径")
args = parser.parse_args()
# 是否使用GPU
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
with open(args.vocab_path, 'r', encoding='utf-8') as f:
vocab_size = len(f.readlines())
# 获取DeepSpeech2模型,并设置为预测
ds2_model = DeepSpeech2Model(vocab_size=vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
resume_model=args.resume_model,
place=place)
ds2_model.export_model(model_path=args.save_model_path)
print('成功导出模型,模型保存在:%s' % args.save_model_path)
import _thread
import argparse
import functools
import os
import time
import tkinter.messagebox
import wave
from tkinter.filedialog import *
import pyaudio
from data_utils.audio_process import AudioInferProcess
from utils.audio_vad import crop_audio_vad
from utils.predict import Predictor
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "是否使用GPU预测")
add_arg('enable_mkldnn', bool, False, "是否使用mkldnn加速")
add_arg('beam_size', int, 300, "集束搜索解码相关参数,搜索的大小,范围:[5, 500]")
add_arg('alpha', float, 1.2, "集束搜索解码相关参数,LM系数")
add_arg('beta', float, 0.35, "集束搜索解码相关参数,WC系数")
add_arg('cutoff_prob', float, 0.99, "集束搜索解码相关参数,剪枝的概率")
add_arg('cutoff_top_n', int, 40, "集束搜索解码相关参数,剪枝的最大值")
add_arg('mean_std_path', str, './dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
add_arg('vocab_path', str, './dataset/zh_vocab.txt', "数据集的词汇表文件路径")
add_arg('model_dir', str, './models/infer/', "导出的预测模型文件夹路径")
add_arg('lang_model_path', str, './lm/zh_giga.no_cna_cmn.prune01244.klm', "集束搜索解码相关参数,语言模型文件路径")
add_arg('decoding_method', str, 'ctc_greedy', "结果解码方法,有集束搜索(ctc_beam_search)、贪婪策略(ctc_greedy)", choices=['ctc_beam_search', 'ctc_greedy'])
args = parser.parse_args()
print_arguments(args)
class SpeechRecognitionApp:
def __init__(self, window: Tk, args):
self.window = window
self.wav_path = None
self.predicting = False
self.playing = False
self.recording = False
self.stream = None
self.to_an = True
# 最大录音时长
self.max_record = 20
# 录音保存的路径
self.output_path = 'dataset/record'
# 创建一个播放器
self.p = pyaudio.PyAudio()
# 指定窗口标题
self.window.title("夜雨飘零语音识别")
# 固定窗口大小
self.window.geometry('870x500')
self.window.resizable(False, False)
# 识别短语音按钮
self.short_button = Button(self.window, text="选择短语音识别", width=20, command=self.predict_audio_thread)
self.short_button.place(x=10, y=10)
# 识别长语音按钮
self.long_button = Button(self.window, text="选择长语音识别", width=20, command=self.predict_long_audio_thread)
self.long_button.place(x=170, y=10)
# 录音按钮
self.record_button = Button(self.window, text="录音识别", width=20, command=self.record_audio_thread)
self.record_button.place(x=330, y=10)
# 播放音频按钮
self.play_button = Button(self.window, text="播放音频", width=20, command=self.play_audio_thread)
self.play_button.place(x=490, y=10)
# 输出结果文本框
self.result_label = Label(self.window, text="输出日志:")
self.result_label.place(x=10, y=70)
self.result_text = Text(self.window, width=120, height=30)
self.result_text.place(x=10, y=100)
# 转阿拉伯数字控件
self.an_frame = Frame(self.window)
self.check_var = BooleanVar()
self.to_an_check = Checkbutton(self.an_frame, text='中文数字转阿拉伯数字', variable=self.check_var, command=self.to_an_state)
self.to_an_check.grid(row=0)
self.to_an_check.select()
self.an_frame.grid(row=1)
self.an_frame.place(x=700, y=10)
# 获取数据生成器,处理数据和获取字典需要
self.audio_process = AudioInferProcess(vocab_filepath=args.vocab_path, mean_std_filepath=args.mean_std_path)
# 获取识别器中文数字转阿拉伯数字
self.predictor = Predictor(model_dir=args.model_dir, audio_process=self.audio_process,
decoding_method=args.decoding_method, alpha=args.alpha, beta=args.beta,
lang_model_path=args.lang_model_path, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_top_n=args.cutoff_top_n, use_gpu=args.use_gpu,
enable_mkldnn=args.enable_mkldnn)
# 是否中文数字转阿拉伯数字
def to_an_state(self):
self.to_an = self.check_var.get()
# 预测短语音线程
def predict_audio_thread(self):
if not self.predicting:
self.wav_path = askopenfilename(filetypes=[("音频文件", "*.wav"), ("音频文件", "*.mp3")], initialdir='./dataset')
if self.wav_path == '': return
self.result_text.delete('1.0', 'end')
self.result_text.insert(END, "已选择音频文件:%s\n" % self.wav_path)
self.result_text.insert(END, "正在识别中...\n")
_thread.start_new_thread(self.predict_audio, (self.wav_path, ))
else:
tkinter.messagebox.showwarning('警告', '正在预测,请等待上一轮预测结束!')
# 预测短语音
def predict_audio(self, wav_path):
self.predicting = True
try:
start = time.time()
score, text = self.predictor.predict(audio_path=wav_path, to_an=self.to_an)
self.result_text.insert(END, "消耗时间:%dms, 识别结果: %s, 得分: %d\n" % (
round((time.time() - start) * 1000), text, score))
except Exception as e:
print(e)
self.predicting = False
# 预测长语音线程
def predict_long_audio_thread(self):
if not self.predicting:
self.wav_path = askopenfilename(filetypes=[("音频文件", "*.wav"), ("音频文件", "*.mp3")], initialdir='./dataset')
if self.wav_path == '': return
self.result_text.delete('1.0', 'end')
self.result_text.insert(END, "已选择音频文件:%s\n" % self.wav_path)
self.result_text.insert(END, "正在识别中...\n")
_thread.start_new_thread(self.predict_long_audio, (self.wav_path, ))
else:
tkinter.messagebox.showwarning('警告', '正在预测,请等待上一轮预测结束!')
# 预测长语音
def predict_long_audio(self, wav_path):
self.predicting = True
try:
start = time.time()
# 分割长音频
audios_path = crop_audio_vad(wav_path)
texts = ''
scores = []
# 执行识别
for i, audio_path in enumerate(audios_path):
score, text = self.predictor.predict(audio_path=audio_path, to_an=self.to_an)
texts = texts + ',' + text
scores.append(score)
self.result_text.insert(END, "第%d个分割音频, 得分: %d, 识别结果: %s\n" % (i, score, text))
self.result_text.insert(END, "=====================================================\n")
self.result_text.insert(END, "最终结果,消耗时间:%d, 得分: %d, 识别结果: %s\n" %
(round((time.time() - start) * 1000), sum(scores) / len(scores), texts))
except Exception as e:
print(e)
self.predicting = False
# 录音识别线程
def record_audio_thread(self):
if not self.playing and not self.recording:
self.result_text.delete('1.0', 'end')
_thread.start_new_thread(self.record_audio, ())
else:
if self.playing:
tkinter.messagebox.showwarning('警告', '正在录音,无法播放音频!')
else:
# 停止播放
self.recording = False
def record_audio(self):
self.record_button.configure(text='停止录音')
self.recording = True
# 录音参数
chunk = 1024
format = pyaudio.paInt16
channels = 1
rate = 16000
# 打开录音
self.stream = self.p.open(format=format,
channels=channels,
rate=rate,
input=True,
frames_per_buffer=chunk)
self.result_text.insert(END, "正在录音...\n")
start = time.time()
frames = []
while True:
if not self.recording:break
data = self.stream.read(chunk)
frames.append(data)
if len(frames) % 15 == 0:
self.result_text.insert(END, "已录音%.2f秒\n" % (time.time() - start))
if (time.time() - start) > self.max_record:
self.result_text.insert(END, "录音已超过最大限制时长,强制停止录音!")
break
if not os.path.exists(self.output_path):
os.makedirs(self.output_path)
save_path = os.path.join(self.output_path, '%s.wav' % str(int(time.time())))
wf = wave.open(save_path, 'wb')
wf.setnchannels(channels)
wf.setsampwidth(self.p.get_sample_size(format))
wf.setframerate(rate)
wf.writeframes(b''.join(frames))
wf.close()
self.recording = False
self.result_text.insert(END, "录音已结束,录音文件保存在:%s\n" % save_path)
# 识别录音
self.result_text.insert(END, "正在识别中...\n")
self.wav_path = save_path
self.predict_audio(self.wav_path)
self.record_button.configure(text='录音识别')
# 播放音频线程
def play_audio_thread(self):
if self.wav_path is None or self.wav_path == '':
tkinter.messagebox.showwarning('警告', '音频路径为空!')
else:
if not self.playing and not self.recording:
_thread.start_new_thread(self.play_audio, ())
else:
if self.recording:
tkinter.messagebox.showwarning('警告', '正在录音,无法播放音频!')
else:
# 停止播放
self.playing = False
# 播放音频
def play_audio(self):
self.play_button.configure(text='停止播放')
self.playing = True
CHUNK = 1024
wf = wave.open(self.wav_path, 'rb')
# 打开数据流
self.stream = self.p.open(format=self.p.get_format_from_width(wf.getsampwidth()),
channels=wf.getnchannels(),
rate=wf.getframerate(),
output=True)
# 读取数据
data = wf.readframes(CHUNK)
# 播放
while data != b'':
if not self.playing:break
self.stream.write(data)
data = wf.readframes(CHUNK)
# 停止数据流
self.stream.stop_stream()
self.stream.close()
self.playing = False
self.play_button.configure(text='播放音频')
tk = Tk()
myapp = SpeechRecognitionApp(tk, args)
if __name__ == '__main__':
tk.mainloop()
import argparse
import shutil
import time
import paddle
from paddlespeech.cli import ASRExecutor
from PaddlePaddle_DeepSpeech2.data_utils.audio_process import AudioInferProcess
from PaddlePaddle_DeepSpeech2.utils.predict import Predictor
from PaddlePaddle_DeepSpeech2.utils.audio_vad import crop_audio_vad
import os
normal_speed = 4
# from data_utils.audio_process import AudioInferProcess
# from utils.predict import Predictor
# from utils.audio_vad import crop_audio_vad
# from utils.utility import add_arguments, print_arguments
# parser = argparse.ArgumentParser(description=__doc__)
# add_arg = functools.partial(add_arguments, argparser=parser)
# add_arg('wav_path', str, './dataset/test.wav', "预测音频的路径")
# add_arg('is_long_audio', bool, False, "是否为长语音")
# add_arg('use_gpu', bool, False, "是否使用GPU预测")
# add_arg('enable_mkldnn', bool, False, "是否使用mkldnn加速")
# add_arg('to_an', bool, True, "是否转为阿拉伯数字")
# add_arg('beam_size', int, 300, "集束搜索解码相关参数,搜索的大小,范围:[5, 500]")
# add_arg('alpha', float, 1.2, "集束搜索解码相关参数,LM系数")
# add_arg('beta', float, 0.35, "集束搜索解码相关参数,WC系数")
# add_arg('cutoff_prob', float, 0.99, "集束搜索解码相关参数,剪枝的概率")
# add_arg('cutoff_top_n', int, 40, "集束搜索解码相关参数,剪枝的最大值")
# add_arg('mean_std_path', str, './PaddlePaddle_DeepSpeech2/dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
# add_arg('vocab_path', str, './PaddlePaddle_DeepSpeech2/dataset/zh_vocab.txt', "数据集的词汇表文件路径")
# add_arg('model_dir', str, './PaddlePaddle_DeepSpeech2/models/infer/', "导出的预测模型文件夹路径")
# add_arg('lang_model_path', str, './PaddlePaddle_DeepSpeech2/lm/zh_giga.no_cna_cmn.prune01244.klm',
# "集束搜索解码相关参数,语言模型文件路径")
# add_arg('decoding_method', str, 'ctc_greedy', "结果解码方法,有集束搜索(ctc_beam_search)、贪婪策略(ctc_greedy)",
# choices=['ctc_beam_search', 'ctc_greedy'])
# args = parser.parse_args()
# print_arguments(args)
# 使用paddle deepspeech进行语音识别
def predict_long_audio_with_paddle(wav_path, pre_time, state):
# 获取数据生成器,处理数据和获取字典需要
vocab_path = './PaddlePaddle_DeepSpeech2/dataset/zh_vocab.txt'
mean_std_path = './PaddlePaddle_DeepSpeech2/dataset/mean_std.npz'
decoding_method = 'ctc_greedy'
alpha = 1.2
beta = 0.35
model_dir = './PaddlePaddle_DeepSpeech2/models/infer/'
lang_model_path = './PaddlePaddle_DeepSpeech2/lm/zh_giga.no_cna_cmn.prune01244.klm'
beam_size = 300
cutoff_prob = 0.99
cutoff_top_n = 40
use_gpu = False
enable_mkldnn = False
audio_process = AudioInferProcess(vocab_filepath=vocab_path, mean_std_filepath=mean_std_path)
predictor = Predictor(model_dir=model_dir, audio_process=audio_process, decoding_method=decoding_method,
alpha=alpha, beta=beta, lang_model_path=lang_model_path,
beam_size=beam_size,
cutoff_prob=cutoff_prob, cutoff_top_n=cutoff_top_n, use_gpu=use_gpu,
enable_mkldnn=enable_mkldnn)
asr_executor = ASRExecutor()
start = time.time()
# 分割长音频
audios_path, time_stamps = crop_audio_vad(wav_path)
texts = ''
narratages = []
last_time = 0
# 执行识别
for i, audio_path in enumerate(audios_path):
print("{}开始处理{}".format(paddle.get_device(), audio_path))
# 标识当前语音识别的进度
state[0] = float((i + 1) / len(audio_path)) if state[0] is None or state[0] < 0.99 else 0.99
text = asr_executor(
model='conformer_wenetspeech',
lang='zh',
sample_rate=16000,
config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
ckpt_path=None,
audio_file=audio_path,
force_yes=True,
device=paddle.get_device()
)
if text:
if i == 0 or (i > 0 and time_stamps[i][0] - last_time >= 1):
recommend_lens = int(time_stamps[i][0] * normal_speed) if i == 0 else int(
(time_stamps[i][0] - last_time) * normal_speed)
narratages.append(["", "", "", "插入旁白,推荐字数为%d" % recommend_lens])
narratages.append(
[round(time_stamps[i][0] + pre_time, 2), round(time_stamps[i][1] + pre_time, 2), text, ''])
last_time = time_stamps[i][1]
print(
"第%d个分割音频 对应时间为%.2f-%.2f 识别结果: %s" % (i, time_stamps[i][0] + pre_time, time_stamps[i][1] + pre_time, text))
print("最终结果,消耗时间:%d, 识别结果: %s" % (round((time.time() - start) * 1000), texts))
# 完成后删除分割出来的音频
save_path = os.path.join(os.path.dirname(wav_path), 'crop_audio')
if os.path.exists(save_path):
shutil.rmtree(save_path)
return narratages
# # 使用网上已有的模型进行识别(效果差)
# def predict_audio_with_paddle():
# start = time.time()
# text = asr_executor(
# model='conformer_wenetspeech',
# lang='zh',
# sample_rate=16000,
# config=None, # Set `config` and `ckpt_path` to None to use pretrained model.
# ckpt_path=None,
# audio_file=args.wav_path,
# force_yes=False,
# device=paddle.get_device()
# )
# print("消耗时间:%dms, 识别结果: %s" % (round((time.time() - start) * 1000), text))
#
#
# def predict_long_audio():
# start = time.time()
# # 分割长音频
# audios_path = crop_audio_vad(args.wav_path)
# texts = ''
# scores = []
# # 执行识别
# for i, audio_path in enumerate(audios_path):
# score, text = predictor.predict(audio_path=audio_path, to_an=args.to_an)
# texts = texts + ',' + text
# scores.append(score)
# print("第%d个分割音频, 得分: %d, 识别结果: %s" % (i, score, text))
# print("最终结果,消耗时间:%d, 得分: %d, 识别结果: %s" % (round((time.time() - start) * 1000), sum(scores) / len(scores), texts))
#
#
# def predict_audio():
# start = time.time()
# score, text = predictor.predict(audio_path=args.wav_path, to_an=args.to_an)
# print("消耗时间:%dms, 识别结果: %s, 得分: %d" % (round((time.time() - start) * 1000), text, score))
if __name__ == "__main__":
# if args.is_long_audio:
# # predict_long_audio()
# predict_long_audio_with_paddle()
# else:
# # predict_audio()
# predict_audio_with_paddle()
pass
import argparse
import functools
import os
import sys
import time
from flask import request, Flask, render_template
from flask_cors import CORS
from data_utils.audio_process import AudioInferProcess
from utils.predict import Predictor
from utils.audio_vad import crop_audio_vad
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("host", str, "0.0.0.0", "监听主机的IP地址")
add_arg("port", int, 5000, "服务所使用的端口号")
add_arg("save_path", str, 'dataset/upload/', "上传音频文件的保存目录")
add_arg('use_gpu', bool, True, "是否使用GPU预测")
add_arg('enable_mkldnn', bool, False, "是否使用mkldnn加速")
add_arg('to_an', bool, True, "是否转为阿拉伯数字")
add_arg('beam_size', int, 300, "集束搜索解码相关参数,搜索大小,范围:[5, 500]")
add_arg('alpha', float, 1.2, "集束搜索解码相关参数,LM系数")
add_arg('beta', float, 0.35, "集束搜索解码相关参数,WC系数")
add_arg('cutoff_prob', float, 0.99, "集束搜索解码相关参数,剪枝的概率")
add_arg('cutoff_top_n', int, 40, "集束搜索解码相关参数,剪枝的最大值")
add_arg('mean_std_path', str, './dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
add_arg('vocab_path', str, './dataset/zh_vocab.txt', "数据集的词汇表文件路径")
add_arg('model_dir', str, './models/infer/', "导出的预测模型文件夹路径")
add_arg('lang_model_path', str, './lm/zh_giga.no_cna_cmn.prune01244.klm', "集束搜索解码相关参数,语言模型文件路径")
add_arg('decoding_method', str, 'ctc_greedy', "结果解码方法,有集束搜索(ctc_beam_search)、贪婪策略(ctc_greedy)", choices=['ctc_beam_search', 'ctc_greedy'])
args = parser.parse_args()
app = Flask(__name__, template_folder="templates", static_folder="static", static_url_path="/")
# 允许跨越访问
CORS(app)
# 获取数据生成器,处理数据和获取字典需要
audio_process = AudioInferProcess(vocab_filepath=args.vocab_path, mean_std_filepath=args.mean_std_path)
predictor = Predictor(model_dir=args.model_dir, audio_process=audio_process, decoding_method=args.decoding_method,
alpha=args.alpha, beta=args.beta, lang_model_path=args.lang_model_path, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_top_n=args.cutoff_top_n, use_gpu=args.use_gpu,
enable_mkldnn=args.enable_mkldnn)
# 语音识别接口
@app.route("/recognition", methods=['POST'])
def recognition():
f = request.files['audio']
if f:
# 临时保存路径
file_path = os.path.join(args.save_path, f.filename)
f.save(file_path)
try:
start = time.time()
# 执行识别
score, text = predictor.predict(audio_path=file_path, to_an=args.to_an)
end = time.time()
print("识别时间:%dms,识别结果:%s, 得分: %f" % (round((end - start) * 1000), text, score))
result = str({"code": 0, "msg": "success", "result": text, "score": round(score, 3)}).replace("'", '"')
return result
except:
return str({"error": 1, "msg": "audio read fail!"})
return str({"error": 3, "msg": "audio is None!"})
# 长语音识别接口
@app.route("/recognition_long_audio", methods=['POST'])
def recognition_long_audio():
f = request.files['audio']
if f:
# 临时保存路径
file_path = os.path.join(args.save_path, f.filename)
f.save(file_path)
try:
start = time.time()
# 分割长音频
audios_path = crop_audio_vad(file_path)
texts = ''
scores = []
# 执行识别
for i, audio_path in enumerate(audios_path):
score, text = predictor.predict(audio_path=audio_path, to_an=args.to_an)
texts = texts + ',' + text
scores.append(score)
end = time.time()
print("识别时间:%dms,识别结果:%s, 得分: %f" % (round((end - start) * 1000), texts, sum(scores) / len(scores)))
result = str({"code": 0, "msg": "success", "result": texts, "score": round(float(sum(scores) / len(scores)), 3)}).replace("'", '"')
return result
except Exception as e:
print(e, file=sys.stderr)
return str({"error": 1, "msg": "audio read fail!"})
return str({"error": 3, "msg": "audio is None!"})
@app.route('/')
def home():
return render_template("index.html")
if __name__ == '__main__':
print_arguments(args)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
app.run(host=args.host, port=args.port)
import logging
import os
import shutil
import time
import paddle
from datetime import datetime, timedelta
from distutils.dir_util import mkpath
import numpy as np
import paddle.fluid as fluid
from visualdl import LogWriter
from utils.error_rate import char_errors, word_errors
from decoders.ctc_greedy_decoder import greedy_decoder_batch
from model_utils.network import deep_speech_v2_network
logging.basicConfig(format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
paddle.enable_static()
class DeepSpeech2Model(object):
"""DeepSpeech2Model class.
:param vocab_size: 词汇表大小
:type vocab_size: int
:param num_conv_layers: 叠加卷积层数
:type num_conv_layers: int
:param num_rnn_layers: 叠加RNN层数
:type num_rnn_layers: int
:param rnn_layer_size: RNN层大小
:type rnn_layer_size: int
:param place: Program running place.
:type place: CPUPlace or CUDAPlace
:param resume_model: 恢复模型路径
:type resume_model: string|None
:param pretrained_model: 预训练模型路径
:type pretrained_model: string|None
:param output_model_dir: 保存模型的路径
:type output_model_dir: string|None
:param error_rate_type: 测试计算错误率的方式
:type error_rate_type: string|None
:param vocab_list: 词汇表列表
:type vocab_list: list|None
:param blank: 损失函数的空白索引
:type blank: int
"""
def __init__(self,
vocab_size,
num_conv_layers,
num_rnn_layers,
rnn_layer_size,
place=paddle.CPUPlace(),
resume_model=None,
pretrained_model=None,
output_model_dir=None,
error_rate_type='cer',
vocab_list=None,
blank=0):
self._vocab_size = vocab_size
self._num_conv_layers = num_conv_layers
self._num_rnn_layers = num_rnn_layers
self._rnn_layer_size = rnn_layer_size
self._place = place
self._blank = blank
self._pretrained_model = pretrained_model
self._resume_model = resume_model
self._output_model_dir = output_model_dir
self._ext_scorer = None
self.logger = logging.getLogger("")
self.logger.setLevel(level=logging.INFO)
self.error_rate_type = error_rate_type
self.vocab_list = vocab_list
self.save_model_path = ''
# 预测相关的参数
self.infer_program = None
self.infer_compiled_prog = None
self.infer_feeder = None
self.infer_log_probs = None
self.infer_exe = None
def create_network(self, is_infer=False):
"""Create data layers and model network.
:param is_infer: Whether to create a network for Inference.
:type is_infer: bool
:return reader: Reader for input.
:rtype reader: read generater
:return log_probs: An output unnormalized log probability layer.
:rtype lig_probs: Varable
:return loss: A ctc loss layer.
:rtype loss: Variable
"""
if not is_infer:
input_fields = {
'names': ['audio_data', 'text_data', 'seq_len_data', 'masks'],
'shapes': [[None, 161, None], [None, 1], [None, 1], [None, 32, 81, None]],
'dtypes': ['float32', 'int32', 'int64', 'float32'],
'lod_levels': [0, 1, 0, 0]
}
inputs = [
paddle.static.data(name=input_fields['names'][i],
shape=input_fields['shapes'][i],
dtype=input_fields['dtypes'][i],
lod_level=input_fields['lod_levels'][i])
for i in range(len(input_fields['names']))
]
reader = fluid.io.DataLoader.from_generator(feed_list=inputs,
capacity=64,
iterable=False,
use_double_buffer=True)
(audio_data, text_data, seq_len_data, masks) = inputs
else:
audio_data = paddle.static.data(name='audio_data',
shape=[None, 161, None],
dtype='float32',
lod_level=0)
seq_len_data = paddle.static.data(name='seq_len_data',
shape=[None, 1],
dtype='int64',
lod_level=0)
masks = paddle.static.data(name='masks',
shape=[None, 32, 81, None],
dtype='float32',
lod_level=0)
text_data = None
reader = fluid.DataFeeder([audio_data, seq_len_data, masks], self._place)
log_probs, loss = deep_speech_v2_network(audio_data=audio_data,
text_data=text_data,
seq_len_data=seq_len_data,
masks=masks,
dict_size=self._vocab_size,
num_conv_layers=self._num_conv_layers,
num_rnn_layers=self._num_rnn_layers,
rnn_size=self._rnn_layer_size,
blank=self._blank)
return reader, log_probs, loss
# 加载模型
def load_param(self, program, model_path, ignore_opt=False):
if not os.path.exists(model_path):
raise Warning("The pretrained params [%s] do not exist." % model_path)
load_state_dict = paddle.load(model_path)
if ignore_opt:
for key in program.state_dict(mode='opt').keys():
load_state_dict.pop(key)
program.set_state_dict(load_state_dict)
print('[{}] 成功加载模型:{}'.format(datetime.now(), model_path))
# 保存模型
def save_param(self, program, epoch):
if not os.path.exists(self._output_model_dir):
os.mkdir(self._output_model_dir)
model_path = '{}/{}.pdparams'.format(self._output_model_dir, epoch)
paddle.save(program.state_dict(), model_path)
old_model_path = '{}/{}.pdparams'.format(self._output_model_dir, epoch - 3)
if os.path.exists(old_model_path):
os.remove(old_model_path)
print("模型已保存在:%s" % model_path)
return model_path
def train(self,
train_batch_reader,
dev_batch_reader,
learning_rate,
gradient_clipping,
num_epoch,
batch_size,
train_num_samples,
test_num_samples,
test_off=False):
"""Train the model.
:param train_batch_reader: Train data reader.
:type train_batch_reader: callable
:param dev_batch_reader: Validation data reader.
:type dev_batch_reader: callable
:param learning_rate: Learning rate for ADAM optimizer.
:type learning_rate: float
:param gradient_clipping: Gradient clipping threshold.
:type gradient_clipping: float
:param num_epoch: Number of training epochs.
:type num_epoch: int
:param batch_size: Number of batch size.
:type batch_size: int
:param train_num_samples: The num of train samples.
:type train_num_samples: int
:param test_num_samples: The num of test samples.
:type test_num_samples: int
:param test_off: Turn off testing.
:type test_off: bool
"""
shutil.rmtree('log', ignore_errors=True)
writer = LogWriter(logdir='log')
# prepare model output directory
if not os.path.exists(self._output_model_dir):
mkpath(self._output_model_dir)
if isinstance(self._place, paddle.CUDAPlace):
dev_count = len(paddle.static.cuda_places())
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
pre_epoch = 0
if self._resume_model:
try:
pre_epoch = os.path.basename(self._resume_model).split('.')[0]
pre_epoch = int(pre_epoch)
except:
print("恢复模型命名不正确,epoch从0开始训练!")
# prepare the network
train_program = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_prog):
with paddle.utils.unique_name.guard():
train_reader, _, ctc_loss = self.create_network()
# 学习率
scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=learning_rate, gamma=0.83,
last_epoch=pre_epoch - 1)
# 准备优化器
optimizer = paddle.optimizer.Adam(
learning_rate=scheduler,
weight_decay=paddle.regularizer.L2Decay(5e-4),
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=gradient_clipping))
optimizer.minimize(loss=ctc_loss)
exe = paddle.static.Executor(self._place)
exe.run(startup_prog)
# 加载预训练模型
if self._resume_model is not None or self._pretrained_model is not None:
if self._resume_model is not None:
self.load_param(train_program, self._resume_model)
else:
self.load_param(train_program, self._pretrained_model, ignore_opt=True)
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
# pass the build_strategy to with_data_parallel API
train_compiled_prog = paddle.static.CompiledProgram(train_program) \
.with_data_parallel(loss_name=ctc_loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
train_reader.set_batch_generator(train_batch_reader)
train_step = 0
test_step = 0
train_num_batch = train_num_samples // batch_size // dev_count
test_num_batch = test_num_samples // batch_size
sum_batch = train_num_batch * (num_epoch - pre_epoch)
# run train
for epoch_id in range(pre_epoch, num_epoch):
epoch_id += 1
train_reader.start()
epoch_loss = []
time_begin = time.time()
batch_id = 0
start = time.time()
while True:
try:
if batch_id % 100 == 0:
# 执行训练
fetch = exe.run(program=train_compiled_prog, fetch_list=[ctc_loss.name], return_numpy=False)
each_loss = fetch[0]
epoch_loss.extend(np.array(each_loss[0]) / batch_size)
eta_sec = ((time.time() - start) * 1000) * (
sum_batch - (epoch_id - pre_epoch - 1) * train_num_batch - batch_id)
eta_str = str(timedelta(seconds=int(eta_sec / 1000)))
print(
"Train [%s] epoch: [%d/%d], batch: [%d/%d], learning rate: %.8f, train loss: %f, eta: %s" %
(datetime.now(), epoch_id, num_epoch, batch_id, train_num_batch, scheduler.get_lr(),
np.mean(each_loss[0]) / batch_size, eta_str))
# 记录训练损失值
writer.add_scalar('Train loss', np.mean(each_loss[0]) / batch_size, train_step)
writer.add_scalar('Learning rate', scheduler.get_lr(), train_step)
train_step += 1
else:
# 执行训练
_ = exe.run(program=train_compiled_prog, fetch_list=[], return_numpy=False)
# 每10000个batch保存一次模型
if batch_id % 10000 == 0 and batch_id != 0:
self.save_param(train_program, epoch_id)
batch_id = batch_id + 1
start = time.time()
except fluid.core.EOFException:
train_reader.reset()
break
scheduler.step()
# 每一个epoch保存一次模型
self._resume_model = self.save_param(train_program, epoch_id)
used_time = time.time() - time_begin
if test_off:
print('======================last Train=====================')
print("Train time: %f sec, epoch: %d, train loss: %f\n" %
(used_time, epoch_id, float(np.mean(np.array(epoch_loss)))))
print('======================last Train=====================')
else:
print('\n======================Begin test=====================')
# 执行测试
test_result = self.test(test_reader=dev_batch_reader, epoch_id=epoch_id, test_num_batch=test_num_batch)
print("Test [%s] train time: %s, epoch: %d, train loss: %f, test %s: %f"
% (datetime.now(), str(timedelta(seconds=int(used_time))), epoch_id,
float(np.mean(np.array(epoch_loss))), self.error_rate_type, test_result))
print('======================Stop Test=====================\n')
# 记录测试结果
writer.add_scalar('Test %s' % self.error_rate_type, test_result, test_step)
test_step += 1
self.save_param(train_program, num_epoch)
print("\n------------Training finished!!!-------------")
def test(self, test_reader, epoch_id, test_num_batch):
'''Test the model.
:param test_reader: Reader of test.
:type test_reader: Reader
:param epoch_id: Train epoch id
:type epoch_id: int
:param test_num_batch: Test batch number
:type test_num_batch: int
:return: Wer/Cer rate.
:rtype: float
'''
# 初始化预测程序
self.create_infer_program()
# 加载预训练模型
self.load_param(self.infer_program, self._resume_model)
errors_sum, len_refs = 0.0, 0
errors_func = char_errors if self.error_rate_type == 'cer' else word_errors
for batch_id, infer_data in enumerate(test_reader()):
# 执行预测
probs_split = self.infer_batch_data(infer_data=infer_data)
# 使用最优路径解码
result_transcripts = greedy_decoder_batch(probs_split=probs_split, vocabulary=self.vocab_list)
target_transcripts = infer_data[1]
# 计算字错率
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
if batch_id % 100 == 0:
print("Test [%s] epoch: %d, batch: [%d/%d], %s: %f" %
(datetime.now(), epoch_id, batch_id, test_num_batch, self.error_rate_type, errors_sum / len_refs))
return errors_sum / len_refs
# 预测一个batch的音频
def infer_batch_data(self, infer_data):
"""Infer the prob matrices for a batch of speech utterances.
:param infer_data: List of utterances to infer, with each utterance
consisting of a tuple of audio features and
transcription text (empty string).
:type infer_data: list
:return: List of 2-D probability matrix, and each consists of prob
vectors for one speech utterancce.
:rtype: List of matrix
"""
# define inferer
if self.infer_exe is None:
# 初始化预测程序
self.create_infer_program()
# 加载预训练模型
self.load_param(self.infer_program, self._resume_model)
infer_results = []
data = []
if isinstance(self._place, paddle.CUDAPlace):
num_places = len(paddle.static.cuda_places())
else:
num_places = int(os.environ.get('CPU_NUM', 1))
# 开始预测
for i in range(infer_data[0].shape[0]):
# 使用多卡推理
data.append([[infer_data[0][i], infer_data[2][i], infer_data[3][i]]])
if len(data) == num_places:
each_log_probs = self.infer_exe.run(program=self.infer_compiled_prog,
feed=list(self.infer_feeder.feed_parallel(
iterable=data, num_places=num_places)),
fetch_list=[self.infer_log_probs],
return_numpy=False)
data = []
infer_results.extend(np.array(each_log_probs[0]))
# 如果数据是单数,就获取最后一个计算
last_data_num = infer_data[0].shape[0] % num_places
if last_data_num != 0:
for i in range(infer_data[0].shape[0] - last_data_num, infer_data[0].shape[0]):
each_log_probs = self.infer_exe.run(program=self.infer_program,
feed=self.infer_feeder.feed(
[[infer_data[0][i], infer_data[2][i], infer_data[3][i]]]),
fetch_list=[self.infer_log_probs],
return_numpy=False)
infer_results.extend(np.array(each_log_probs[0]))
# slice result
infer_results = np.array(infer_results)
seq_len = (infer_data[2] - 1) // 3 + 1
start_pos = [0] * (infer_data[0].shape[0] + 1)
for i in range(infer_data[0].shape[0]):
start_pos[i + 1] = start_pos[i] + seq_len[i][0]
probs_split = [
infer_results[start_pos[i]:start_pos[i + 1]]
for i in range(0, infer_data[0].shape[0])
]
return probs_split
# 初始化预测程序,加预训练模型
def create_infer_program(self):
# define inferer
self.infer_program = paddle.static.Program()
startup_prog = paddle.static.Program()
# prepare the network
with paddle.static.program_guard(self.infer_program, startup_prog):
with paddle.utils.unique_name.guard():
self.infer_feeder, self.infer_log_probs, _ = self.create_network(is_infer=True)
self.infer_program = self.infer_program.clone(for_test=True)
self.infer_exe = paddle.static.Executor(self._place)
self.infer_exe.run(startup_prog)
# 支持多卡推理
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
self.infer_compiled_prog = paddle.static.CompiledProgram(self.infer_program) \
.with_data_parallel(build_strategy=build_strategy,
exec_strategy=exec_strategy)
# 导出预测模型
def export_model(self, model_path):
self.create_infer_program()
# 加载预训练模型
self.load_param(self.infer_program, self._resume_model)
audio_data = paddle.static.data(name='audio_data',
shape=[None, 161, None],
dtype='float32',
lod_level=0)
seq_len_data = paddle.static.data(name='seq_len_data',
shape=[None, 1],
dtype='int64',
lod_level=0)
masks = paddle.static.data(name='masks',
shape=[None, 32, 81, None],
dtype='float32',
lod_level=0)
if not os.path.exists(os.path.dirname(model_path)):
os.makedirs(os.path.dirname(model_path))
paddle.static.save_inference_model(path_prefix=model_path + '/inference',
feed_vars=[audio_data, seq_len_data, masks],
fetch_vars=[self.infer_log_probs],
executor=self.infer_exe,
program=self.infer_program)
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.static.nn as nn
def conv_bn_layer(input, filter_size, num_channels_out, stride, padding, act, masks):
"""卷积层与批处理归一化
:param input: Input layer.
:type input: Variable
:param filter_size: 卷积核大小
:type filter_size: int|tuple|list
:param num_channels_out: 输出通道数
:type num_channels_out: int
:param stride: 步幅大小
:type stride: int|tuple|list
:param padding: 填充大小
:type padding: int|tuple|list
:param act: 激活函数类型
:type act: string
:param masks:掩码层,用于填充
:type masks: Variable
:return: 批处理范数层后卷积层
:rtype: Variable
"""
conv_layer = nn.conv2d(input=input,
num_filters=num_channels_out,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=paddle.ParamAttr(),
bias_attr=False)
batch_norm = nn.batch_norm(input=conv_layer, act=act, param_attr=paddle.ParamAttr(), bias_attr=paddle.ParamAttr())
# 将填充部分重置为0
padding_reset = paddle.multiply(batch_norm, masks)
return padding_reset
def bidirectional_gru_bn_layer(input, size, act):
"""双向gru层与顺序批处理归一化,批处理规范化只在输入状态权值上执行。
:param input: Input layer.
:type input: Variable
:param h_size: GRU的cell的大小
:type h_size: int
:param act: 激活函数类型
:type act: string
:return: 双向GRU层
:rtype: Variable
"""
input_proj_forward = nn.fc(x=input, size=size * 3, weight_attr=paddle.ParamAttr())
input_proj_reverse = nn.fc(x=input, size=size * 3, weight_attr=paddle.ParamAttr())
# 批标准只在与输入相关的预测上执行
input_proj_bn_forward = nn.batch_norm(input=input_proj_forward,
act=None,
param_attr=paddle.ParamAttr(),
bias_attr=paddle.ParamAttr())
input_proj_bn_reverse = nn.batch_norm(input=input_proj_reverse,
act=None,
param_attr=paddle.ParamAttr(),
bias_attr=paddle.ParamAttr())
# forward and backward in time
forward_gru = fluid.layers.dynamic_gru(input=input_proj_bn_forward,
size=size,
gate_activation='sigmoid',
candidate_activation=act,
param_attr=paddle.ParamAttr(),
bias_attr=paddle.ParamAttr(),
is_reverse=False)
reverse_gru = fluid.layers.dynamic_gru(input=input_proj_bn_reverse,
size=size,
gate_activation='sigmoid',
candidate_activation=act,
param_attr=paddle.ParamAttr(),
bias_attr=paddle.ParamAttr(),
is_reverse=True)
return paddle.concat(x=[forward_gru, reverse_gru], axis=1)
def conv_group(input, num_stacks, seq_len_data, masks):
"""具有堆叠卷积层的卷积组
:param input: Input layer.
:type input: Variable
:param num_stacks: 堆叠的卷积层数
:type num_stacks: int
:param seq_len_data: 有效序列长度数据层
:type seq_len_data:Variable
:param masks: 掩码数据层以重置填充
:type masks: Variable
:return: 卷积组的输出层
:rtype: Variable
"""
filter_size = (41, 11)
stride = (2, 3)
padding = (20, 5)
conv = conv_bn_layer(input=input,
filter_size=filter_size,
num_channels_out=32,
stride=stride,
padding=padding,
act="brelu",
masks=masks)
seq_len_data = (np.array(seq_len_data) - filter_size[1] + 2 * padding[1]) // stride[1] + 1
output_height = (161 - 1) // 2 + 1
for i in range(num_stacks - 1):
# reshape masks
output_height = (output_height - 1) // 2 + 1
masks = paddle.slice(masks, axes=[2], starts=[0], ends=[output_height])
conv = conv_bn_layer(input=conv,
filter_size=(21, 11),
num_channels_out=32,
stride=(2, 1),
padding=(10, 5),
act="brelu",
masks=masks)
output_num_channels = 32
return conv, output_num_channels, output_height, seq_len_data
def rnn_group(input, size, num_stacks):
"""RNN组具有堆叠的双向GRU层
:param input: Input layer.
:type input: Variable
:param size:每层RNN的cell大小
:type size: int
:param num_stacks: 堆叠RNN层数
:type num_stacks: int
:return: RNN组的输出层
:rtype: Variable
"""
output = input
for i in range(num_stacks):
output = bidirectional_gru_bn_layer(input=output, size=size, act="relu")
return output
def deep_speech_v2_network(audio_data,
text_data,
seq_len_data,
masks,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=256,
blank=0):
"""DeepSpeech2网络结构
:param audio_data: 音频输入层
:type audio_data: Variable
:param text_data: 标签输入层
:type text_data: Variable
:param seq_len_data: 输出长度输入层
:type seq_len_data: Variable
:param masks: 掩码数据输入层
:type masks: Variable
:param dict_size: 字典大小
:type dict_size: int
:param num_conv_layers: 叠加卷积层数
:type num_conv_layers: int
:param num_rnn_layers: 叠加RNN层数
:type num_rnn_layers: int
:param rnn_size: RNN层隐层的大小
:type rnn_size: int
:return: 模型概率分布和ctc损失
:rtype: tuple of LayerOutput
"""
audio_data = paddle.unsqueeze(audio_data, axis=[1])
# 卷积组
conv_group_output, conv_group_num_channels, conv_group_height, seq_len_data = conv_group(input=audio_data,
num_stacks=num_conv_layers,
seq_len_data=seq_len_data,
masks=masks)
# 转换数据形式卷积特征映射到向量序列
transpose = paddle.transpose(conv_group_output, perm=[0, 3, 1, 2])
reshape_conv_output = paddle.reshape(x=transpose, shape=[0, -1, conv_group_height * conv_group_num_channels])
# 删除padding部分
seq_len_data = paddle.reshape(seq_len_data, [-1])
sequence = nn.sequence_unpad(x=reshape_conv_output, length=seq_len_data)
# RNN组
rnn_group_output = rnn_group(input=sequence, size=rnn_size, num_stacks=num_rnn_layers)
fc = nn.fc(x=rnn_group_output, size=dict_size, weight_attr=paddle.ParamAttr(), bias_attr=paddle.ParamAttr())
# 输出模型概率分布
log_probs = paddle.nn.functional.softmax(fc)
if not text_data:
return log_probs, None
else:
# 计算CTCLoss
ctc_loss = paddle.nn.functional.ctc_loss(log_probs=fc, labels=text_data, blank=blank, norm_by_times=True,
reduction='sum', input_lengths=None, label_lengths=None)
ctc_loss = paddle.sum(ctc_loss)
return log_probs, ctc_loss
# pip install -r requirements.txt
numpy~=1.21.2
scipy==1.6.1
tqdm==4.59.0
pytest-runner
librosa==0.8.0
python-Levenshtein==0.12.2
visualdl~=2.2.0
SoundFile~=0.10.3.post1
resampy==0.2.2
webrtcvad==2.0.10
python_speech_features==0.6
cn2an==0.5.11
zhconv==1.4.2
pyaudio~=0.2.11
flask~=1.1.2
flask-cors
lac
\ No newline at end of file
body {
padding-top: 100px;
background-color: #008e8e;
}
#content {
width: 900px;
height: 500px;
margin: 0 auto;
background-color: #00921d;
position: relative;
}
#player {
bottom: 50px;
left: 300px;
position: absolute;
}
#result {
background-color: white;
width: 300px;
height: 250px;
top: 100px;
right: 100px;
position: absolute;
}
#result_p{
width: 300px;
height: 250px;
}
#upload {
width: 80px;
height: 30px;
top: 100px;
left: 130px;
position: absolute;
}
#upload_long {
width: 80px;
height: 30px;
top: 100px;
left: 300px;
position: absolute;
}
#btn {
top: 180px;
left: 150px;
position: relative;
}
.file {
position: relative;
text-align: center;
display: inline-block;
background: #D0EEFF;
border: 1px solid #99D3F5;
border-radius: 4px;
padding: 8px 40px;
overflow: hidden;
color: #1E88C7;
text-decoration: none;
text-indent: 0;
line-height: 20px;
}
#upload_recod_btn {
top: 290px;
width: 200px;
height: 30px;
left: 155px;
margin-top: 10px;
position: absolute;
}
.file input {
position: absolute;
font-size: 100px;
right: 0;
top: 0;
opacity: 0;
}
.file:hover {
background: #AADFFD;
border-color: #78C3F3;
color: #004974;
text-decoration: none;
}
#record_btn {
position: relative;
width: 80px;
height: 80px;
left: 150px;
top: 185px;
}
#play_btn {
position: relative;
width: 60px;
height: 60px;
left: 250px;
top: 178px;
}
\ No newline at end of file
//兼容
window.URL = window.URL || window.webkitURL;
//获取计算机的设备:摄像头或者录音设备
navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || navigator.mozGetUserMedia || navigator.msGetUserMedia;
var HZRecorder = function (stream, config) {
config = config || {};
config.sampleBits = config.sampleBits || 16; //采样数位 8, 16
config.sampleRate = config.sampleRate || 16000; //采样率 16000
//创建一个音频环境对象
var audioContext = window.AudioContext || window.webkitAudioContext;
var context = new audioContext();
var audioInput = context.createMediaStreamSource(stream);
// 第二个和第三个参数指的是输入和输出都是单声道,2是双声道。
var recorder = context.createScriptProcessor(4096, 2, 2);
var audioData = {
size: 0 //录音文件长度
, buffer: [] //录音缓存
, inputSampleRate: context.sampleRate //输入采样率
, inputSampleBits: 16 //输入采样数位 8, 16
, outputSampleRate: config.sampleRate //输出采样率
, outputSampleBits: config.sampleBits //输出采样数位 8, 16
, input: function (data) {
this.buffer.push(new Float32Array(data));
this.size += data.length;
}
, compress: function () { //合并压缩
//合并
var data = new Float32Array(this.size);
var offset = 0;
for (var i = 0; i < this.buffer.length; i++) {
data.set(this.buffer[i], offset);
offset += this.buffer[i].length;
}
//压缩
var compression = parseInt(this.inputSampleRate / this.outputSampleRate);
var length = data.length / compression;
var result = new Float32Array(length);
var index = 0, j = 0;
while (index < length) {
result[index] = data[j];
j += compression;
index++;
}
return result;
}
, encodeWAV: function () {
var sampleRate = Math.min(this.inputSampleRate, this.outputSampleRate);
var sampleBits = Math.min(this.inputSampleBits, this.outputSampleBits);
var bytes = this.compress();
var dataLength = bytes.length * (sampleBits / 8);
var buffer = new ArrayBuffer(44 + dataLength);
var data = new DataView(buffer);
var channelCount = 1;//单声道
var offset = 0;
var writeString = function (str) {
for (var i = 0; i < str.length; i++) {
data.setUint8(offset + i, str.charCodeAt(i));
}
}
// 资源交换文件标识符
writeString('RIFF'); offset += 4;
// 下个地址开始到文件尾总字节数,即文件大小-8
data.setUint32(offset, 36 + dataLength, true); offset += 4;
// WAV文件标志
writeString('WAVE'); offset += 4;
// 波形格式标志
writeString('fmt '); offset += 4;
// 过滤字节,一般为 0x10 = 16
data.setUint32(offset, 16, true); offset += 4;
// 格式类别 (PCM形式采样数据)
data.setUint16(offset, 1, true); offset += 2;
// 通道数
data.setUint16(offset, channelCount, true); offset += 2;
// 采样率,每秒样本数,表示每个通道的播放速度
data.setUint32(offset, sampleRate, true); offset += 4;
// 波形数据传输率 (每秒平均字节数) 单声道×每秒数据位数×每样本数据位/8
data.setUint32(offset, channelCount * sampleRate * (sampleBits / 8), true); offset += 4;
// 快数据调整数 采样一次占用字节数 单声道×每样本的数据位数/8
data.setUint16(offset, channelCount * (sampleBits / 8), true); offset += 2;
// 每样本数据位数
data.setUint16(offset, sampleBits, true); offset += 2;
// 数据标识符
writeString('data'); offset += 4;
// 采样数据总数,即数据总大小-44
data.setUint32(offset, dataLength, true); offset += 4;
// 写入采样数据
if (sampleBits === 8) {
for (var i = 0; i < bytes.length; i++, offset++) {
var s = Math.max(-1, Math.min(1, bytes[i]));
var val = s < 0 ? s * 0x8000 : s * 0x7FFF;
val = parseInt(255 / (65535 / (val + 32768)));
data.setInt8(offset, val, true);
}
} else {
for (var i = 0; i < bytes.length; i++, offset += 2) {
var s = Math.max(-1, Math.min(1, bytes[i]));
data.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
}
}
return new Blob([data], { type: 'audio/wav' });
}
};
//开始录音
this.start = function () {
audioInput.connect(recorder);
recorder.connect(context.destination);
}
//停止
this.stop = function () {
recorder.disconnect();
}
//获取音频文件
this.getBlob = function () {
this.stop();
return audioData.encodeWAV();
}
//回放
this.play = function (audio) {
audio.src = window.URL.createObjectURL(this.getBlob());
}
//清除
this.clear = function(){
audioData.buffer=[];
audioData.size=0;
}
//上传
this.upload = function (url, callback) {
var fd = new FormData();
// 上传的文件名和数据
fd.append("audio", this.getBlob());
var xhr = new XMLHttpRequest();
xhr.timeout = 60000
if (callback) {
xhr.upload.addEventListener("progress", function (e) {
callback('uploading', e);
}, false);
xhr.addEventListener("load", function (e) {
callback('ok', e);
}, false);
xhr.addEventListener("error", function (e) {
callback('error', e);
}, false);
xhr.addEventListener("abort", function (e) {
callback('cancel', e);
}, false);
}
xhr.open("POST", url);
xhr.send(fd);
}
//音频采集
recorder.onaudioprocess = function (e) {
audioData.input(e.inputBuffer.getChannelData(0));
//record(e.inputBuffer.getChannelData(0));
}
};
//抛出异常
HZRecorder.throwError = function (message) {
alert(message);
throw new function () { this.toString = function () { return message; } }
}
//是否支持录音
HZRecorder.canRecording = (navigator.getUserMedia != null);
//获取录音机
HZRecorder.get = function (callback, config) {
if (callback) {
if (navigator.getUserMedia) {
navigator.getUserMedia(
{ audio: true } //只启用音频
, function (stream) {
var rec = new HZRecorder(stream, config);
callback(rec);
}
, function (error) {
switch (error.code || error.name) {
case 'PERMISSION_DENIED':
case 'PermissionDeniedError':
HZRecorder.throwError('用户拒绝提供信息。');
break;
case 'NOT_SUPPORTED_ERROR':
case 'NotSupportedError':
HZRecorder.throwError('浏览器不支持硬件设备。');
break;
case 'MANDATORY_UNSATISFIED_ERROR':
case 'MandatoryUnsatisfiedError':
HZRecorder.throwError('无法发现指定的硬件设备。');
break;
default:
HZRecorder.throwError('无法打开麦克风。异常信息:' + (error.code || error.name));
break;
}
});
} else {
HZRecorder.throwErr('当前浏览器不支持录音功能。'); return;
}
}
};
\ No newline at end of file
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>语音识别-夜雨飘零</title>
<script type="text/javascript" src="record.js"></script>
<link href="index.css" rel="stylesheet" type="text/css"/>
</head>
<body>
<div id="content">
<div>
<a id="upload" onclick="uploadAudioFile()" class="file">短音频文件识别</a>
<a id="upload_long" onclick="uploadLongAudioFile()" class="file">长音频文件识别</a>
<img id="record_btn" onclick="record()" src="record.png" alt="录音"/>
<img id="play_btn" onclick="play()" src="player.png" alt="播放"/>
<a onclick="uploadRecordAudio()" class="file" id="upload_recod_btn">上传录音文件</a>
</div>
<div id="result">
<textarea id="result_p"></textarea>
</div>
<div id="player">
<audio controls autoplay></audio>
</div>
</div>
<script>
var is_recording = false;
var is_playing = false;
var recorder;
var audio = document.querySelector('audio');
function record() {
if (is_recording) {
is_recording = false;
stopRecording()
document.getElementById('record_btn').src = 'record.png'
} else {
is_recording = true;
startRecording()
document.getElementById('record_btn').src = 'recording.gif'
}
}
function play() {
if (is_playing) {
is_playing = false;
stopPlay()
document.getElementById('play_btn').src = 'player.png'
} else {
is_playing = true;
startPlay()
document.getElementById('play_btn').src = 'stop.png'
}
}
function startRecording() {
HZRecorder.get(function (rec) {
recorder = rec;
recorder.start();
});
}
function stopRecording() {
recorder.stop();
}
function startPlay() {
recorder.play(audio);
}
function stopPlay() {
audio.pause();
}
function cancelAudio() {
recorder.stop();
recorder.clear();
}
function uploadRecordAudio() {
recorder.upload(location.origin + "/recognition", function (state, e) {
switch (state) {
case 'uploading':
var percentComplete = Math.round(e.loaded * 100 / e.total) + '%';
console.log(percentComplete);
break;
case 'ok':
console.log(e.target.responseText)
document.getElementById('result_p').innerHTML = e.target.responseText
break;
case 'error':
alert("上传失败");
break;
case 'cancel':
alert("上传被取消");
break;
}
});
}
function uploadAudioFile(){
var input = document.createElement("input");
input.type = "file";
input.accept = "audio/*";
input.click();
input.onchange = function(){
var file = input.files[0];
upload_file(location.origin + "/recognition", file, function (state, e) {
switch (state) {
case 'uploading':
var percentComplete = Math.round(e.loaded * 100 / e.total) + '%';
console.log(percentComplete);
break;
case 'ok':
console.log(e.target.responseText)
document.getElementById('result_p').innerHTML = e.target.responseText
break;
case 'error':
alert("上传失败");
break;
case 'cancel':
alert("上传被取消");
break;
}
});
}
}
function uploadLongAudioFile(){
var input = document.createElement("input");
input.type = "file";
input.accept = "audio/*";
input.click();
input.onchange = function(){
var file = input.files[0];
upload_file(location.origin + "/recognition_long_audio", file, function (state, e) {
switch (state) {
case 'uploading':
var percentComplete = Math.round(e.loaded * 100 / e.total) + '%';
console.log(percentComplete);
break;
case 'ok':
console.log(e.target.responseText)
document.getElementById('result_p').innerHTML = e.target.responseText
break;
case 'error':
alert("上传失败");
break;
case 'cancel':
alert("上传被取消");
break;
}
});
}
}
// 上传音频文件
upload_file = function (url, file, callback) {
var fd = new FormData();
// 上传的文件名和数据
fd.append("audio", file);
var xhr = new XMLHttpRequest();
xhr.timeout = 60000
if (callback) {
xhr.upload.addEventListener("progress", function (e) {
callback('uploading', e);
}, false);
xhr.addEventListener("load", function (e) {
callback('ok', e);
}, false);
xhr.addEventListener("error", function (e) {
callback('error', e);
}, false);
xhr.addEventListener("abort", function (e) {
callback('cancel', e);
}, false);
}
xhr.open("POST", url);
xhr.send(fd);
}
</script>
</body>
</html>
\ No newline at end of file
import argparse
import os
import shutil
import ijson
from pydub import AudioSegment
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--wenetspeech_json', type=str, default='/media/wenetspeech/WenetSpeech.json', help="WenetSpeech的标注json文件路径")
parser.add_argument('--annotation_dir', type=str, default='../dataset/annotation/', help="存放数量列表的文件夹路径")
args = parser.parse_args()
def process_wenetspeech(wenetspeech_json, annotation_dir):
input_dir = os.path.dirname(wenetspeech_json)
if not os.path.exists(annotation_dir):
os.makedirs(annotation_dir)
if os.path.exists(os.path.join(annotation_dir, 'wenetspeech.txt')):
f_ann = open(os.path.join(annotation_dir, 'wenetspeech.txt'), 'a', encoding='utf-8')
else:
f_ann = open(os.path.join(annotation_dir, 'wenetspeech.txt'), 'w', encoding='utf-8')
if os.path.exists(os.path.join(annotation_dir, 'test.txt')):
f_ann_test = open(os.path.join(annotation_dir, 'test.txt'), 'a', encoding='utf-8')
else:
f_ann_test = open(os.path.join(annotation_dir, 'test.txt'), 'w', encoding='utf-8')
with open(wenetspeech_json, 'r', encoding='utf-8') as f:
objects = ijson.items(f, 'audios.item')
while True:
try:
long_audio = objects.__next__()
try:
long_audio_path = os.path.realpath(os.path.join(input_dir, long_audio['path']))
aid = long_audio['aid']
segments_lists = long_audio['segments']
assert (os.path.exists(long_audio_path))
except AssertionError:
print(f'''Warning: {long_audio_path} 不存在或者已经处理过自动删除了,跳过''')
continue
except Exception:
print(f'''Warning: {aid} 数据读取错误,跳过''')
continue
else:
print(f'正在处理{long_audio_path}音频')
save_dir = long_audio_path[:-5]
os.makedirs(save_dir, exist_ok=True)
source_wav = AudioSegment.from_file(long_audio_path)
for segment_file in segments_lists:
try:
sid = segment_file['sid']
start_time = segment_file['begin_time']
end_time = segment_file['end_time']
text = segment_file['text']
confidence = segment_file['confidence']
if confidence < 0.95: continue
except Exception:
print(f'''Warning: {segment_file} something is wrong, skipped''')
continue
else:
start = int(start_time * 1000)
end = int(end_time * 1000)
target_audio = source_wav[start:end].set_frame_rate(16000)
save_audio_path = os.path.join(save_dir, sid.split('_')[-1] + '.wav')
target_audio.export(save_audio_path, format="wav")
if long_audio['path'].split('/')[1] != 'train':
f_ann_test.write('%s\t%s\n' % (save_audio_path, text))
else:
f_ann.write('%s\t%s\n' % (save_audio_path, text))
# 删除已经处理过的音频
os.remove(long_audio_path)
except StopIteration:
print("数据读取完成")
break
shutil.copy(os.path.join(annotation_dir, 'wenetspeech.txt'), os.path.join(input_dir, 'wenetspeech.txt'))
shutil.copy(os.path.join(annotation_dir, 'test.txt'), os.path.join(input_dir, 'test.txt'))
if __name__ == '__main__':
process_wenetspeech(wenetspeech_json=args.wenetspeech_json, annotation_dir=args.annotation_dir)
\ No newline at end of file
import re
from typing import Dict
from typing import List
import numpy as np
import paddle
from parakeet.frontend.zh_frontend import Frontend as cnFrontend
class Frontend():
def __init__(self, phone_vocab_path=None, tone_vocab_path=None):
self.frontend = cnFrontend()
self.vocab_phones = {}
self.vocab_tones = {}
if phone_vocab_path:
with open(phone_vocab_path, 'rt', encoding='utf-8') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
self.vocab_phones[phn] = int(id)
if tone_vocab_path:
with open(tone_vocab_path, 'rt', encoding='utf-8') as f:
tone_id = [line.strip().split() for line in f.readlines()]
for tone, id in tone_id:
self.vocab_tones[tone] = int(id)
def _p2id(self, phonemes: List[str]) -> np.array:
# replace unk phone with sp
phonemes = [
phn if phn in self.vocab_phones else "sp" for phn in phonemes
]
phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
def _t2id(self, tones: List[str]) -> np.array:
# replace unk phone with sp
tones = [tone if tone in self.vocab_tones else "0" for tone in tones]
tone_ids = [self.vocab_tones[item] for item in tones]
return np.array(tone_ids, np.int64)
def _get_phone_tone(self, phonemes: List[str],
get_tone_ids: bool=False) -> List[List[str]]:
phones = []
tones = []
if get_tone_ids and self.vocab_tones:
for full_phone in phonemes:
# split tone from finals
match = re.match(r'^(\w+)([012345])$', full_phone)
if match:
phone = match.group(1)
tone = match.group(2)
# if the merged erhua not in the vocab
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, we split 'iaor' into ['iao','er']
# and the tones accordingly change from ['3'] to ['3','2'], while '2' is the tone of 'er2'
if len(phone) >= 2 and phone != "er" and phone[
-1] == 'r' and phone not in self.vocab_phones and phone[:
-1] in self.vocab_phones:
phones.append(phone[:-1])
phones.append("er")
tones.append(tone)
tones.append("2")
else:
phones.append(phone)
tones.append(tone)
else:
phones.append(full_phone)
tones.append('0')
else:
for phone in phonemes:
# if the merged erhua not in the vocab
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, change ['iaor3'] to ['iao3','er2']
if len(phone) >= 3 and phone[:-1] != "er" and phone[
-2] == 'r' and phone not in self.vocab_phones and (
phone[:-2] + phone[-1]) in self.vocab_phones:
phones.append((phone[:-2] + phone[-1]))
phones.append("er2")
else:
phones.append(phone)
return phones, tones
def get_input_ids(
self,
sentence: str,
merge_sentences: bool=True,
get_tone_ids: bool=False) -> Dict[str, List[paddle.Tensor]]:
phonemes = self.frontend.get_phonemes(
sentence, merge_sentences=merge_sentences)
result = {}
phones = []
tones = []
temp_phone_ids = []
temp_tone_ids = []
for part_phonemes in phonemes:
phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids)
if tones:
tone_ids = self._t2id(tones)
tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids)
if phones:
phone_ids = self._p2id(phones)
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids)
if temp_tone_ids:
result["tone_ids"] = temp_tone_ids
if temp_phone_ids:
result["phone_ids"] = temp_phone_ids
return result
import _thread
import argparse
import logging
import os
import random
import threading
from pathlib import Path
import numpy as np
import paddle
import soundfile
import yaml
from tqdm import tqdm
from yacs.config import CfgNode
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
from parakeet.modules.normalizer import ZScore
from frontend import Frontend
def generate(args, fastspeech2_config, pwg_config):
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for generate
sentences = []
with open(args.text, 'rt', encoding='utf-8') as f:
for line in f:
utt_id, sentence = line.strip().split()
sentences.append((utt_id, sentence))
with open(args.phones_dict, "r", encoding='utf-8') as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
with open(args.speaker_dict, 'rt', encoding='utf-8') as f:
spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id)
print("num_speakers:", num_speakers)
odim = fastspeech2_config.n_mels
model = FastSpeech2(idim=vocab_size,
odim=odim,
num_speakers=num_speakers,
**fastspeech2_config["model"])
model.set_state_dict(paddle.load(args.fastspeech2_checkpoint)["main_params"])
model.eval()
vocoder = PWGGenerator(**pwg_config["generator_params"])
vocoder.set_state_dict(paddle.load(args.pwg_params))
vocoder.remove_weight_norm()
vocoder.eval()
print("model done!")
frontend = Frontend(args.phones_dict)
print("frontend done!")
stat = np.load(args.fastspeech2_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
fastspeech2_normalizer = ZScore(mu, std)
stat = np.load(args.pwg_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
pwg_normalizer = ZScore(mu, std)
fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
pwg_inference = PWGInference(pwg_normalizer, vocoder)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
annotation_path = Path(os.path.dirname(args.annotation_path))
annotation_path.mkdir(parents=True, exist_ok=True)
start_num = 0
if os.path.exists(args.annotation_path):
with open(args.annotation_path, 'r', encoding='utf-8') as f_ann:
start_num = len(f_ann.readlines())
f_ann = open(args.annotation_path, 'a', encoding='utf-8')
# 开始生成音频
for i in tqdm(range(start_num, len(sentences))):
utt_id, sentence = sentences[i]
# 随机说话人
spk_id = random.randint(0, num_speakers - 1)
try:
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
except:
continue
phone_ids = input_ids["phone_ids"]
flags = 0
for part_phone_ids in phone_ids:
with paddle.no_grad():
mel = fastspeech2_inference(part_phone_ids, spk_id=paddle.to_tensor(spk_id))
temp_wav = pwg_inference(mel)
if flags == 0:
wav = temp_wav
flags = 1
else:
wav = paddle.concat([wav, temp_wav])
save_audio_path = str(output_dir / (utt_id + ".wav"))
soundfile.write(save_audio_path, wav.numpy(), samplerate=fastspeech2_config.fs)
f_ann.write('%s\t%s\n' % (save_audio_path[6:].replace('\\', '/'), sentence.replace('。', '').replace(',', '')
.replace('!', '').replace('?', '')))
f_ann.flush()
def main():
parser = argparse.ArgumentParser(description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument("--fastspeech2-config",
type=str,
default='models/fastspeech2_nosil_aishell3_ckpt_0.4/default.yaml',
help="fastspeech2 config file to overwrite default config.")
parser.add_argument("--fastspeech2-checkpoint",
type=str,
default='models/fastspeech2_nosil_aishell3_ckpt_0.4/snapshot_iter_96400.pdz',
help="fastspeech2 checkpoint to load.")
parser.add_argument("--fastspeech2-stat",
type=str,
default='models/fastspeech2_nosil_aishell3_ckpt_0.4/speech_stats.npy',
help="mean and standard deviation used to normalize spectrogram when training fastspeech2.")
parser.add_argument("--pwg-config",
type=str,
default='models/parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml',
help="parallel wavegan config file to overwrite default config.")
parser.add_argument("--pwg-params",
type=str,
default='models/parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams',
help="parallel wavegan generator parameters to load.")
parser.add_argument("--pwg-stat",
type=str,
default='models/parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy',
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan.")
parser.add_argument("--phones-dict",
type=str,
default="models/fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt",
help="phone vocabulary file.")
parser.add_argument("--speaker-dict",
type=str,
default="models/fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt",
help="speaker id map file.")
parser.add_argument("--text",
type=str,
default='corpus.txt',
help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument("--output_dir", type=str, default='../../dataset/audio/generate', help="output audio dir.")
parser.add_argument("--annotation_path", type=str, default='../../dataset/annotation/generate.txt',
help="audio annotation path.")
parser.add_argument("--device", type=str, default="gpu", help="device type to use.")
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
args = parser.parse_args()
with open(args.fastspeech2_config) as f:
fastspeech2_config = CfgNode(yaml.safe_load(f))
with open(args.pwg_config) as f:
pwg_config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(fastspeech2_config)
print(pwg_config)
generate(args, fastspeech2_config, pwg_config)
if __name__ == "__main__":
main()
import os
import re
import cn2an
# 判断是否为中文字符
def is_uchar(in_str):
for i in range(len(in_str)):
uchar = in_str[i]
if u'\u4e00' <= uchar <= u'\u9fa5':
pass
else:
return False
return True
# 制作中文语料
utt_id = 0
corpus_dir = 'dgk_lost_conv/results/'
with open('corpus.txt', 'w', encoding='utf-8') as f_write:
for corpus_path in os.listdir(corpus_dir):
if corpus_path[-5:] != '.conv': continue
corpus_path = os.path.join(corpus_dir, corpus_path)
print(corpus_path)
if 'dgk_shooter_z.conv' in corpus_path:
lines = []
with open(corpus_path, 'r', encoding='utf-8') as f:
while True:
try:
line = f.readline().replace('\n', '')
lines.append(line)
if len(line) == 0: break
except:
continue
else:
with open(corpus_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line[2:].replace('/', '').replace('\n', '').replace('?', '?').replace(' ', '%').replace('.', '')
line = line.replace('~', '!').replace(',', ',').replace('、', ',').replace('!', '!').replace('"', '')
line = line.replace(',,', ',').replace('。。', '。').replace('!!', '!').replace('??', '?')
line = line.replace(',,', ',').replace('。。', '。').replace('!!', '!').replace('??', '?')
line = cn2an.transform(line, "an2cn")
if len(line) < 2: continue
if not is_uchar(line.replace(',', '').replace('。', '').replace('?', '').replace('!', '')): continue
my_re = re.compile(r'[A-Za-z0-9]', re.S)
res = re.findall(my_re, line)
if len(res) > 0: continue
f_write.write('%d %s\n' % (utt_id, line))
utt_id += 1
\ No newline at end of file
"""查找最优的集束搜索方法的alpha参数和beta参数"""
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import numpy as np
import argparse
import functools
import paddle
from tqdm import tqdm
from decoders.beam_search_decoder import BeamSearchDecoder
from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model
from utils.error_rate import char_errors, word_errors
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('num_batches', int, -1, "用于评估的数据数量,当为-1时使用全部数据")
add_arg('batch_size', int, 64, "评估是每一批数据的大小")
add_arg('beam_size', int, 300, "定向搜索的大小,范围:[5, 500]")
add_arg('num_proc_bsearch', int, 10, "定向搜索方法使用CPU数量")
add_arg('num_conv_layers', int, 2, "卷积层数量")
add_arg('num_rnn_layers', int, 3, "循环神经网络的数量")
add_arg('rnn_layer_size', int, 1024, "循环神经网络的大小")
add_arg('num_alphas', int, 45, "用于调优的alpha候选项")
add_arg('num_betas', int, 8, "用于调优的beta候选项")
add_arg('alpha_from', float, 1.0, "alpha调优开始大小")
add_arg('alpha_to', float, 0.45, "alpha调优结速大小")
add_arg('beta_from', float, 0.1, "beta调优开始大小")
add_arg('beta_to', float, 0.35, "beta调优结速大小")
add_arg('cutoff_prob', float, 0.99, "剪枝的概率")
add_arg('cutoff_top_n', int, 40, "剪枝的最大值")
add_arg('use_gpu', bool, True, "是否使用GPU训练")
add_arg('tune_manifest', str, 'dataset/manifest.test', "需要评估的测试数据列表")
add_arg('mean_std_path', str, 'dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
add_arg('vocab_path', str, 'dataset/zh_vocab.txt', "数据集的词汇表文件路径")
add_arg('lang_model_path', str, 'lm/zh_giga.no_cna_cmn.prune01244.klm', "语言模型文件路径")
add_arg('model_path', str, 'models/param/50.pdparams', "训练保存的模型文件夹路径")
add_arg('error_rate_type', str, 'cer', "评估所使用的错误率方法,有字错率(cer)、词错率(wer)", choices=['wer', 'cer'])
args = parser.parse_args()
def tune():
# 逐步调整alphas参数和betas参数
if not args.num_alphas >= 0:
raise ValueError("num_alphas must be non-negative!")
if not args.num_betas >= 0:
raise ValueError("num_betas must be non-negative!")
# 是否使用GPU
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
# 获取数据生成器
data_generator = DataGenerator(vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
keep_transcription_text=True,
place=place,
is_training=False)
# 获取评估数据
batch_reader = data_generator.batch_reader_creator(manifest_path=args.tune_manifest,
batch_size=args.batch_size,
shuffle_method=None)
# 获取DeepSpeech2模型,并设置为预测
ds2_model = DeepSpeech2Model(vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
place=place,
resume_model=args.model_path)
# 获取评估函数,有字错率和词错率
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
# 创建用于搜索的alphas参数和betas参数
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
params_grid = [(alpha, beta) for alpha in cand_alphas for beta in cand_betas]
err_sum = [0.0 for _ in range(len(params_grid))]
err_ave = [0.0 for _ in range(len(params_grid))]
num_ins, len_refs, cur_batch = 0, 0, 0
# 多批增量调优参数
ds2_model.logger.info("start tuning ...")
for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break
# 执行预测
probs_split = ds2_model.infer_batch_data(infer_data=infer_data)
target_transcripts = infer_data[1]
num_ins += len(target_transcripts)
# 搜索alphas参数和betas参数
for index, (alpha, beta) in enumerate(tqdm(params_grid)):
# 初始化集束搜索方法
beam_search_decoder = BeamSearchDecoder(alpha, beta, args.lang_model_path, data_generator.vocab_list)
result_transcripts = beam_search_decoder.decode_batch_beam_search(probs_split=probs_split,
beam_alpha=alpha,
beam_beta=beta,
beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob,
cutoff_top_n=args.cutoff_top_n,
vocab_list=data_generator.vocab_list,
num_processes=args.num_proc_bsearch)
for target, result in zip(target_transcripts, result_transcripts):
errors, len_ref = errors_func(target, result)
err_sum[index] += errors
if args.alpha_from == alpha and args.beta_from == beta:
len_refs += len_ref
err_ave[index] = err_sum[index] / len_refs
# 输出每一个batch的计算结果
err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min)
print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
" min [%s] = %f" % (cur_batch, num_ins, "%.3f" % params_grid[min_index][0],
"%.3f" % params_grid[min_index][1], args.error_rate_type, err_ave_min))
cur_batch += 1
# 输出字错率和词错率以及(alpha, beta)
print("\nFinal %s:\n" % args.error_rate_type)
for index in range(len(params_grid)):
print("(alpha, beta) = (%s, %s), [%s] = %f"
% ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1], args.error_rate_type, err_ave[index]))
err_ave_min = min(err_ave)
min_index = err_ave.index(err_ave_min)
print("\n一共使用了 %d 批数据推理, 最优的参数为 (alpha, beta) = (%s, %s)"
% (cur_batch, "%.3f" % params_grid[min_index][0], "%.3f" % params_grid[min_index][1]))
ds2_model.logger.info("finish tuning")
def main():
print_arguments(args)
tune()
if __name__ == '__main__':
main()
import argparse
import functools
import io
import os
from datetime import datetime
from model_utils.model import DeepSpeech2Model
from data_utils.data import DataGenerator
from utils.utility import add_arguments, print_arguments, get_data_len
import paddle
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "是否使用GPU训练")
add_arg('batch_size', int, 16, "训练每一批数据的大小")
add_arg('num_epoch', int, 50, "训练的轮数")
add_arg('num_conv_layers', int, 2, "卷积层数量")
add_arg('num_rnn_layers', int, 3, "循环神经网络的数量")
add_arg('rnn_layer_size', int, 1024, "循环神经网络的大小")
add_arg('learning_rate', float, 5e-4, "初始学习率")
add_arg('min_duration', float, 0.5, "最短的用于训练的音频长度")
add_arg('max_duration', float, 20.0, "最长的用于训练的音频长度")
add_arg('test_off', bool, False, "是否关闭测试")
add_arg('resume_model', str, None, "恢复训练,当为None则不使用预训练模型")
add_arg('pretrained_model', str, None, "使用预训练模型的路径,当为None是不使用预训练模型")
add_arg('train_manifest', str, './dataset/manifest.train', "训练的数据列表")
add_arg('test_manifest', str, './dataset/manifest.test', "测试的数据列表")
add_arg('mean_std_path', str, './dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
add_arg('vocab_path', str, './dataset/zh_vocab.txt', "数据集的词汇表文件路径")
add_arg('output_model_dir', str, './models/param', "保存训练模型的文件夹")
add_arg('augment_conf_path', str, './conf/augmentation.json', "数据增强的配置文件,为json格式")
add_arg('shuffle_method', str, 'batch_shuffle_clipped', "打乱数据的方法", choices=['instance_shuffle', 'batch_shuffle', 'batch_shuffle_clipped'])
args = parser.parse_args()
# 训练模型
def train():
# 是否使用GPU
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
# 获取训练数据生成器
augmentation_config = io.open(args.augment_conf_path, mode='r', encoding='utf8').read() if args.augment_conf_path is not None else '{}'
train_generator = DataGenerator(vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config=augmentation_config,
max_duration=args.max_duration,
min_duration=args.min_duration,
place=place)
if args.resume_model:
try:
pre_epoch = os.path.basename(args.resume_model).split('.')[0]
train_generator.epoch = int(pre_epoch)
except:
pass
# 获取测试数据生成器
test_generator = DataGenerator(vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
keep_transcription_text=True,
place=place,
is_training=False)
# 获取训练数据
train_batch_reader = train_generator.batch_reader_creator(manifest_path=args.train_manifest,
batch_size=args.batch_size,
shuffle_method=args.shuffle_method)
# 获取测试数据
test_batch_reader = test_generator.batch_reader_creator(manifest_path=args.test_manifest,
batch_size=args.batch_size,
shuffle_method=None)
# 获取DeepSpeech2模型
ds2_model = DeepSpeech2Model(vocab_size=train_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
place=place,
pretrained_model=args.pretrained_model,
resume_model=args.resume_model,
output_model_dir=args.output_model_dir,
vocab_list=train_generator.vocab_list)
# 获取训练数据数量
train_num_samples = get_data_len(args.train_manifest, args.max_duration, args.min_duration)
print("[%s] 训练数据数量:%d\n" % (datetime.now(), train_num_samples))
# 获取训测试据数量
test_num_samples = get_data_len(args.test_manifest, args.max_duration, args.min_duration)
print("[%s] 测试数据数量:%d\n" % (datetime.now(), test_num_samples))
# 开始训练
ds2_model.train(train_batch_reader=train_batch_reader,
dev_batch_reader=test_batch_reader,
learning_rate=args.learning_rate,
gradient_clipping=400,
batch_size=args.batch_size,
train_num_samples=train_num_samples,
test_num_samples=test_num_samples,
num_epoch=args.num_epoch,
test_off=args.test_off)
def main():
print_arguments(args)
train()
if __name__ == '__main__':
main()
import collections
import contextlib
import os
import wave
import webrtcvad
def read_wave(path):
"""Reads a .wav file.
Takes the path, and returns (PCM audio data, sample rate).
"""
with contextlib.closing(wave.open(path, 'rb')) as wf:
num_channels = wf.getnchannels()
assert num_channels == 1
sample_width = wf.getsampwidth()
assert sample_width == 2
sample_rate = wf.getframerate()
assert sample_rate in (8000, 16000, 32000, 48000)
pcm_data = wf.readframes(wf.getnframes())
return pcm_data, sample_rate
def write_wave(path, audio, sample_rate):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio)
class Frame(object):
"""Represents a "frame" of audio data."""
def __init__(self, bytes, timestamp, duration):
self.bytes = bytes
self.timestamp = timestamp
self.duration = duration
def frame_generator(frame_duration_ms, audio, sample_rate):
"""Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and the sample rate.
Yields Frames of the requested duration.
"""
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
offset = 0
timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += duration
offset += n
def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
"""Filters out non-voiced audio frames.
Given a webrtcvad.Vad and a source of audio frames, yields only
the voiced audio.
Uses a padded, sliding window algorithm over the audio frames.
When more than 90% of the frames in the window are voiced (as
reported by the VAD), the collector triggers and begins yielding
audio frames. Then the collector waits until 90% of the frames in
the window are unvoiced to detrigger.
The window is padded at the front and back to provide a small
amount of silence or the beginnings/endings of speech around the
voiced frames.
Arguments:
sample_rate - The audio sample rate, in Hz.
frame_duration_ms - The frame duration in milliseconds.
padding_duration_ms - The amount to pad the window, in milliseconds.
vad - An instance of webrtcvad.Vad.
frames - a source of audio frames (sequence or generator).
Returns: A generator that yields PCM audio data.
"""
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
# We use a deque for our sliding window/ring buffer.
ring_buffer = collections.deque(maxlen=num_padding_frames)
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the NOTTRIGGERED state.
triggered = False
voiced_frames = []
for i, frame in enumerate(frames):
is_speech = vad.is_speech(frame.bytes, sample_rate)
if not triggered:
ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in ring_buffer if speech])
# If we're NOTTRIGGERED and more than 90% of the frames in
# the ring buffer are voiced frames, then enter the
# TRIGGERED state.
if num_voiced > 0.9 * ring_buffer.maxlen:
triggered = True
# We want to yield all the audio we see from now until
# we are NOTTRIGGERED, but we have to start with the
# audio that's already in the ring buffer.
for f, s in ring_buffer:
voiced_frames.append(f)
ring_buffer.clear()
else:
# We're in the TRIGGERED state, so collect the audio data
# and add it to the ring buffer.
voiced_frames.append(frame)
ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
# If more than 90% of the frames in the ring buffer are
# unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected.
if num_unvoiced > 0.9 * ring_buffer.maxlen:
triggered = False
yield b''.join([f.bytes for f in voiced_frames]), voiced_frames[0].timestamp, voiced_frames[
-1].timestamp
ring_buffer.clear()
voiced_frames = []
# If we have any leftover voiced audio when we run out of input, yield it.
if voiced_frames:
yield b''.join([f.bytes for f in voiced_frames]), voiced_frames[0].timestamp, voiced_frames[
-1].timestamp
def crop_audio_vad(audio_path, aggressiveness=1, frame_duration_ms=30):
audio, sample_rate = read_wave(audio_path)
vad = webrtcvad.Vad(aggressiveness)
frames = frame_generator(frame_duration_ms, audio, sample_rate)
frames = list(frames)
segments = vad_collector(sample_rate, frame_duration_ms, 300, vad, frames)
audios_path = []
time_stamps = []
save_path = os.path.join(os.path.dirname(audio_path), 'crop_audio')
if not os.path.exists(save_path):
os.makedirs(save_path)
for i, segment in enumerate(segments):
path = os.path.join(save_path, '%s_%d.wav' % (os.path.basename(audio_path)[:-4], i))
write_wave(path, segment[0], sample_rate)
audios_path.append(path)
time_stamps.append(segment[1:])
return audios_path, time_stamps
"""用于计算字错率和词错率"""
import numpy as np
def _levenshtein_distance(ref, hyp):
"""Levenshtein distance is a string metric for measuring the difference
between two sequences. Informally, the levenshtein disctance is defined as
the minimum number of single-character edits (substitutions, insertions or
deletions) required to change one word into the other. We can naturally
extend the edits to word level when calculate levenshtein disctance for
two sentences.
"""
m = len(ref)
n = len(hyp)
# special case
if ref == hyp:
return 0
if m == 0:
return n
if n == 0:
return m
if m < n:
ref, hyp = hyp, ref
m, n = n, m
# use O(min(m, n)) space
distance = np.zeros((2, n + 1), dtype=np.int32)
# initialize distance matrix
for j in range(n + 1):
distance[0][j] = j
# calculate levenshtein distance
for i in range(1, m + 1):
prev_row_idx = (i - 1) % 2
cur_row_idx = i % 2
distance[cur_row_idx][0] = i
for j in range(1, n + 1):
if ref[i - 1] == hyp[j - 1]:
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
else:
s_num = distance[prev_row_idx][j - 1] + 1
i_num = distance[cur_row_idx][j - 1] + 1
d_num = distance[prev_row_idx][j] + 1
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
return distance[m % 2][n]
def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
"""Compute the levenshtein distance between reference sequence and
hypothesis sequence in word-level.
:param reference: The reference sentence.
:type reference: str
:param hypothesis: The hypothesis sentence.
:type hypothesis: str
:param ignore_case: Whether case-sensitive or not.
:type ignore_case: bool
:param delimiter: Delimiter of input sentences.
:type delimiter: char
:return: Levenshtein distance and word number of reference sentence.
:rtype: list
"""
if ignore_case:
reference = reference.lower()
hypothesis = hypothesis.lower()
ref_words = list(filter(None, reference.split(delimiter)))
hyp_words = list(filter(None, hypothesis.split(delimiter)))
edit_distance = _levenshtein_distance(ref_words, hyp_words)
return float(edit_distance), len(ref_words)
def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
"""Compute the levenshtein distance between reference sequence and
hypothesis sequence in char-level.
:param reference: The reference sentence.
:type reference: str
:param hypothesis: The hypothesis sentence.
:type hypothesis: str
:param ignore_case: Whether case-sensitive or not.
:type ignore_case: bool
:param remove_space: Whether remove internal space characters
:type remove_space: bool
:return: Levenshtein distance and length of reference sentence.
:rtype: list
"""
if ignore_case:
reference = reference.lower()
hypothesis = hypothesis.lower()
join_char = ' '
if remove_space:
join_char = ''
reference = join_char.join(list(filter(None, reference.split(' '))))
hypothesis = join_char.join(list(filter(None, hypothesis.split(' '))))
edit_distance = _levenshtein_distance(reference, hypothesis)
return float(edit_distance), len(reference)
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
"""Calculate word error rate (WER). WER compares reference text and
hypothesis text in word-level. WER is defined as:
.. math::
WER = (Sw + Dw + Iw) / Nw
where
.. code-block:: text
Sw is the number of words subsituted,
Dw is the number of words deleted,
Iw is the number of words inserted,
Nw is the number of words in the reference
We can use levenshtein distance to calculate WER. Please draw an attention
that empty items will be removed when splitting sentences by delimiter.
:param reference: The reference sentence.
:type reference: str
:param hypothesis: The hypothesis sentence.
:type hypothesis: str
:param ignore_case: Whether case-sensitive or not.
:type ignore_case: bool
:param delimiter: Delimiter of input sentences.
:type delimiter: char
:return: Word error rate.
:rtype: float
:raises ValueError: If word number of reference is zero.
"""
edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case,
delimiter)
if ref_len == 0:
raise ValueError("Reference's word number should be greater than 0.")
wer = float(edit_distance) / ref_len
return wer
def cer(reference, hypothesis, ignore_case=False, remove_space=False):
"""Calculate charactor error rate (CER). CER compares reference text and
hypothesis text in char-level. CER is defined as:
.. math::
CER = (Sc + Dc + Ic) / Nc
where
.. code-block:: text
Sc is the number of characters substituted,
Dc is the number of characters deleted,
Ic is the number of characters inserted
Nc is the number of characters in the reference
We can use levenshtein distance to calculate CER. Chinese input should be
encoded to unicode. Please draw an attention that the leading and tailing
space characters will be truncated and multiple consecutive space
characters in a sentence will be replaced by one space character.
:param reference: The reference sentence.
:type reference: str
:param hypothesis: The hypothesis sentence.
:type hypothesis: str
:param ignore_case: Whether case-sensitive or not.
:type ignore_case: bool
:param remove_space: Whether remove internal space characters
:type remove_space: bool
:return: Character error rate.
:rtype: float
:raises ValueError: If the reference length is zero.
"""
edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case,
remove_space)
if ref_len == 0:
raise ValueError("Length of reference should be greater than 0.")
cer = float(edit_distance) / ref_len
return cer
import os
import sys
from LAC import LAC
import cn2an
import numpy as np
import paddle.inference as paddle_infer
from decoders.ctc_greedy_decoder import greedy_decoder
class Predictor:
def __init__(self, model_dir, audio_process, decoding_method='ctc_greedy', alpha=1.2, beta=0.35,
lang_model_path=None, beam_size=10, cutoff_prob=1.0, cutoff_top_n=40, use_gpu=True, gpu_mem=500,
enable_mkldnn=False, num_threads=10):
self.audio_process = audio_process
self.decoding_method = decoding_method
self.alpha = alpha
self.beta = beta
self.lang_model_path = lang_model_path
self.beam_size = beam_size
self.cutoff_prob = cutoff_prob
self.cutoff_top_n = cutoff_top_n
self.use_gpu = use_gpu
self.lac = None
# 集束搜索方法的处理
if decoding_method == "ctc_beam_search":
try:
from decoders.beam_search_decoder import BeamSearchDecoder
self.beam_search_decoder = BeamSearchDecoder(alpha, beta, lang_model_path, audio_process.vocab_list)
except ModuleNotFoundError:
raise Exception('缺少swig_decoders库,请根据文档安装,如果是Windows系统,请使用ctc_greedy。')
# 创建 config
model_path = os.path.join(model_dir, 'inference.pdmodel')
params_path = os.path.join(model_dir, 'inference.pdiparams')
if not os.path.exists(model_path) or not os.path.exists(params_path):
raise Exception("模型文件不存在,请检查%s和%s是否存在!" % (model_path, params_path))
self.config = paddle_infer.Config(model_path, params_path)
self.config.enable_use_gpu(1000, 0)
self.config.enable_memory_optim()
if self.use_gpu:
self.config.enable_use_gpu(gpu_mem, 0)
else:
self.config.disable_gpu()
self.config.set_cpu_math_library_num_threads(num_threads)
if enable_mkldnn:
self.config.set_mkldnn_cache_capacity(10)
self.config.enable_mkldnn()
# enable memory optim
self.config.enable_memory_optim()
self.config.disable_glog_info()
# 根据 config 创建 predictor
self.predictor = paddle_infer.create_predictor(self.config)
# 获取输入层
self.audio_data_handle = self.predictor.get_input_handle('audio_data')
self.seq_len_data_handle = self.predictor.get_input_handle('seq_len_data')
self.masks_handle = self.predictor.get_input_handle('masks')
# 获取输出的名称
self.output_names = self.predictor.get_output_names()
# 预热
warmup_audio_path = 'dataset/test.wav'
if os.path.exists(warmup_audio_path):
self.predict(warmup_audio_path, to_an=True)
else:
print('预热文件不存在,忽略预热!', file=sys.stderr)
# 预测图片
def predict(self, audio_path, to_an=False):
# 加载音频文件,并进行预处理
audio_feature = self.audio_process.process_utterance(audio_path)
audio_len = audio_feature.shape[1]
mask_shape0 = (audio_feature.shape[0] - 1) // 2 + 1
mask_shape1 = (audio_feature.shape[1] - 1) // 3 + 1
mask_max_len = (audio_len - 1) // 3 + 1
mask_ones = np.ones((mask_shape0, mask_shape1))
mask_zeros = np.zeros((mask_shape0, mask_max_len - mask_shape1))
mask = np.repeat(np.reshape(np.concatenate((mask_ones, mask_zeros), axis=1),
(1, mask_shape0, mask_max_len)), 32, axis=0)
audio_data = np.array(audio_feature).astype('float32')[np.newaxis, :]
seq_len_data = np.array([audio_len]).astype('int64')
masks = np.array(mask).astype('float32')[np.newaxis, :]
# 设置输入
self.audio_data_handle.reshape([audio_data.shape[0], audio_data.shape[1], audio_data.shape[2]])
self.seq_len_data_handle.reshape([audio_data.shape[0]])
self.masks_handle.reshape([masks.shape[0], masks.shape[1], masks.shape[2], masks.shape[3]])
self.audio_data_handle.copy_from_cpu(audio_data)
self.seq_len_data_handle.copy_from_cpu(seq_len_data)
self.masks_handle.copy_from_cpu(masks)
# 运行predictor
self.predictor.run()
# 获取输出
output_handle = self.predictor.get_output_handle(self.output_names[0])
output_data = output_handle.copy_to_cpu()
# 执行解码
if self.decoding_method == 'ctc_beam_search':
# 集束搜索解码策略
result = self.beam_search_decoder.decode_beam_search(probs_split=output_data,
beam_alpha=self.alpha,
beam_beta=self.beta,
beam_size=self.beam_size,
cutoff_prob=self.cutoff_prob,
cutoff_top_n=self.cutoff_top_n,
vocab_list=self.audio_process.vocab_list)
else:
# 贪心解码策略
result = greedy_decoder(probs_seq=output_data, vocabulary=self.audio_process.vocab_list)
score, text = result[0], result[1]
# 是否转为阿拉伯数字
if to_an:
text = self.cn2an(text)
return score, text
# 是否转为阿拉伯数字
def cn2an(self, text):
# 获取分词模型
if self.lac is None:
self.lac = LAC(mode='lac', use_cuda=self.use_gpu)
lac_result = self.lac.run(text)
result_text = ''
for t, r in zip(lac_result[0], lac_result[1]):
if r == 'm' or r == 'TIME':
t = cn2an.transform(t, "cn2an")
result_text += t
return result_text
"""Contains common utility functions."""
import distutils.util
import librosa
import soundfile
from data_utils.utility import read_manifest
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
# 获取训练数据长度
def get_data_len(manifest_path, max_duration, min_duration):
manifest = read_manifest(manifest_path=manifest_path,
max_duration=max_duration,
min_duration=min_duration)
return len(manifest)
# 改变音频采样率为16000Hz
def change_rate(audio_path):
audio_path = audio_path.replace('\\', '/')
data, sr = soundfile.read(audio_path)
if sr != 16000:
data, sr = librosa.load(audio_path, sr=16000)
soundfile.write(audio_path, data, samplerate=16000)
...@@ -6,47 +6,6 @@ from openpyxl.styles import PatternFill, Alignment ...@@ -6,47 +6,6 @@ from openpyxl.styles import PatternFill, Alignment
from split_wav import * from split_wav import *
def create_sheet(path, sheet_name, value):
"""
根据给定的表头,初始化表格,
:param path: str, 表格(book)的存储位置
:param sheet_name: str, 表(sheet)的名字
:param value: list, 表头内容为['起始时间','终止时间','字幕','建议','旁边解说脚本']
:return: None
"""
index = len(value)
workbook = openpyxl.Workbook()
sheet = workbook.active
sheet.title = sheet_name
# 将字幕对应的那一列扩宽一些
sheet.column_dimensions['C'].width = 50
for i in range(0, index):
for j in range(0, len(value[i])):
sheet.cell(row=i + 1, column=j + 1, value=str(value[i][j]))
workbook.save(path)
def write_to_sheet(path, sheet_name, value):
"""
向已存在的表格中写入数据
:param path:
:param sheet_name:
:param value:
:return:
"""
index = len(value)
workbook = openpyxl.load_workbook(path)
sheet = workbook.get_sheet_by_name(sheet_name)
cur_row = sheet.max_row
for i in range(0, index):
for j in range(0, len(value[i])):
sheet.cell(row=cur_row + i + 1, column=j + 1, value=str(value[i][j]))
if value[i][j] == '' or value[i][j] == '插入旁白':
sheet.cell(row=cur_row + i + 1, column=j + 1).fill = PatternFill(fill_type='solid', fgColor='ffff00')
if j == 2:
sheet.cell(row=cur_row + i + 1, column=j + 1).alignment = Alignment(wrapText=True)
workbook.save(path)
def trans_to_mono(wav_path): def trans_to_mono(wav_path):
""" """
...@@ -95,15 +54,12 @@ def detect_with_asr(video_path, book_path, start_time=0, end_time=-1, state=None ...@@ -95,15 +54,12 @@ def detect_with_asr(video_path, book_path, start_time=0, end_time=-1, state=None
book_name_xlsx = book_path book_name_xlsx = book_path
sheet_name_xlsx = "旁白插入位置建议" sheet_name_xlsx = "旁白插入位置建议"
# 如果当前路径下不存在与视频同名的表格,则创建输出内容存放的表格
if not os.path.exists(book_name_xlsx):
table_head = [["起始时间", "终止时间", "字幕", '建议', '解说脚本']]
create_sheet(book_name_xlsx, sheet_name_xlsx, table_head)
sys.path.append("./PaddlePaddle_DeepSpeech2") sys.path.append("./PaddlePaddle_DeepSpeech2")
from infer_path import predict_long_audio_with_paddle from infer_path import predict_long_audio_with_paddle
table_content = predict_long_audio_with_paddle(audio_path, start_time, state) table_head = [["起始时间", "终止时间", "字幕", '建议', '解说脚本']]
write_to_sheet(book_name_xlsx, sheet_name_xlsx, table_content) table_content = table_head + predict_long_audio_with_paddle(audio_path, start_time, state)
from detect_with_ocr import write_excel_xlsx
write_excel_xlsx(book_name_xlsx, sheet_name_xlsx, table_content)
state[0] = 1 state[0] = 1
# 删除中间文件 # 删除中间文件
# shutil.rmtree(tmp_root) # shutil.rmtree(tmp_root)
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
import difflib import difflib
import openpyxl import openpyxl
from openpyxl.styles import PatternFill from openpyxl.styles import PatternFill, Alignment
# 字幕的上下边界 # 字幕的上下边界
up_b, down_b = 0, 0 up_b, down_b = 0, 0
...@@ -221,7 +221,7 @@ def write_excel_xlsx(path, sheet_name, value): ...@@ -221,7 +221,7 @@ def write_excel_xlsx(path, sheet_name, value):
sheet.column_dimensions['D'].width = 30 sheet.column_dimensions['D'].width = 30
for i in range(0, index): for i in range(0, index):
for j in range(0, len(value[i])): for j in range(0, len(value[i])):
sheet.cell(row=i + 1, column=j + 1, value=str(value[i][j])) sheet.cell(row=i + 1, column=j + 1, value=str(value[i][j])).alignment = Alignment(wrapText=True)
if value[i][j] == '' or '插入旁白' in str(value[i][j]) or value[i][j] == '翻译': if value[i][j] == '' or '插入旁白' in str(value[i][j]) or value[i][j] == '翻译':
sheet.cell(row=i + 1, column=j + 1).fill = PatternFill(fill_type='solid', fgColor='ffff00') sheet.cell(row=i + 1, column=j + 1).fill = PatternFill(fill_type='solid', fgColor='ffff00')
workbook.save(path) workbook.save(path)
......
eagle64.ico

16.6 KB

...@@ -192,6 +192,9 @@ def ss_and_export(video_path, sheet_path, output_dir, speed, caption_file, state ...@@ -192,6 +192,9 @@ def ss_and_export(video_path, sheet_path, output_dir, speed, caption_file, state
if not os.path.exists(root_path): if not os.path.exists(root_path):
os.mkdir(root_path) os.mkdir(root_path)
global tmp_file
tmp_file = os.path.join(output_dir, tmp_file)
# 读取表格,并获取旁白及对应插入位置 # 读取表格,并获取旁白及对应插入位置
sheet_content = read_sheet(book_path) sheet_content = read_sheet(book_path)
narratages, start_timestamp, end_timestamp = get_narratage_text(sheet_content, speed) narratages, start_timestamp, end_timestamp = get_narratage_text(sheet_content, speed)
...@@ -204,7 +207,7 @@ def ss_and_export(video_path, sheet_path, output_dir, speed, caption_file, state ...@@ -204,7 +207,7 @@ def ss_and_export(video_path, sheet_path, output_dir, speed, caption_file, state
wav_path = os.path.join(root_path, '%.2f.wav' % start_timestamp[i]) wav_path = os.path.join(root_path, '%.2f.wav' % start_timestamp[i])
narratage_paths.append(wav_path) narratage_paths.append(wav_path)
speech_synthesis(text, wav_path, speed) speech_synthesis(text, wav_path, speed)
time.sleep(1) time.sleep(2)
print("目前正在处理{}".format(wav_path)) print("目前正在处理{}".format(wav_path))
if state is not None: if state is not None:
state[0] = float((i + 1) / len(narratages)) * 0.97 state[0] = float((i + 1) / len(narratages)) * 0.97
......
...@@ -12,6 +12,9 @@ window = tk.Tk() ...@@ -12,6 +12,9 @@ window = tk.Tk()
window.title('无障碍电影辅助工具') # 标题 window.title('无障碍电影辅助工具') # 标题
window.geometry('600x400') # 窗口尺寸 window.geometry('600x400') # 窗口尺寸
window.resizable(0, 0) window.resizable(0, 0)
# window.iconphoto(False, tk.PhotoImage(file="cropped-eagle.png"))
window.iconbitmap("eagle64.ico")
# window.tk.call("wm", "iconphoto", window._w, tk.PhotoImage(file="cropped-eagle.png"))
def open_video_file(): def open_video_file():
...@@ -455,7 +458,7 @@ synthesis_command.place(relx=0.05, rely=0.45, relwidth=0.9, relheight=0.4) ...@@ -455,7 +458,7 @@ synthesis_command.place(relx=0.05, rely=0.45, relwidth=0.9, relheight=0.4)
# synthesis_command.grid_columnconfigure(i, weight=1) # synthesis_command.grid_columnconfigure(i, weight=1)
audioDir_label = ttk.Label(synthesis_command, text="输出音频存放于") audioDir_label = ttk.Label(synthesis_command, text="输出音频存放于")
audioDir_label.grid(column=0, row=0) audioDir_label.grid(column=0, row=0, sticky="")
audioDir = tk.StringVar() audioDir = tk.StringVar()
audioDir_input = ttk.Entry(synthesis_command, width=30, textvariable=audioDir) audioDir_input = ttk.Entry(synthesis_command, width=30, textvariable=audioDir)
audioDir_input.grid(column=1, row=0) audioDir_input.grid(column=1, row=0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment