Commit 0824724a authored by smile2019's avatar smile2019

Merge remote-tracking branch 'refs/remotes/origin/feat_1' into feat_1

parents a5f36591 9f2adad2
......@@ -12,12 +12,16 @@ import os
class Content:
StartTimeColumn = 0
AsideColumnNumber = 2
SpeedColumnNumber = 3
ActivateColumns = [2, 3]
AsideColumnNumber = 4
SpeedColumnNumber = 5
# ActivateColumns = [2, 3]
ActivateColumns = [4,5]
# ColumnCount = 3
ObjectName = "all_tableWidget"
TimeFormatColumns = [0]
# TimeFormatColumns = [0]
TimeFormatColumns = [0, 1]
SpeedList = ["1.00(4字/秒)", "1.10(4.5字/秒)", "1.25(5字/秒)",
"1.50(6字/秒)", "1.75(7字/秒)", "2.00(8字/秒)", "2.50(10字/秒)"]
class Aside:
......
......@@ -24,17 +24,20 @@ class Ui_Dialog(object):
self.name_input.setObjectName("name_input")
self.root_input = QtWidgets.QLineEdit(Dialog)
self.root_input.setObjectName("root_input")
self.gridLayout.addWidget(self.root_input, 1, 1, 1, 1)
self.gridLayout.addWidget(self.root_input, 0, 1, 1, 1)
self.get_dir = QtWidgets.QPushButton(Dialog)
self.get_dir.setObjectName("get_dir")
self.gridLayout.addWidget(self.get_dir, 1, 2, 1, 1)
self.gridLayout.addWidget(self.get_dir, 0, 2, 1, 1)
self.rootLabel = QtWidgets.QLabel(Dialog)
self.rootLabel.setObjectName("rootLabel")
self.gridLayout.addWidget(self.rootLabel, 1, 0, 1, 1)
self.gridLayout.addWidget(self.rootLabel, 0, 0, 1, 1)
self.nameLabel = QtWidgets.QLabel(Dialog)
self.nameLabel.setObjectName("nameLabel")
self.gridLayout.addWidget(self.nameLabel, 0, 0, 1, 1)
self.gridLayout.addWidget(self.name_input, 0, 1, 1, 1)
self.gridLayout.addWidget(self.nameLabel, 1, 0, 1, 1)
self.gridLayout.addWidget(self.name_input, 1, 1, 1, 1)
self.gridLayout_2.addLayout(self.gridLayout, 0, 0, 1, 1)
self.horizontalLayout = QtWidgets.QHBoxLayout()
self.horizontalLayout.setObjectName("horizontalLayout")
......@@ -64,7 +67,7 @@ class Ui_Dialog(object):
_translate = QtCore.QCoreApplication.translate
Dialog.setWindowTitle(_translate("Dialog", "Dialog"))
self.nameLabel.setText(_translate("Dialog", "工程名称"))
self.rootLabel.setText(_translate("Dialog", "工程文件夹"))
self.rootLabel.setText(_translate("Dialog", "目标路径"))
self.get_dir.setText(_translate("Dialog", "打开文件夹"))
self.confirm.setText(_translate("Dialog", "确认"))
self.cancel.setText(_translate("Dialog", "取消"))
......@@ -44,7 +44,7 @@ ocr = PaddleOCR(use_angle_cls=True, lang="ch", show_log=False, use_gpu=False, cl
normal_speed = 4
def get_position(video_path: str, start_time: float) -> Tuple[float, float]:
def get_position(video_path: str, start_time: float, rate: float, rate_bottom: float) -> Tuple[float, float]:
# return (885.0, 989.0)
"""根据对视频中的画面进行分析,确定字幕的位置,以便后续的字幕识别
......@@ -65,68 +65,79 @@ def get_position(video_path: str, start_time: float) -> Tuple[float, float]:
txt_cnt = 0
pre_txt = None
video.set(cv2.CAP_PROP_POS_FRAMES, start)
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT) * 0.6)
while True:
_, img = video.read()
# print("img:", img)
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# cv2.imshow('img', gray)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
cnt += 1
if img is None or cnt > 10000:
break
if cnt % int(fps / 3) != 0:
continue
img = img[height:]
res = ocr.ocr(img, cls=True)
sorted(res, key=lambda text: text[0][0][1])
bottom_position = None
if len(res) == 0:
continue
log = []
print("cnt:", cnt, "rect_num:", len(res))
for x in res:
# print("x:", x)
rect, (txt, confidence) = x
[x1,y1],[x2,y2],[x3,y3],[x4,y4] = rect
# font_size = rect[2][1] - rect[0][1]
mid = (x1 + x2) / 2
gradient = np.arctan(abs((y2 - y1) / (x2 - x1)))
# 可能是字幕的文本
conf_thred = 0.9
# conf_thred = 0.8
if confidence > conf_thred and 0.4 * img.shape[1] < mid < 0.6 * img.shape[1] and gradient < 0.1:
if bottom_position is None:
bottom_position = y1
# 判断是否与前一文本相同(是不是同一个字幕),非同一字幕的前提下,取对应上下边界,
keys = subtitle_position.keys()
if abs(y1 - bottom_position) < 10:
if pre_txt is None or pre_txt != txt:
txt_cnt += 1
pre_txt = txt
if (y1, y3) in keys:
subtitle_position[(y1, y3)] += 1
else:
replace = False
for k in keys:
# 更新键值为最宽的上下限
if abs(y1 - k[0]) + abs(y3 - k[1]) < 10:
subtitle_position[k] += 1
new_k = min(k[0], y1), max(k[1], y3)
if new_k != k:
subtitle_position[new_k] = subtitle_position[k]
subtitle_position.pop(k)
replace = True
break
if not replace:
subtitle_position[(y1, y3)] = 1
if txt_cnt == 3:
break
print(subtitle_position)
up_bounding, down_bounding = max(subtitle_position, key=subtitle_position.get)
return int(up_bounding + height), int(down_bounding + height)
# height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT) * 0.6)
print(">>>>>>>>>>>>video height")
print(cv2.CAP_PROP_FRAME_HEIGHT)
print(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
up = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT) * (rate))
down = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT) * (rate_bottom))
# down = up + 20
# down = video.get(cv2.CAP_PROP_FRAME_HEIGHT) * (0.73)
print(up)
# print(down)
return int(up), int(down)
# while True:
# _, img = video.read()
# # print("img:", img)
# # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# # cv2.imshow('img', gray)
# # cv2.waitKey(0)
# # cv2.destroyAllWindows()
# cnt += 1
# if img is None or cnt > 10000:
# break
# if cnt % int(fps / 3) != 0:
# continue
# img = img[height:]
# res = ocr.ocr(img, cls=True)
# sorted(res, key=lambda text: text[0][0][1])
# bottom_position = None
# if len(res) == 0:
# continue
# log = []
# print("cnt:", cnt, "rect_num:", len(res))
# for x in res:
# # print("x:", x)
# rect, (txt, confidence) = x
# [x1,y1],[x2,y2],[x3,y3],[x4,y4] = rect
# # font_size = rect[2][1] - rect[0][1]
# mid = (x1 + x2) / 2
# gradient = np.arctan(abs((y2 - y1) / (x2 - x1)))
# # 可能是字幕的文本
# conf_thred = 0.9
# # conf_thred = 0.8
# if confidence > conf_thred and 0.4 * img.shape[1] < mid < 0.6 * img.shape[1] and gradient < 0.1:
# if bottom_position is None:
# bottom_position = y1
# # 判断是否与前一文本相同(是不是同一个字幕),非同一字幕的前提下,取对应上下边界,
# keys = subtitle_position.keys()
# if abs(y1 - bottom_position) < 10:
# if pre_txt is None or pre_txt != txt:
# txt_cnt += 1
# pre_txt = txt
# if (y1, y3) in keys:
# subtitle_position[(y1, y3)] += 1
# else:
# replace = False
# for k in keys:
# # 更新键值为最宽的上下限
# if abs(y1 - k[0]) + abs(y3 - k[1]) < 10:
# subtitle_position[k] += 1
# new_k = min(k[0], y1), max(k[1], y3)
# if new_k != k:
# subtitle_position[new_k] = subtitle_position[k]
# subtitle_position.pop(k)
# replace = True
# break
# if not replace:
# subtitle_position[(y1, y3)] = 1
# if txt_cnt == 3:
# break
# print(subtitle_position)
# up_bounding, down_bounding = max(subtitle_position, key=subtitle_position.get)
# return int(up_bounding + height), int(down_bounding + height)
def erasePunc(txt: str) -> str:
......@@ -193,19 +204,29 @@ def detect_subtitle(img: np.ndarray) -> Tuple[Union[str, None], float]:
Tuple[Union[str, None]]: 字幕信息(没有字幕时返回None)和置信度
"""
subTitle = ''
# up_b = 276
# down_b = 297
height = down_b - up_b
img = img[int(up_b - height * 0.7):int(down_b + height * 0.7)]
img = img[int(up_b - height*0.7):int(down_b + height*0.7)]
# 针对低帧率的视频做图像放大处理
print(">>>>>>>>>>>>>>>>>>>>>img shape")
print(height)
print(up_b)
print(down_b)
print(img.shape)
if img.shape[1] < 1000:
img = cv2.resize(img, (int(img.shape[1] * 1.5), int(img.shape[0] * 1.5)))
cv2.imwrite('./cap.png', img)
res = ocr.ocr(img, cls=True)
print('--------> res', res)
sorted(res, key=lambda text: text[0][0][1])
sorted(res, key=lambda text: text[0][0][0])
if len(res) == 0:
return None, 0
possible_txt = []
conf = 0
print(res)
print('res --------->', res)
for x in res:
# cv2.imshow("cut", img)
# cv2.waitKey(0)
......@@ -312,15 +333,24 @@ def process_video(video_path: str, begin: float, end: float, book_path: str, she
mainWindow.projectContext.last_time = cur_time
subTitle, conf = detect_subtitle(frame)
print(">>>>>>>>>>>>111111111")
if subTitle is not None:
print(">>>>>>>>>>>>111111111 2222222")
subTitle = normalize(subTitle)
if len(subTitle) == 0:
print(">>>>>>>>>>>>111111111 3333333")
subTitle = None
print(">>>>>>>>>>>>222222222")
# 第一次找到字幕
if lastSubTitle is None and subTitle is not None:
start_time = cur_time
print(">>>>>>>>>>>>333333333")
# 字幕消失
elif lastSubTitle is not None and subTitle is None:
print(">>>>>>>>>>>>4444444444")
end_time = cur_time
res.append([start_time, end_time, lastSubTitle])
if (len(res) == 1 and res[-1][0] - last_time >= 1) or (len(res) > 1 and res[-1][0] - res[-2][1]) >= 1:
......@@ -334,8 +364,10 @@ def process_video(video_path: str, begin: float, end: float, book_path: str, she
# write_to_sheet(book_path, sheet_name, [round(start_time, 2), round(end_time, 2), lastSubTitle, ''])
add_to_list(mainWindow, "字幕", [round(start_time, 3), round(end_time, 3), lastSubTitle, ''])
elif lastSubTitle is not None and subTitle is not None:
print(">>>>>>>>>>>>5555555555")
# 两句话连在一起,但是两句话不一样
if string_similar(lastSubTitle, subTitle) < 0.7:
print(">>>>>>>>>>>66666666666")
end_time = cur_time
res.append([start_time, end_time, lastSubTitle])
if (len(res) == 1 and res[-1][0] - last_time >= 1) or (len(res) > 1 and res[-1][0] - res[-2][1]) >= 1:
......@@ -349,6 +381,7 @@ def process_video(video_path: str, begin: float, end: float, book_path: str, she
add_to_list(mainWindow, "字幕", [round(start_time, 3), round(end_time, 3), lastSubTitle, ''])
start_time = end_time
else:
print(">>>>>>>>>>>>777777777")
lastSubTitle = subTitle if conf > lastConf else lastSubTitle
continue
# 当前字幕与上一段字幕不一样
......@@ -404,7 +437,7 @@ def detect_with_ocr(video_path: str, book_path: str, start_time: float, end_time
up_b, down_b = context.caption_boundings[0], context.caption_boundings[1]
else:
# 此处start_time + 300是为了节省用户调整视频开始时间的功夫(强行跳过前5分钟)
up_b, down_b = get_position(video_path, 0)
up_b, down_b = get_position(video_path, 0, mainWindow.rate, mainWindow.rate_bottom)
context.caption_boundings = [up_b, down_b]
context.detected = True
......@@ -418,4 +451,4 @@ def detect_with_ocr(video_path: str, book_path: str, start_time: float, end_time
if __name__ == '__main__':
path = "D:/mystudy/Eagle/accessibility_movie_1/test.mp4"
print("get_pos:", get_position(path, 0))
# print("get_pos:", get_position(path, 0))
......@@ -132,12 +132,16 @@ class MainWindow(QMainWindow, Ui_MainWindow):
"""
self.setting.triggered.connect(self.show_setting_dialog) # 设置
self.action_3.triggered.connect(self.show_detect_dialog)
self.action_3.setEnabled(False)
self.action_4.triggered.connect(self.show_assemble_dialog)
self.action_4.setEnabled(False)
self.action_5.triggered.connect(self.import_excel)
self.action_5.setEnabled(False)
# menu4_action = QtWidgets.QAction()
# menu4_action.setObjectName("menu4_action")
# self.menu_4.addAction(menu4_action)
# menu4_action.triggered.connect(self.show_detect_dialog)
# self.action_3.triggered.connect(self.show_detect_dialog)
# self.action_3.setEnabled(True)
# self.action_4.triggered.connect(self.show_assemble_dialog)
# self.action_4.setEnabled(False)
# self.action_5.triggered.connect(self.import_excel)
# self.action_5.setEnabled(False)
self.action_create.triggered.connect(self.show_create_dialog) # 新建工程
self.action_save.triggered.connect(self.save_project)
self.action_save.setEnabled(False)
......@@ -250,11 +254,17 @@ class MainWindow(QMainWindow, Ui_MainWindow):
# 表格双击和发生change时的处理
self.zm_tableWidget.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.zm_tableWidget.itemDoubleClicked.connect(self.change_video_time)
# self.all_tableWidget.itemDoubleClicked.connect(self.change_video_time)
# self.all_tableWidget.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.all_tableWidget.itemDoubleClicked.connect(self.writeHistory)
self.all_tableWidget.itemChanged.connect(self.rewriteHistory)
self.all_tableWidget.itemChanged.connect(self.write2ProjectFromContent)
self.all_tableWidget.itemChanged.connect(self.generate_audio_slot_all)
self.all_tableWidget.itemDoubleClicked.connect(self.change_video_time)
self.all_tableWidget.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.all_tableWidget.itemDoubleClicked.connect(
self.all_item_changed_by_double_clicked_slot)
self.pb_tableWidget.itemDoubleClicked.connect(self.writeHistory)
self.pb_tableWidget.itemDoubleClicked.connect(
self.pb_item_changed_by_double_clicked_slot)
......@@ -289,6 +299,8 @@ class MainWindow(QMainWindow, Ui_MainWindow):
self.projectContext.Init(project_path)
self.update_ui()
self.rate = 0
# 打印到log文件中
t = RunThread(funcName=make_print_to_file, args=os.path.join(os.getcwd(), 'log'), name="logging")
print(t)
......@@ -311,6 +323,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
# 如果选择的行索引小于2,弹出上下文菜单
menu = QMenu()
item1 = menu.addAction("删除")
add_pb_item = menu.addAction("新增旁白")
# 转换坐标系
screenPos = self.all_tableWidget.mapToGlobal(pos)
......@@ -323,6 +336,10 @@ class MainWindow(QMainWindow, Ui_MainWindow):
rowNum, 1).text(), self.all_tableWidget.item(rowNum, 2).text())
self.del_line_operation_slot(rowNum + 1)
return
if action == add_pb_item:
to_be_delete_element = self.projectContext.all_elements[rowNum]
self.insert_aside_from_cur_time(float(to_be_delete_element.st_time_sec))
return
# 重写关闭Mmainwindow窗口
def closeEvent(self, event):
......@@ -417,6 +434,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
project_name = os.path.basename(project_path)
self.setWindowTitle(f"无障碍电影制作软件(当前工程为:{project_name})")
self.projectContext.Init(project_path)
self.setting_dialog.refresh(self.projectContext)
self.update_ui()
# 导入视频
......@@ -440,6 +458,36 @@ class MainWindow(QMainWindow, Ui_MainWindow):
self.action_insert_aside_from_now.setEnabled(True)
self.insert_aside_from_now_btn.setEnabled(True)
def up_ocr(self):
h = self.widget.up(3)
video_h = self.wgt_video.height()
self.rate = float(h-10)/float(video_h)
print(">>>>>up h:" + str(h))
print(self.wgt_video.height())
print(">>>>>>>>>rate" + str(self.rate))
def down_ocr(self):
h = self.widget.down(3)
video_h = self.wgt_video.height()
self.rate = float(h-10)/float(video_h)
print(">>>>>down h:" + str(h))
print(self.wgt_video.height())
print(">>>>>>>>>rate" + str(self.rate))
def up_ocr_bottom(self):
h = self.widget_bottom.up(3)
video_h = self.wgt_video.height()
self.rate_bottom = float(h-6)/float(video_h)
def down_ocr_bottom(self):
h = self.widget_bottom.down(3)
video_h = self.wgt_video.height()
self.rate_bottom = float(h-6)/float(video_h)
#导入旁白excel
def import_excel(self):
# excel_path = self.openExcelFile()
......@@ -584,7 +632,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
self.projectContext.excel_path = book_path
# 获取视频的时长等信息,初始化开始结束时间
startTime = "00:00:20"
startTime = "00:00:00"
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
......@@ -988,6 +1036,11 @@ class MainWindow(QMainWindow, Ui_MainWindow):
if table.objectName() == constant.Content.ObjectName:
elem_list = elem.to_short_list()
time_format_col_list = constant.Content.TimeFormatColumns
btn = QPushButton()
btn.setText(f"预览{idx}")
col = len(elem_list)
btn.clicked.connect(self.audio_preview_slot_all)
table.setCellWidget(idx, col, btn)
if table.objectName() == constant.Subtitle.ObjectName:
elem_list = elem.to_subtitle_list()
time_format_col_list = constant.Subtitle.TimeFormatColumns
......@@ -1012,6 +1065,12 @@ class MainWindow(QMainWindow, Ui_MainWindow):
# 需要格式化成hh:mm:ss格式
if j in time_format_col_list and type(text) == str and len(text) != 0:
text = utils.transfer_second_to_time(text)
if table.objectName() == constant.Content.ObjectName and j == constant.Content.SpeedColumnNumber:
qcombo = QtWidgets.QComboBox()
qcombo.addItems(constant.Content.SpeedList)
qcombo.setCurrentIndex(constant.Content.SpeedList.index(text))
qcombo.currentIndexChanged.connect(self.change_audio_speed_all)
table.setCellWidget(idx, j, qcombo)
if table.objectName() == constant.Aside.ObjectName and j == constant.Aside.SpeedColumnNumber:
qcombo = QtWidgets.QComboBox()
qcombo.addItems(constant.Aside.SpeedList)
......@@ -1064,6 +1123,43 @@ class MainWindow(QMainWindow, Ui_MainWindow):
else:
self.prompt_dialog.show_with_msg("暂无音频可供预览,请重新生成")
def audio_preview_slot_all(self):
"""字幕旁白界面音频预览,会同步播放视频,并更新视频信息相关组件
"""
btn = self.sender()
# 法1:按照物理位置。这样的结果不太对
idx = self.all_tableWidget.indexAt(btn.pos())
print("index:", idx.row())
# 获取时间
item = self.all_tableWidget.item(idx.row(), 0)
audio_path = None
pos_sec = utils.trans_to_seconds(item.text())
audio_path = os.path.dirname(self.projectContext.excel_path) + (
"/tmp/%.3f.wav" % pos_sec)
print("待播放的音频文件为", audio_path)
if audio_path is not None:
# 确认该音频是否正在合成中
for t in self.all_threads:
if t.name == "single_speech_synthesis" and t.is_alive():
if audio_path in t._args:
self.prompt_dialog.show_with_msg("音频正在合成,请稍候")
return
# 2、如果找到了该音频,则新起一个线程播放
if audio_path != None and os.path.exists(audio_path):
print(audio_path)
t = RunThread(funcName=self.play_audio,
args=(audio_path, ),
name="play_audio")
t.start()
self.all_threads.append(t)
self.player.setPosition(int(pos_sec*1000))
# 做播放视频的操作,绑定播放按钮的变化
self.is_video_playing = False
self.playVideo()
else:
self.prompt_dialog.show_with_msg("暂无音频可供预览,请重新生成")
def checkIfTableItemCanChange(self, table: QTableWidget, i: int, j: int):
"""确认单元格是否可编辑
......@@ -1075,10 +1171,12 @@ class MainWindow(QMainWindow, Ui_MainWindow):
Returns:
bool: True(可编辑) or False(不可编辑)
"""
if table.objectName() == self.all_tableWidget.objectName():
return True
# if table.objectName() == self.all_tableWidget.objectName():
# return True
if table.objectName() == self.pb_tableWidget.objectName() and j in constant.Aside.ActivateColumns:
return True
if table.objectName() == self.all_tableWidget.objectName() and j in constant.Content.ActivateColumns:
return True
return False
def save_project(self):
......@@ -1146,7 +1244,9 @@ class MainWindow(QMainWindow, Ui_MainWindow):
row = item.row() # 获取行数
col = item.column() # 获取列数 注意是column而不是col哦
text = item.text() # 获取内容
if col == constant.Aside.AsideColumnNumber:
# if col == constant.Aside.AsideColumnNumber:
# self.projectContext.history_push(row, text, text)
if col == constant.Content.AsideColumnNumber:
self.projectContext.history_push(row, text, text)
......@@ -1182,6 +1282,82 @@ class MainWindow(QMainWindow, Ui_MainWindow):
# 合成这一段语音
self.do_generate_audio_by_aside_row(int(row))
def generate_audio_slot_all(self, item):
"""生成临时旁白音频
Args:
item : 被选中的单元格
在set表格的时候(初始化),不会触发。只有双击修改或切换语速时才会触发
"""
try:
if self.projectContext.initial_ing == True:
return
if item is None:
print("WRONG!!!! item Is None")
return
if self.is_user_editing() == False:
return
self.set_user_edit(False)
# 不需要set为False
row = item.row() # 获取行数
col = item.column() # 获取列数 注意是column而不是col
# 只有更新语速或者更新旁白,才需要重新生成音频
if col not in constant.Content.ActivateColumns:
return
# 停止预览,释放当前占用的音频文件
self.audio_player.setMedia(QMediaContent())
# 合成这一段语音
self.do_generate_audio_by_aside_row_all(int(row))
except Exception as e:
print(e)
def do_generate_audio_by_aside_row_all(self, row: int):
"""根据行号生成对应旁白文本的临时音频
Args:
row (int): 旁白表格中的行号
传入pb_tableWidget中的行号,生成对应音频
"""
from speech_synthesis import speech_synthesis, Speaker, choose_speaker
audio_dir = os.path.dirname(self.projectContext.excel_path)
wav_path = audio_dir + \
'/tmp/%.3f.wav' % float(
self.projectContext.all_elements[int(row)].st_time_sec)
print("wav_path:", wav_path)
try:
# speed_info = self.projectContext.speaker_speed
# 使用私有 语速
speed_info = self.projectContext.all_elements[int(row)].speed
speaker_info = self.projectContext.speaker_info
speed = float(speed_info.split('(')[0])
speaker_name = speaker_info.split(",")[0]
speaker = self.projectContext.choose_speaker(speaker_name)
text = self.projectContext.all_elements[int(row)].aside
self.projectContext.all_elements[int(row)].print_self()
# 如果目前wav_path存在,且旁白字数清空了,那就把已生成wav删掉
if text is None or len(text) == 0:
if os.path.exists(wav_path):
os.remove(wav_path)
return
# 把目前在执行的单条语音的合成线程先停掉
for t in self.all_threads:
if wav_path in t._args and t.name == "single_speech_synthesis" and t.is_alive():
stop_thread(t)
t = RunThread(funcName=speech_synthesis,
args=(text, wav_path, speaker, speed),
name="single_speech_synthesis")
t.setDaemon(True)
t.start()
self.all_threads.append(t)
except Exception as e:
print(e)
def do_generate_audio_by_aside_row(self, row: int):
"""根据行号生成对应旁白文本的临时音频
......@@ -1261,7 +1437,6 @@ class MainWindow(QMainWindow, Ui_MainWindow):
row = item.row() # 获取行数
col = item.column() # 获取列数 注意是column而不是col
text = item.text() # 获取内容
if self.can_write_history == False:
self.can_write_history = True
return
......@@ -1277,7 +1452,9 @@ class MainWindow(QMainWindow, Ui_MainWindow):
col = item.column() # 获取列数 注意是column而不是col哦
text = item.text() # 获取内容
if col != constant.Aside.AsideColumnNumber:
# if col != constant.Aside.AsideColumnNumber:
# return
if col != constant.Content.AsideColumnNumber:
return
opt = self.projectContext.history_pop()
......@@ -1350,6 +1527,44 @@ class MainWindow(QMainWindow, Ui_MainWindow):
int(idx), constant.Content.SpeedColumnNumber, QTableWidgetItem(text))
self.projectContext.refresh_aside_speed(row, text)
def write2ProjectFromContent(self, item):
"""将表格中修改的内容更新至工程中
Args:
item: 修改的单元格
"""
if self.projectContext.initial_ing == True:
return
if self.is_user_editing() == False:
return
if item is None:
print("WRONG!!!! item Is None")
return
else:
row = item.row() # 获取行数
col = item.column() # 获取列数 注意是column而不是col哦
text = item.text() # 获取内容
if col not in constant.Content.ActivateColumns:
return
# 更新【字幕旁白】这个tab里的字,如果是语速,那就更新语速这一列,如果是旁白,那就更新旁白这一列
try:
print("行号", row)
print("开始时间", self.projectContext.all_elements[row].st_time_sec)
idx = self.projectContext.aside_subtitle_2contentId(
self.projectContext.all_elements[row])
print("对应index", idx)
if col == constant.Content.AsideColumnNumber:
# self.all_tableWidget.setItem(
# int(idx), constant.Content.AsideColumnNumber, QTableWidgetItem(text))
self.projectContext.refresh_element(row, text)
elif col == constant.Content.SpeedColumnNumber:
# self.all_tableWidget.setItem(
# int(idx), constant.Content.SpeedColumnNumber, QTableWidgetItem(text))
self.projectContext.refresh_speed(row, text)
except Exception as e:
print(e)
def undo_slot(self):
"""撤销之前对表格内容的修改操作
......@@ -1360,9 +1575,13 @@ class MainWindow(QMainWindow, Ui_MainWindow):
print('[undo_slot] record=%s' % (record.to_string()))
item = QTableWidgetItem(record.old_str)
row = int(record.row)
self.projectContext.aside_list[row].aside = record.old_str
self.pb_tableWidget.setItem(
row, constant.Aside.AsideColumnNumber, item)
# self.projectContext.aside_list[row].aside = record.old_str
# self.pb_tableWidget.setItem(
# row, constant.Aside.AsideColumnNumber, item)
self.projectContext.all_elements[row].aside = record.old_str
self.all_tableWidget.setItem(
row, constant.Content.AsideColumnNumber, item)
self.action_redo.setEnabled(True)
def redo_slot(self):
......@@ -1376,9 +1595,12 @@ class MainWindow(QMainWindow, Ui_MainWindow):
self.action_redo.setEnabled(False)
item = QTableWidgetItem(record.new_str)
row = int(record.row)
self.projectContext.aside_list[row].aside = record.new_str
self.pb_tableWidget.setItem(
row, constant.Aside.AsideColumnNumber, item)
# self.projectContext.aside_list[row].aside = record.new_str
# self.pb_tableWidget.setItem(
# row, constant.Aside.AsideColumnNumber, item)
self.projectContext.all_elements[row].aside = record.new_str
self.all_tableWidget.setItem(
row, constant.Content.AsideColumnNumber, item)
def view_history_slot(self):
"""查看历史操作
......@@ -1425,7 +1647,7 @@ class MainWindow(QMainWindow, Ui_MainWindow):
"""
cur_time = self.player.position() / 1000
if self.curTab == 0:
if not self.is_user_editing() and self.curTab == 0:
all_elements = self.projectContext.all_elements
for i in range(len(all_elements) - 1, -1, -1):
if utils.trans_to_seconds(all_elements[i].st_time_sec) <= cur_time:
......@@ -1495,7 +1717,26 @@ class MainWindow(QMainWindow, Ui_MainWindow):
print("self.player.position()", self.player.position())
cur_time = round(self.player.position()/1000, 3)
idx = self.calculate_element_row(cur_time)
print(">>>>>>>>>>>>>>>>>>>>>>>>>add row")
print("idex :" + str(idx))
print("[insert_aside_from_now_slot] idx=", idx)
# 其实end_time目前是没啥用的,可以删掉了
print("cur_lens", len(self.projectContext.all_elements))
if idx < len(self.projectContext.all_elements) - 1:
self.add_line_operation_slot(idx, str(cur_time), self.projectContext.all_elements[idx+1].st_time_sec, "", "插入旁白,推荐字数为0", "", self.projectContext.speaker_speed)
else:
self.add_line_operation_slot(idx, str(cur_time), str(cur_time+1), "", "插入旁白,推荐字数为0", "", self.projectContext.speaker_speed)
def insert_aside_from_cur_time(self,cur_time:float):
"""在当前位置插入旁白
根据当前时间找到表格中合适插入的位置,然后在对应位置添加旁白
"""
if self.player.duration() == 0 or self.projectContext.project_base_dir in [None, ""]:
self.prompt_dialog.show_with_msg("插入失败!未检测到视频或工程!")
return
print("self.player.position()", self.player.position())
idx = self.calculate_element_row(cur_time)
print("idex :" + str(idx))
print("[insert_aside_from_now_slot] idx=", idx)
# 其实end_time目前是没啥用的,可以删掉了
......@@ -1544,7 +1785,6 @@ class MainWindow(QMainWindow, Ui_MainWindow):
same_flag = True
break
if float(cur_time) < float(self.projectContext.all_elements[idx].st_time_sec):
print(">>>>>>>>>bbbbbbbb")
break
idx += 1
return idx,same_flag
......@@ -1645,7 +1885,6 @@ class MainWindow(QMainWindow, Ui_MainWindow):
# new_list = []
for new_element in elements:
start_time_map[new_element.st_time_sec] = ""
# print(">>>>>>remove start")
for aside in self.projectContext.aside_list: # 使用切片复制整个列表
aside.print_self()
if aside.aside != None and aside.aside != "" and aside.st_time_sec not in start_time_map:
......@@ -1656,7 +1895,6 @@ class MainWindow(QMainWindow, Ui_MainWindow):
# new_list.append(aside)
# self.projectContext.aside_list=new_list
# self.refresh_tab_slot(True)
print(">>>>>>remove end")
for item in remove_list:
item.print_self()
idx = 0
......@@ -1839,3 +2077,23 @@ class MainWindow(QMainWindow, Ui_MainWindow):
self.projectContext.all_elements[int(all_idx)].speed = combo.currentText()
self.all_tableWidget.setItem(int(all_idx), constant.Content.SpeedColumnNumber, QTableWidgetItem(combo.currentText()))
self.do_generate_audio_by_aside_row(int(row))
def change_audio_speed_all(self):
"""换语速
首先定位到待切换语速的那一行,释放当前播放的音频文件,并替换对应旁白文本的语速,同时更新字幕旁白表格中的语速,然后自动生成新的音频。
"""
combo = self.sender()
idx = self.all_tableWidget.indexAt(combo.pos())
row = idx.row()
print("index:", row)
# 将audio_player的资源置空
self.audio_player.setMedia(QMediaContent())
self.projectContext.all_elements[row].speed = combo.currentText()
# 更新字幕旁白表格里对应行的语速
all_idx = self.projectContext.aside_subtitle_2contentId(
self.projectContext.all_elements[row])
self.projectContext.all_elements[int(all_idx)].speed = combo.currentText()
self.all_tableWidget.setItem(int(all_idx), constant.Content.SpeedColumnNumber, QTableWidgetItem(combo.currentText()))
self.do_generate_audio_by_aside_row_all(int(row))
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<!-- <?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>MainWindow</class>
<widget class="QMainWindow" name="MainWindow">
......@@ -709,4 +709,4 @@
</customwidgets>
<resources/>
<connections/>
</ui>
</ui> -->
......@@ -7,7 +7,33 @@
# WARNING! All changes made in this file will be lost!
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import QMainWindow, QFileDialog, QTableWidget, QTableWidgetItem, QAbstractItemView, QProgressBar, QLabel, QApplication, QPushButton, QMenu, QWidget
from PyQt5.QtCore import QUrl, Qt, QTimer, QRect, pyqtSignal, QPersistentModelIndex
from PyQt5.QtMultimedia import *
from PyQt5.QtGui import QIcon, QPainter, QColor, QPen
class MyWidget(QWidget):
def paintEvent(self, event):
print(">>>>>>>>>>>>>>>into paint")
painter = QPainter(self)
painter.setRenderHint(QPainter.Antialiasing) # Optional: Enable anti-aliasing
# painter.setCompositionMode(QPainter.CompositionMode_SourceOver) # Set composition mode
# # Draw existing content
# painter.fillRect(event.rect(), QColor(255, 255, 255)) # Fill with white color (you can adjust as needed)
# Draw a transparent horizontal line
painter.setPen(QPen(Qt.red, 2, Qt.SolidLine))
painter.drawLine(0, 1, 800, 1)
def up(self, mov_len):
print(">>>>>>>>>>>up" + str(mov_len))
self.move(0, self.y() - mov_len)
return self.y()
def down(self, mov_len):
print(">>>>>>>>>>>down" + str(mov_len))
self.move(0,self.y() + mov_len)
return self.y()
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
......@@ -32,8 +58,12 @@ class Ui_MainWindow(object):
self.verticalLayout_2 = QtWidgets.QVBoxLayout()
self.verticalLayout_2.setObjectName("verticalLayout_2")
self.wgt_video = myVideoWidget(self.centralwidget)
self.wgt_video.setMinimumSize(QtCore.QSize(410, 200))
self.wgt_video.setMaximumSize(QtCore.QSize(16777215, 16777215))
# self.wgt_video.setMinimumSize(QtCore.QSize(410, 200))
# self.wgt_video.setMaximumSize(QtCore.QSize(16777215, 16777215))
self.widget = MyWidget(self.centralwidget)
self.widget.setGeometry(0,400,800,3)
self.widget_bottom = MyWidget(self.centralwidget)
self.widget_bottom.setGeometry(0,430,800,3)
palette = QtGui.QPalette()
brush = QtGui.QBrush(QtGui.QColor(0, 0, 0))
brush.setStyle(QtCore.Qt.SolidPattern)
......@@ -225,7 +255,7 @@ class Ui_MainWindow(object):
self.zm_tableWidget.setRowCount(0)
self.zm_tableWidget.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows)
self.horizontalLayout_2.addWidget(self.zm_tableWidget)
self.tabWidget.addTab(self.zm_tab, "")
# self.tabWidget.addTab(self.zm_tab, "")
self.pb_tab = QtWidgets.QWidget()
self.pb_tab.setObjectName("pb_tab")
self.horizontalLayout_3 = QtWidgets.QHBoxLayout(self.pb_tab)
......@@ -236,7 +266,7 @@ class Ui_MainWindow(object):
self.pb_tableWidget.setRowCount(0)
self.pb_tableWidget.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows)
self.horizontalLayout_3.addWidget(self.pb_tableWidget)
self.tabWidget.addTab(self.pb_tab, "")
# self.tabWidget.addTab(self.pb_tab, "")
self.shuiping.addWidget(self.tabWidget)
self.shuiping.setStretch(0, 3)
self.shuiping.setStretch(1, 5)
......@@ -327,8 +357,14 @@ class Ui_MainWindow(object):
self.menu.setObjectName("menu")
self.menu_2 = QtWidgets.QMenu(self.menubar)
self.menu_2.setObjectName("menu_2")
self.menu_3 = QtWidgets.QMenu(self.menubar)
self.menu_3.setObjectName("menu_3")
# self.menu_3 = QtWidgets.QMenu(self.menubar)
# self.menu_3.setObjectName("menu_3")
self.menu_4 = QtWidgets.QMenu(self.menubar)
self.menu_4.setObjectName("menu_4")
self.menu_5 = QtWidgets.QMenu(self.menubar)
self.menu_5.setObjectName("menu_5")
self.menu_6 = QtWidgets.QMenu(self.menubar)
self.menu_6.setObjectName("menu_6")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
......@@ -355,12 +391,26 @@ class Ui_MainWindow(object):
self.action_redo = QtWidgets.QAction(MainWindow)
# self.action_redo.setFont(font)
self.action_redo.setObjectName("action_redo")
self.action_3 = QtWidgets.QAction(MainWindow)
self.action_3.setObjectName("action_3")
self.action_4 = QtWidgets.QAction(MainWindow)
self.action_4.setObjectName("action_4")
self.action_5 = QtWidgets.QAction(MainWindow)
self.action_5.setObjectName("action_5")
self.action_3 = QtWidgets.QAction("旁白区间检测",self,triggered=self.show_detect_dialog)
self.action_3.setEnabled(False)
self.action_4 = QtWidgets.QAction("旁白音频合成",self,triggered=self.show_assemble_dialog)
self.action_4.setEnabled(False)
self.action_5 = QtWidgets.QAction("旁白导入",self,triggered=self.import_excel)
self.action_5.setEnabled(False)
self.action_6 = QtWidgets.QAction("字幕上边界++",self,triggered=self.up_ocr)
self.action_6.setEnabled(True)
self.action_7 = QtWidgets.QAction("字幕上边界--",self,triggered=self.down_ocr)
self.action_7.setEnabled(True)
self.action_8 = QtWidgets.QAction("字幕下边界++",self,triggered=self.up_ocr_bottom)
self.action_8.setEnabled(True)
self.action_9 = QtWidgets.QAction("字幕下边界--",self,triggered=self.down_ocr_bottom)
self.action_9.setEnabled(True)
# self.action_3.setObjectName("action_3")
# self.action_4 = QtWidgets.QAction(MainWindow)
# self.action_4.setObjectName("action_4")
# self.action_5 = QtWidgets.QAction(MainWindow)
# self.action_5.setObjectName("action_5")
self.action_operate = QtWidgets.QAction(MainWindow)
self.action_operate.setObjectName("action_operate")
self.action_export = QtWidgets.QAction(MainWindow)
......@@ -384,13 +434,22 @@ class Ui_MainWindow(object):
self.menu_2.addSeparator()
self.menu_2.addAction(self.action_insert_aside_from_now)
self.menu_2.addAction(self.action_operate)
self.menu_3.addAction(self.action_3)
self.menu_3.addAction(self.action_4)
self.menu_3.addAction(self.action_5)
self.menu_3.addSeparator()
# self.menu_3.addAction(self.action_3)
# self.menu_3.addAction(self.action_4)
# self.menu_3.addAction(self.action_5)
# self.menu_3.addSeparator()
self.menubar.addAction(self.menu.menuAction())
self.menubar.addAction(self.menu_2.menuAction())
self.menubar.addAction(self.menu_3.menuAction())
self.menubar.addAction(self.action_3)
self.menubar.addAction(self.action_4)
self.menubar.addAction(self.action_5)
self.menubar.addAction(self.action_6)
self.menubar.addAction(self.action_7)
self.menubar.addAction(self.action_8)
self.menubar.addAction(self.action_9)
# self.menubar.addAction(self.menu_5.menuAction())
# self.menubar.addAction(self.menu_6.menuAction())
# self.menubar.addAction(self.menu_3.menuAction())
self.retranslateUi(MainWindow)
self.tabWidget.setCurrentIndex(0)
......@@ -410,7 +469,10 @@ class Ui_MainWindow(object):
self.pb_label.setText(_translate("MainWindow", "刻度"))
self.menu.setTitle(_translate("MainWindow", "文件"))
self.menu_2.setTitle(_translate("MainWindow", "编辑"))
self.menu_3.setTitle(_translate("MainWindow", "功能按键"))
# self.menu_3.setTitle(_translate("MainWindow", "功能按键"))
self.menu_4.setTitle(_translate("MainWindow", "旁白区间检测"))
self.menu_5.setTitle(_translate("MainWindow", "旁白音频合成"))
self.menu_6.setTitle(_translate("MainWindow", "旁白导入"))
self.setting.setText(_translate("MainWindow", "设置"))
self.action_open_project.setText(_translate("MainWindow", "打开"))
self.import_movie.setText(_translate("MainWindow", "视频导入"))
......@@ -418,15 +480,14 @@ class Ui_MainWindow(object):
self.action_save.setText(_translate("MainWindow", "保存并备份"))
self.action_undo.setText(_translate("MainWindow", "撤销"))
self.action_redo.setText(_translate("MainWindow", "重做"))
self.action_3.setText(_translate("MainWindow", "旁白区间检测"))
self.action_4.setText(_translate("MainWindow", "旁白音频合成"))
self.action_5.setText(_translate("MainWindow", "旁白导入"))
# self.action_3.setText(_translate("MainWindow", "旁白区间检测"))
# self.action_4.setText(_translate("MainWindow", "旁白音频合成"))
# self.action_5.setText(_translate("MainWindow", "旁白导入"))
self.action_operate.setText(_translate("MainWindow", "操作表格"))
self.action_export.setText(_translate("MainWindow", "导出"))
self.action_insert_aside_from_now.setText(_translate("MainWindow", "当前位置插入旁白"))
self.action_create.setText(_translate("MainWindow", "新建"))
from myVideoWidget import myVideoWidget
from myvideoslider import myVideoSlider
from mywidgetcontents import myWidgetContents
......@@ -96,7 +96,8 @@ class Element:
def to_list(self):
return [self.st_time_sec, self.ed_time_sec, self.subtitle, self.suggest, self.aside, self.speed]
def to_short_list(self):
return [self.st_time_sec, self.subtitle, self.aside, self.speed]
# return [self.st_time_sec, self.subtitle, self.aside, self.speed]
return [self.st_time_sec, self.ed_time_sec, self.subtitle, self.suggest, self.aside, self.speed]
def to_aside_list(self):
# return [self.st_time_sec, self.ed_time_sec, self.suggest, self.aside, self.speed]
return [self.st_time_sec, self.suggest, self.aside, self.speed]
......@@ -119,6 +120,7 @@ class ProjectContext:
self.subtitle_list = []
self.aside_list = []
self.all_elements = []
self.speaker_type = None
self.speaker_info = None
self.speaker_speed = None
self.duration = 0
......@@ -128,7 +130,8 @@ class ProjectContext:
self.aside_header = ['起始时间', '推荐字数', '解说脚本',"语速", "预览音频"]
self.subtitle_header = ["起始时间", "终止时间", "字幕"]
self.contentHeader = ["起始时间", "字幕", "解说脚本", "语速"]
# self.contentHeader = ["起始时间", "字幕", "解说脚本", "语速"]
self.contentHeader = ["起始时间", "结束时间", "字幕", "推荐字数", "解说脚本", "语速", "预览音频"]
self.excel_sheet_name = "旁白插入位置建议"
self.history_records = []
self.records_pos = 0
......@@ -174,6 +177,7 @@ class ProjectContext:
if not os.path.exists(self.conf_path):
print("conf file does not exist, 找管理员要")
return
print(self.conf_path)
with open(self.conf_path, 'r', encoding='utf8') as f:
info = json.load(f)
# print(json.dumps(info, ensure_ascii=False, indent=4))
......@@ -181,17 +185,20 @@ class ProjectContext:
self.excel_path = info["excel_path"]
self.speaker_info = info["speaker_info"]["speaker_id"]
self.speaker_speed = info["speaker_info"]["speaker_speed"]
self.speaker_type = info["speaker_info"]["speaker_type"] if "speaker_type" in info["speaker_info"] else "科大讯飞"
self.detected = info["detection_info"]["detected"]
self.nd_process = info["detection_info"]["nd_process"]
self.last_time = info["detection_info"]["last_time"]
self.caption_boundings = info["detection_info"]["caption_boundings"]
self.has_subtitle = info["detection_info"]["has_subtitle"]
# 当前工程下没有配置文件,就初始化一份
# 当前工程下没有配置文件,就初始化一份``
if self.conf_path != this_conf_path:
self.conf_path = this_conf_path
print("11111sava")
self.save_conf()
def save_conf(self):
print(self.speaker_speed)
with open(self.conf_path, 'w', encoding='utf-8') as f:
# if len(self.caption_boundings) > 0:
# print(type(self.caption_boundings[0]))
......@@ -207,6 +214,7 @@ class ProjectContext:
"has_subtitle": self.has_subtitle
},
"speaker_info": {
"speaker_type": self.speaker_type,
"speaker_id": self.speaker_info,
"speaker_speed": self.speaker_speed
}
......@@ -223,6 +231,7 @@ class ProjectContext:
# 先备份文件,再覆盖主文件,可选是否需要备份,默认需要备份
# 20221030:添加旁白检测的进度
def save_project(self, need_save_new: bool=False) -> str:
print("22222sava")
self.save_conf()
# all_element = sorted(all_element, key=lambda x: float(x.st_time_sec))
print("current excel_path:", self.excel_path)
......@@ -254,6 +263,11 @@ class ProjectContext:
if not self.initial_ing:
save_excel_to_path(self.all_elements, self.excel_path, self.write_header, self.excel_sheet_name)
def refresh_speed(self, row, speed: str)->None:
self.all_elements[int(row)].speed = speed
if not self.initial_ing:
save_excel_to_path(self.all_elements, self.excel_path, self.write_header, self.excel_sheet_name)
# 加载整个工程,填充到ProjectContext上下文中
def load_project(self):
pass
......@@ -344,6 +358,22 @@ class ProjectContext:
self.speaker_info = speaker_name[0]
return tuple(speaker_name)
def get_all_speaker_zju_info(self):
"""获取所有说话人的名字、性别及年龄段等信息
用于显示在人机交互界面上,方便用户了解说话人并进行选择
"""
f = open(constant.Pathes.speaker_conf_path, encoding="utf-8")
content = json.load(f)
speaker_name = []
for speaker in content["speaker_zju_details"]:
speaker_name.append(
",".join([speaker["name"], speaker["gender"], speaker["age_group"]]))
if self.speaker_info is None:
self.speaker_info = speaker_name[0]
return tuple(speaker_name)
def init_speakers(self):
"""初始化说话人信息
......@@ -354,6 +384,8 @@ class ProjectContext:
content = json.load(f)
for speaker_info in content["speaker_details"]:
self.speakers.append(Speaker(speaker_info))
for speaker_info in content["speaker_zju_details"]:
self.speakers.append(Speaker(speaker_info))
def choose_speaker(self, speaker_name: str) -> Speaker:
"""选择说话人
......
from PyQt5.QtMultimediaWidgets import QVideoWidget
from PyQt5.QtCore import *
from PyQt5.QtMultimedia import QMediaPlayer
class myVideoWidget(QVideoWidget):
......@@ -7,6 +8,8 @@ class myVideoWidget(QVideoWidget):
def __init__(self, parent=None):
super(QVideoWidget, self).__init__(parent)
self.setAspectRatioMode(Qt.IgnoreAspectRatio)
def mouseDoubleClickEvent(self, QMouseEvent): #双击事件
......
{"video_path": null, "excel_path": null, "detection_info": {"detected": false, "nd_process": 0.0, "last_time": 0.0, "caption_boundings": [], "has_subtitle": true}, "speaker_info": {"speaker_id": "\u6653\u6653\uff0c\u5973\uff0c\u5e74\u8f7b\u4eba", "speaker_speed": "1.10(4.5\u5b57/\u79d2)"}}
\ No newline at end of file
{"video_path": null, "excel_path": null, "detection_info": {"detected": false, "nd_process": 0.0, "last_time": 0.0, "caption_boundings": [], "has_subtitle": true}, "speaker_info": {"speaker_type": "\u6d59\u5927\u5185\u90e8tts", "speaker_id": "test\uff0c\u5973\uff0c\u5e74\u8f7b\u4eba", "speaker_speed": "1.00(4\u5b57/\u79d2)"}}
\ No newline at end of file
......@@ -139,5 +139,16 @@
"audio_path": "./res/speaker_audio/Yunye.wav",
"speaker_code": "zh-CN-YunyeNeural"
}
]
],
"speaker_zju_details": [{
"id": 0,
"name": "test",
"language": "中文(普通话,简体)",
"age_group": "年轻人",
"gender": "女",
"description": "休闲、放松的语音,用于自发性对话和会议听录。",
"audio_path": "./res/speaker_zju_audio/local_tts_example.wav",
"speaker_code": "",
"speaker_type":"1"
}]
}
\ No newline at end of file
......@@ -8,6 +8,7 @@ from setting_dialog_ui import Ui_Dialog
from utils import validate_and_get_filepath, replace_path_suffix
import winsound
import constant
audioPlayed = winsound.PlaySound(None, winsound.SND_NODEFAULT)
......@@ -19,41 +20,98 @@ class Setting_Dialog(QDialog, Ui_Dialog):
self.setupUi(self)
self.setWindowTitle("设置")
self.projectContext = projectContext
self.refresh(self.projectContext)
self.refresh_flag = False
self.clear_flag = False
self.comboBox_0.currentIndexChanged.connect(self.choose)
self.comboBox.currentIndexChanged.connect(self.speaker_change_slot)
self.comboBox_2.currentIndexChanged.connect(self.speed_change_slot)
self.pushButton.clicked.connect(self.play_audio_slot)
def refresh(self,projectContext):
try:
self.refresh_flag = True
self.clear_flag = True
self.comboBox_0.clear()
self.comboBox.clear()
self.comboBox_2.clear()
# todo 把所有说话人都加上来
self.speaker_li = self.projectContext.get_all_speaker_info()
for i in self.speaker_li:
self.comboBox.addItem(i)
self.speaker_li = projectContext.get_all_speaker_info()
self.speaker_zju_li = projectContext.get_all_speaker_zju_info() #本地tts
self.speed_list_zju = ["1.00(4字/秒)", "1.10(4.5字/秒)", "1.25(5字/秒)", "1.50(6字/秒)", "1.75(7字/秒)", "2.00(8字/秒)", "2.50(10字/秒)"] #本地tts
# for i in self.speaker_li:
# self.comboBox.addItem(i)
self.speed_li_2 = ["1.00(4字/秒)", "1.10(4.5字/秒)", "1.25(5字/秒)", "1.50(6字/秒)", "1.75(7字/秒)", "2.00(8字/秒)", "2.50(10字/秒)"]
# self.comboBox_2.addItems(self.speed_li_2)
self.speaker_types = ["科大讯飞", "浙大内部tts"]
self.comboBox_0.addItems(self.speaker_types)
print(projectContext.speaker_type)
if projectContext.speaker_type is None or projectContext.speaker_type == "":
self.comboBox_0.setCurrentIndex(0)
else:
self.comboBox_0.setCurrentIndex(self.speaker_types.index(projectContext.speaker_type))
if self.comboBox_0.currentIndex() ==0: #讯飞
self.comboBox.addItems(self.speaker_li)
self.comboBox_2.addItems(self.speed_li_2)
else:
# local
self.comboBox.addItems(self.speaker_zju_li)
self.comboBox_2.addItems(self.speed_list_zju)
self.clear_flag = False
if self.projectContext.speaker_info is None:
if projectContext.speaker_info is None or projectContext.speaker_info == "":
self.comboBox.setCurrentIndex(0)
else:
self.comboBox.setCurrentIndex(self.speaker_li.index(self.projectContext.speaker_info))
if self.projectContext.speaker_speed is None:
print(projectContext.speaker_info)
self.comboBox.setCurrentIndex(self.speaker_li.index(projectContext.speaker_info) if self.comboBox_0.currentIndex() ==0 else self.speaker_zju_li.index(projectContext.speaker_info))
print(projectContext.speaker_speed)
if projectContext.speaker_speed is None or projectContext.speaker_speed == "":
self.comboBox_2.setCurrentIndex(0)
else:
self.comboBox_2.setCurrentIndex(self.speed_li_2.index(self.projectContext.speaker_speed))
self.comboBox.currentIndexChanged.connect(self.speaker_change_slot)
self.comboBox_2.currentIndexChanged.connect(self.speed_change_slot)
self.pushButton.clicked.connect(self.play_audio_slot)
self.comboBox_2.setCurrentIndex(self.speed_li_2.index(projectContext.speaker_speed) if self.comboBox_0.currentIndex() ==0 else self.speed_list_zju.index(projectContext.speaker_speed))
finally:
self.refresh_flag = False
def choose(self):
if self.refresh_flag:
return
print(self.comboBox_0.currentIndex())
self.comboBox.clear()
self.comboBox_2.clear()
self.projectContext.speaker_type = self.comboBox_0.currentText()
if self.comboBox_0.currentIndex() ==0:
print("讯飞")
self.comboBox.addItems(self.speaker_li)
self.comboBox_2.addItems(self.speed_li_2)
# constant.Content.SpeedList.clear()
# constant.Content.SpeedList = self.speed_li_2
else:
print("local")
self.comboBox.addItems(self.speaker_zju_li)
self.comboBox_2.addItems(self.speed_list_zju)
# constant.Content.SpeedList.clear()
# constant.Content.SpeedList = self.speed_list_zju
def content_fresh(self):
"""刷新界面中的内容
将工程信息中的说话人信息、说话人语速更新到界面中,如果未选择则初始化为第一个选项
"""
if self.projectContext.speaker_info is None:
print(self.projectContext.speaker_info)
if self.projectContext.speaker_info is None or self.projectContext.speaker_info == "" :
self.comboBox.setCurrentIndex(0)
else:
self.comboBox.setCurrentIndex(self.speaker_li.index(self.projectContext.speaker_info))
if self.projectContext.speaker_speed is None:
self.comboBox.setCurrentIndex(self.speaker_li.index(self.projectContext.speaker_info) if self.comboBox_0.currentIndex() ==0 else self.speaker_zju_li.index(self.projectContext.speaker_info))
if self.projectContext.speaker_speed is None or self.projectContext.speaker_speed == "":
self.comboBox_2.setCurrentIndex(0)
else:
self.comboBox_2.setCurrentIndex(self.speed_li_2.index(self.projectContext.speaker_speed))
self.comboBox_2.setCurrentIndex(self.speed_li_2.index(self.projectContext.speaker_speed) if self.comboBox_0.currentIndex() ==0 else self.speed_list_zju.index(self.projectContext.speaker_speed))
def speaker_change_slot(self):
"""切换说话人
......@@ -61,6 +119,8 @@ class Setting_Dialog(QDialog, Ui_Dialog):
将当前的说话人设置为工程的说话人,并保存到配置文件中
"""
if self.clear_flag:
return
self.projectContext.speaker_info = self.comboBox.currentText()
self.projectContext.save_conf()
# print("self.projectContext.speaker_info:", self.projectContext.speaker_info)
......@@ -71,6 +131,8 @@ class Setting_Dialog(QDialog, Ui_Dialog):
将当前的语速设置为工程的语速,并保存到配置文件中
"""
if self.clear_flag:
return
self.projectContext.speaker_speed = self.comboBox_2.currentText()
self.projectContext.save_conf()
......
......@@ -19,20 +19,32 @@ class Ui_Dialog(object):
self.gridLayout_2.setObjectName("gridLayout_2")
self.gridLayout = QtWidgets.QGridLayout()
self.gridLayout.setObjectName("gridLayout")
self.label_2 = QtWidgets.QLabel(Dialog)
self.label_2.setObjectName("label_2")
self.gridLayout.addWidget(self.label_2, 0, 0, 1, 1)
self.comboBox_0 = QtWidgets.QComboBox(Dialog)
self.comboBox_0.setCurrentText("")
self.comboBox_0.setObjectName("comboBox_0")
self.gridLayout.addWidget(self.comboBox_0, 0, 1, 1, 1)
self.label_3 = QtWidgets.QLabel(Dialog)
self.label_3.setObjectName("label_3")
self.gridLayout.addWidget(self.label_3, 0, 0, 1, 1)
self.gridLayout.addWidget(self.label_3, 1, 0, 1, 1)
self.comboBox = QtWidgets.QComboBox(Dialog)
self.comboBox.setCurrentText("")
self.comboBox.setObjectName("comboBox")
self.gridLayout.addWidget(self.comboBox, 0, 1, 1, 1)
self.gridLayout.addWidget(self.comboBox, 1, 1, 1, 1)
self.label_4 = QtWidgets.QLabel(Dialog)
self.label_4.setObjectName("label_4")
self.gridLayout.addWidget(self.label_4, 1, 0, 1, 1)
self.gridLayout.addWidget(self.label_4, 2, 0, 1, 1)
self.comboBox_2 = QtWidgets.QComboBox(Dialog)
self.comboBox_2.setCurrentText("")
self.comboBox_2.setObjectName("comboBox_2")
self.gridLayout.addWidget(self.comboBox_2, 1, 1, 1, 1)
self.gridLayout.addWidget(self.comboBox_2, 2, 1, 1, 1)
self.gridLayout.setRowMinimumHeight(0, 60)
self.gridLayout.setRowMinimumHeight(1, 60)
self.gridLayout.setColumnStretch(1, 1)
......@@ -50,6 +62,8 @@ class Ui_Dialog(object):
def retranslateUi(self, Dialog):
_translate = QtCore.QCoreApplication.translate
Dialog.setWindowTitle(_translate("Dialog", "Dialog"))
self.label_2.setText(_translate("Dialog", "TTS引擎"))
self.label_3.setText(_translate("Dialog", "旁白说话人:"))
self.label_3.setText(_translate("Dialog", "旁白说话人:"))
self.label_4.setText(_translate("Dialog", "旁白语速:"))
self.pushButton.setText(_translate("Dialog", "播放样例音频"))
......@@ -27,6 +27,7 @@ from azure.cognitiveservices.speech import SpeechConfig, SpeechSynthesizer, Resu
from azure.cognitiveservices.speech.audio import AudioOutputConfig
import openpyxl
import shutil
from vits_chinese import tts
tmp_file = 'tmp.wav'
adjusted_wav_path = "adjusted.wav"
......@@ -53,6 +54,8 @@ class Speaker:
self.speaker_code = speaker_info["speaker_code"]
self.description = speaker_info["description"]
self.voice_example = speaker_info["audio_path"]
self.speaker_type = speaker_info["speaker_type"] if "speaker_type" in speaker_info else None #speakers.json里面新加字段speaker_type =1 表示用local tts
def init_speakers():
......@@ -65,6 +68,8 @@ def init_speakers():
global speakers
for speaker_info in content["speaker_details"]:
speakers.append(Speaker(speaker_info))
for speaker_info in content["speaker_zju_details"]:
speakers.append(Speaker(speaker_info))
def choose_speaker(speaker_name: str) -> Speaker:
......@@ -94,6 +99,12 @@ def speech_synthesis(text: str, output_file: str, speaker: Speaker, speed: float
speed (float, optional): 指定的音频语速. Defaults to 1.0.
"""
if not os.path.exists(os.path.dirname(output_file)): # 如果路径不存在
print("output_file路径不存在,创建:", os.path.dirname(output_file))
os.makedirs(os.path.dirname(output_file))
if speaker.speaker_type != None and speaker.speaker_type == "1":
tts(text, speed, output_file)
else:
speech_config = SpeechConfig(
subscription="db34d38d2d3447d482e0f977c66bd624",
region="eastus"
......@@ -103,9 +114,7 @@ def speech_synthesis(text: str, output_file: str, speaker: Speaker, speed: float
speech_config.speech_synthesis_voice_name = speaker.speaker_code
# 先把合成的语音文件输出得到tmp.wav中,便于可能的调速需求
if not os.path.exists(os.path.dirname(output_file)): # 如果路径不存在
print("output_file路径不存在,创建:", os.path.dirname(output_file))
os.makedirs(os.path.dirname(output_file))
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
ssml_string = f"""
<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{speech_config.speech_synthesis_language}">
......
......@@ -80,7 +80,8 @@ if __name__ == '__main__':
mainWindow.setWindowTitle(f"无障碍电影制作软件(当前工程为:{project_name})")
mainWindow.renew_signal.connect(change_project_path)
apply_stylesheet(app, theme='dark_amber.xml')
mainWindow.show()
# mainWindow.show()
mainWindow.showMaximized()
currentExitCode = app.exec_()
app = None
except Exception as e:
......
### 安装环境
```
pip install -r requirements.txt
```
### 接口
infer.py
\ No newline at end of file
import sys
import os
sys.path.append(os.path.dirname(__file__))
from .infer import tts
from .utils import get_hparams_from_file
\ No newline at end of file
import copy
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import commons
import modules
from modules import LayerNorm
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=4,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class Decoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertConfig, BertTokenizer
class CharEmbedding(nn.Module):
def __init__(self, model_dir):
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(model_dir)
self.bert_config = BertConfig.from_pretrained(model_dir)
self.hidden_size = self.bert_config.hidden_size
self.bert = BertModel(self.bert_config)
self.proj = nn.Linear(self.hidden_size, 256)
self.linear = nn.Linear(256, 3)
def text2Token(self, text):
token = self.tokenizer.tokenize(text)
txtid = self.tokenizer.convert_tokens_to_ids(token)
return txtid
def forward(self, inputs_ids, inputs_masks, tokens_type_ids):
out_seq = self.bert(input_ids=inputs_ids,
attention_mask=inputs_masks,
token_type_ids=tokens_type_ids)[0]
out_seq = self.proj(out_seq)
return out_seq
class TTSProsody(object):
def __init__(self, path, device):
self.device = device
self.char_model = CharEmbedding(path)
self.char_model.load_state_dict(
torch.load(
os.path.join(path, 'prosody_model.pt'),
map_location="cpu"
),
strict=False
)
self.char_model.eval()
self.char_model.to(self.device)
def get_char_embeds(self, text):
input_ids = self.char_model.text2Token(text)
input_masks = [1] * len(input_ids)
type_ids = [0] * len(input_ids)
input_ids = torch.LongTensor([input_ids]).to(self.device)
input_masks = torch.LongTensor([input_masks]).to(self.device)
type_ids = torch.LongTensor([type_ids]).to(self.device)
with torch.no_grad():
char_embeds = self.char_model(
input_ids, input_masks, type_ids).squeeze(0).cpu()
return char_embeds
def expand_for_phone(self, char_embeds, length): # length of phones for char
assert char_embeds.size(0) == len(length)
expand_vecs = list()
for vec, leng in zip(char_embeds, length):
vec = vec.expand(leng, -1)
expand_vecs.append(vec)
expand_embeds = torch.cat(expand_vecs, 0)
assert expand_embeds.size(0) == sum(length)
return expand_embeds.numpy()
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prosody = TTSProsody('./bert/', device)
while True:
text = input("请输入文本:")
prosody.get_char_embeds(text)
from .ProsodyModel import TTSProsody
\ No newline at end of file
{
"attention_probs_dropout_prob": 0.1,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"type_vocab_size": 2,
"vocab_size": 21128
}
def is_chinese(uchar):
if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
return True
else:
return False
pinyin_dict = {
"a": ("^", "a"),
"ai": ("^", "ai"),
"an": ("^", "an"),
"ang": ("^", "ang"),
"ao": ("^", "ao"),
"ba": ("b", "a"),
"bai": ("b", "ai"),
"ban": ("b", "an"),
"bang": ("b", "ang"),
"bao": ("b", "ao"),
"be": ("b", "e"),
"bei": ("b", "ei"),
"ben": ("b", "en"),
"beng": ("b", "eng"),
"bi": ("b", "i"),
"bian": ("b", "ian"),
"biao": ("b", "iao"),
"bie": ("b", "ie"),
"bin": ("b", "in"),
"bing": ("b", "ing"),
"bo": ("b", "o"),
"bu": ("b", "u"),
"ca": ("c", "a"),
"cai": ("c", "ai"),
"can": ("c", "an"),
"cang": ("c", "ang"),
"cao": ("c", "ao"),
"ce": ("c", "e"),
"cen": ("c", "en"),
"ceng": ("c", "eng"),
"cha": ("ch", "a"),
"chai": ("ch", "ai"),
"chan": ("ch", "an"),
"chang": ("ch", "ang"),
"chao": ("ch", "ao"),
"che": ("ch", "e"),
"chen": ("ch", "en"),
"cheng": ("ch", "eng"),
"chi": ("ch", "iii"),
"chong": ("ch", "ong"),
"chou": ("ch", "ou"),
"chu": ("ch", "u"),
"chua": ("ch", "ua"),
"chuai": ("ch", "uai"),
"chuan": ("ch", "uan"),
"chuang": ("ch", "uang"),
"chui": ("ch", "uei"),
"chun": ("ch", "uen"),
"chuo": ("ch", "uo"),
"ci": ("c", "ii"),
"cong": ("c", "ong"),
"cou": ("c", "ou"),
"cu": ("c", "u"),
"cuan": ("c", "uan"),
"cui": ("c", "uei"),
"cun": ("c", "uen"),
"cuo": ("c", "uo"),
"da": ("d", "a"),
"dai": ("d", "ai"),
"dan": ("d", "an"),
"dang": ("d", "ang"),
"dao": ("d", "ao"),
"de": ("d", "e"),
"dei": ("d", "ei"),
"den": ("d", "en"),
"deng": ("d", "eng"),
"di": ("d", "i"),
"dia": ("d", "ia"),
"dian": ("d", "ian"),
"diao": ("d", "iao"),
"die": ("d", "ie"),
"ding": ("d", "ing"),
"diu": ("d", "iou"),
"dong": ("d", "ong"),
"dou": ("d", "ou"),
"du": ("d", "u"),
"duan": ("d", "uan"),
"dui": ("d", "uei"),
"dun": ("d", "uen"),
"duo": ("d", "uo"),
"e": ("^", "e"),
"ei": ("^", "ei"),
"en": ("^", "en"),
"ng": ("^", "en"),
"eng": ("^", "eng"),
"er": ("^", "er"),
"fa": ("f", "a"),
"fan": ("f", "an"),
"fang": ("f", "ang"),
"fei": ("f", "ei"),
"fen": ("f", "en"),
"feng": ("f", "eng"),
"fo": ("f", "o"),
"fou": ("f", "ou"),
"fu": ("f", "u"),
"ga": ("g", "a"),
"gai": ("g", "ai"),
"gan": ("g", "an"),
"gang": ("g", "ang"),
"gao": ("g", "ao"),
"ge": ("g", "e"),
"gei": ("g", "ei"),
"gen": ("g", "en"),
"geng": ("g", "eng"),
"gong": ("g", "ong"),
"gou": ("g", "ou"),
"gu": ("g", "u"),
"gua": ("g", "ua"),
"guai": ("g", "uai"),
"guan": ("g", "uan"),
"guang": ("g", "uang"),
"gui": ("g", "uei"),
"gun": ("g", "uen"),
"guo": ("g", "uo"),
"ha": ("h", "a"),
"hai": ("h", "ai"),
"han": ("h", "an"),
"hang": ("h", "ang"),
"hao": ("h", "ao"),
"he": ("h", "e"),
"hei": ("h", "ei"),
"hen": ("h", "en"),
"heng": ("h", "eng"),
"hong": ("h", "ong"),
"hou": ("h", "ou"),
"hu": ("h", "u"),
"hua": ("h", "ua"),
"huai": ("h", "uai"),
"huan": ("h", "uan"),
"huang": ("h", "uang"),
"hui": ("h", "uei"),
"hun": ("h", "uen"),
"huo": ("h", "uo"),
"ji": ("j", "i"),
"jia": ("j", "ia"),
"jian": ("j", "ian"),
"jiang": ("j", "iang"),
"jiao": ("j", "iao"),
"jie": ("j", "ie"),
"jin": ("j", "in"),
"jing": ("j", "ing"),
"jiong": ("j", "iong"),
"jiu": ("j", "iou"),
"ju": ("j", "v"),
"juan": ("j", "van"),
"jue": ("j", "ve"),
"jun": ("j", "vn"),
"ka": ("k", "a"),
"kai": ("k", "ai"),
"kan": ("k", "an"),
"kang": ("k", "ang"),
"kao": ("k", "ao"),
"ke": ("k", "e"),
"kei": ("k", "ei"),
"ken": ("k", "en"),
"keng": ("k", "eng"),
"kong": ("k", "ong"),
"kou": ("k", "ou"),
"ku": ("k", "u"),
"kua": ("k", "ua"),
"kuai": ("k", "uai"),
"kuan": ("k", "uan"),
"kuang": ("k", "uang"),
"kui": ("k", "uei"),
"kun": ("k", "uen"),
"kuo": ("k", "uo"),
"la": ("l", "a"),
"lai": ("l", "ai"),
"lan": ("l", "an"),
"lang": ("l", "ang"),
"lao": ("l", "ao"),
"le": ("l", "e"),
"lei": ("l", "ei"),
"leng": ("l", "eng"),
"li": ("l", "i"),
"lia": ("l", "ia"),
"lian": ("l", "ian"),
"liang": ("l", "iang"),
"liao": ("l", "iao"),
"lie": ("l", "ie"),
"lin": ("l", "in"),
"ling": ("l", "ing"),
"liu": ("l", "iou"),
"lo": ("l", "o"),
"long": ("l", "ong"),
"lou": ("l", "ou"),
"lu": ("l", "u"),
"lv": ("l", "v"),
"luan": ("l", "uan"),
"lve": ("l", "ve"),
"lue": ("l", "ve"),
"lun": ("l", "uen"),
"luo": ("l", "uo"),
"ma": ("m", "a"),
"mai": ("m", "ai"),
"man": ("m", "an"),
"mang": ("m", "ang"),
"mao": ("m", "ao"),
"me": ("m", "e"),
"mei": ("m", "ei"),
"men": ("m", "en"),
"meng": ("m", "eng"),
"mi": ("m", "i"),
"mian": ("m", "ian"),
"miao": ("m", "iao"),
"mie": ("m", "ie"),
"min": ("m", "in"),
"ming": ("m", "ing"),
"miu": ("m", "iou"),
"mo": ("m", "o"),
"mou": ("m", "ou"),
"mu": ("m", "u"),
"na": ("n", "a"),
"nai": ("n", "ai"),
"nan": ("n", "an"),
"nang": ("n", "ang"),
"nao": ("n", "ao"),
"ne": ("n", "e"),
"nei": ("n", "ei"),
"nen": ("n", "en"),
"neng": ("n", "eng"),
"ni": ("n", "i"),
"nia": ("n", "ia"),
"nian": ("n", "ian"),
"niang": ("n", "iang"),
"niao": ("n", "iao"),
"nie": ("n", "ie"),
"nin": ("n", "in"),
"ning": ("n", "ing"),
"niu": ("n", "iou"),
"nong": ("n", "ong"),
"nou": ("n", "ou"),
"nu": ("n", "u"),
"nv": ("n", "v"),
"nuan": ("n", "uan"),
"nve": ("n", "ve"),
"nue": ("n", "ve"),
"nuo": ("n", "uo"),
"o": ("^", "o"),
"ou": ("^", "ou"),
"pa": ("p", "a"),
"pai": ("p", "ai"),
"pan": ("p", "an"),
"pang": ("p", "ang"),
"pao": ("p", "ao"),
"pe": ("p", "e"),
"pei": ("p", "ei"),
"pen": ("p", "en"),
"peng": ("p", "eng"),
"pi": ("p", "i"),
"pian": ("p", "ian"),
"piao": ("p", "iao"),
"pie": ("p", "ie"),
"pin": ("p", "in"),
"ping": ("p", "ing"),
"po": ("p", "o"),
"pou": ("p", "ou"),
"pu": ("p", "u"),
"qi": ("q", "i"),
"qia": ("q", "ia"),
"qian": ("q", "ian"),
"qiang": ("q", "iang"),
"qiao": ("q", "iao"),
"qie": ("q", "ie"),
"qin": ("q", "in"),
"qing": ("q", "ing"),
"qiong": ("q", "iong"),
"qiu": ("q", "iou"),
"qu": ("q", "v"),
"quan": ("q", "van"),
"que": ("q", "ve"),
"qun": ("q", "vn"),
"ran": ("r", "an"),
"rang": ("r", "ang"),
"rao": ("r", "ao"),
"re": ("r", "e"),
"ren": ("r", "en"),
"reng": ("r", "eng"),
"ri": ("r", "iii"),
"rong": ("r", "ong"),
"rou": ("r", "ou"),
"ru": ("r", "u"),
"rua": ("r", "ua"),
"ruan": ("r", "uan"),
"rui": ("r", "uei"),
"run": ("r", "uen"),
"ruo": ("r", "uo"),
"sa": ("s", "a"),
"sai": ("s", "ai"),
"san": ("s", "an"),
"sang": ("s", "ang"),
"sao": ("s", "ao"),
"se": ("s", "e"),
"sen": ("s", "en"),
"seng": ("s", "eng"),
"sha": ("sh", "a"),
"shai": ("sh", "ai"),
"shan": ("sh", "an"),
"shang": ("sh", "ang"),
"shao": ("sh", "ao"),
"she": ("sh", "e"),
"shei": ("sh", "ei"),
"shen": ("sh", "en"),
"sheng": ("sh", "eng"),
"shi": ("sh", "iii"),
"shou": ("sh", "ou"),
"shu": ("sh", "u"),
"shua": ("sh", "ua"),
"shuai": ("sh", "uai"),
"shuan": ("sh", "uan"),
"shuang": ("sh", "uang"),
"shui": ("sh", "uei"),
"shun": ("sh", "uen"),
"shuo": ("sh", "uo"),
"si": ("s", "ii"),
"song": ("s", "ong"),
"sou": ("s", "ou"),
"su": ("s", "u"),
"suan": ("s", "uan"),
"sui": ("s", "uei"),
"sun": ("s", "uen"),
"suo": ("s", "uo"),
"ta": ("t", "a"),
"tai": ("t", "ai"),
"tan": ("t", "an"),
"tang": ("t", "ang"),
"tao": ("t", "ao"),
"te": ("t", "e"),
"tei": ("t", "ei"),
"teng": ("t", "eng"),
"ti": ("t", "i"),
"tian": ("t", "ian"),
"tiao": ("t", "iao"),
"tie": ("t", "ie"),
"ting": ("t", "ing"),
"tong": ("t", "ong"),
"tou": ("t", "ou"),
"tu": ("t", "u"),
"tuan": ("t", "uan"),
"tui": ("t", "uei"),
"tun": ("t", "uen"),
"tuo": ("t", "uo"),
"wa": ("^", "ua"),
"wai": ("^", "uai"),
"wan": ("^", "uan"),
"wang": ("^", "uang"),
"wei": ("^", "uei"),
"wen": ("^", "uen"),
"weng": ("^", "ueng"),
"wo": ("^", "uo"),
"wu": ("^", "u"),
"xi": ("x", "i"),
"xia": ("x", "ia"),
"xian": ("x", "ian"),
"xiang": ("x", "iang"),
"xiao": ("x", "iao"),
"xie": ("x", "ie"),
"xin": ("x", "in"),
"xing": ("x", "ing"),
"xiong": ("x", "iong"),
"xiu": ("x", "iou"),
"xu": ("x", "v"),
"xuan": ("x", "van"),
"xue": ("x", "ve"),
"xun": ("x", "vn"),
"ya": ("^", "ia"),
"yan": ("^", "ian"),
"yang": ("^", "iang"),
"yao": ("^", "iao"),
"ye": ("^", "ie"),
"yi": ("^", "i"),
"yin": ("^", "in"),
"ying": ("^", "ing"),
"yo": ("^", "iou"),
"yong": ("^", "iong"),
"you": ("^", "iou"),
"yu": ("^", "v"),
"yuan": ("^", "van"),
"yue": ("^", "ve"),
"yun": ("^", "vn"),
"za": ("z", "a"),
"zai": ("z", "ai"),
"zan": ("z", "an"),
"zang": ("z", "ang"),
"zao": ("z", "ao"),
"ze": ("z", "e"),
"zei": ("z", "ei"),
"zen": ("z", "en"),
"zeng": ("z", "eng"),
"zha": ("zh", "a"),
"zhai": ("zh", "ai"),
"zhan": ("zh", "an"),
"zhang": ("zh", "ang"),
"zhao": ("zh", "ao"),
"zhe": ("zh", "e"),
"zhei": ("zh", "ei"),
"zhen": ("zh", "en"),
"zheng": ("zh", "eng"),
"zhi": ("zh", "iii"),
"zhong": ("zh", "ong"),
"zhou": ("zh", "ou"),
"zhu": ("zh", "u"),
"zhua": ("zh", "ua"),
"zhuai": ("zh", "uai"),
"zhuan": ("zh", "uan"),
"zhuang": ("zh", "uang"),
"zhui": ("zh", "uei"),
"zhun": ("zh", "uen"),
"zhuo": ("zh", "uo"),
"zi": ("z", "ii"),
"zong": ("z", "ong"),
"zou": ("z", "ou"),
"zu": ("z", "u"),
"zuan": ("z", "uan"),
"zui": ("z", "uei"),
"zun": ("z", "uen"),
"zuo": ("z", "uo"),
}
This source diff could not be displayed because it is too large. You can view the blob instead.
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
{
"train": {
"log_interval": 100,
"eval_interval": 10000,
"seed": 1234,
"epochs": 20000,
"learning_rate": 1e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 8,
"fp16_run": false,
"lr_decay": 0.999875,
"segment_size": 12800,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0
},
"data": {
"training_files":"filelists/train.txt",
"validation_files":"filelists/valid.txt",
"max_wav_value": 32768.0,
"sampling_rate": 16000,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": false,
"n_speakers": 0
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"upsample_rates": [8,8,2,2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16,4,4],
"n_layers_q": 3,
"use_spectral_norm": false
}
}
from models import SynthesizerTrn
from vits_pinyin import VITS_PinYin
from text import cleaned_text_to_sequence
from text.symbols import symbols
from .utils import get_hparams_from_file
from .utils import load_model
import torch
import argparse
import os
import re
from scipy.io import wavfile
import numpy as np
def save_wav(wav, path, rate):
wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6
wavfile.write(path, rate, wav.astype(np.int16))
example = [['天空呈现的透心的蓝,像极了当年。总在这样的时候,透过窗棂,心,在天空里无尽的游弋!柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。', 1],
['风的影子翻阅过淡蓝色的信笺,柔和的文字浅浅地漫过我安静的眸,一如几朵悠闲的云儿,忽而氤氲成汽,忽而修饰成花,铅华洗尽后的透彻和靓丽,爽爽朗朗,轻轻盈盈', 1],
['时光仿佛有穿越到了从前,在你诗情画意的眼波中,在你舒适浪漫的暇思里,我如风中的思绪徜徉广阔天际,仿佛一片沾染了快乐的羽毛,在云环影绕颤动里浸润着风的呼吸,风的诗韵,那清新的耳语,那婉约的甜蜜,那恬淡的温馨,将一腔情澜染得愈发的缠绵。', 1],]
class TTS:
def __init__(self):
parent_dir = os.path.dirname(os.path.abspath(__file__))
self.device = torch.device("cpu")
# pinyin
self.tts_front = VITS_PinYin(parent_dir+"/bert", self.device)
# config
hps = get_hparams_from_file(parent_dir + "/configs/bert_vits.json")
# model
self.net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
model_path = "/vits_bert_model.pth"
load_model(parent_dir + model_path, self.net_g)
self.net_g.eval()
self.net_g.to(self.device)
self.speed_map = {
1.00:1,
1.10:0.88,
1.25:0.8,
1.50:0.66,
1.75:0.57,
2.00:0.5,
2.50:0.4
}
def tts_calback(self,text, dur_scale=1):
"""
text : str 转化文本
dur_scale : float 速度 取值范围为[0.1,5],1为正常速度,0.1最快 5最慢
"""
phonemes, char_embeds = self.tts_front.chinese_to_phonemes(text)
input_ids = cleaned_text_to_sequence(phonemes)
with torch.no_grad():
x_tst = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)
x_tst_lengths = torch.LongTensor([len(input_ids)]).to(self.device)
x_tst_prosody = torch.FloatTensor(
char_embeds).unsqueeze(0).to(self.device)
audio = self.net_g.infer(x_tst, x_tst_lengths, x_tst_prosody, noise_scale=0.5,
length_scale=dur_scale)[0][0, 0].data.cpu().float().numpy()
del x_tst, x_tst_lengths, x_tst_prosody
return audio
def tts(text, speed, wav_path):
model = TTS()
# 文本和语速
# text = '你好呀哈。请问你是谁'
# speed = 1 # (0.1,5) 0.1最快 5最慢 default=1
st = time()
# 生成语音
audio = model.tts_calback(text,model.speed_map[speed])
# 保存wav文件
save_wav(audio, wav_path, 16000)
ed = time()
print(f'transform time:{ed-st:.4f}')
print(speed)
print(model.speed_map[speed])
from time import time
if __name__ == "__main__":
# 初始化模型
model = TTS()
# 文本和语速
text = '你好呀哈。请问你是谁'
speed = 1 # (0.1,5) 0.1最快 5最慢 default=1
st = time()
# 生成语音
audio = model.tts_calback(text,speed)
# 保存wav文件
save_wav(audio, f"./vits_infer_out/bert_vits4.wav", 16000)
ed = time()
print(f'transform time:{ed-st:.4f}')
import copy
import math
import torch
from torch import nn
from torch.nn import functional as F
import commons
import modules
import attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from commons import init_weights, get_padding
class DurationPredictor(nn.Module):
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class TextEncoder(nn.Module):
def __init__(
self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb = nn.Embedding(n_vocab, hidden_channels)
self.emb_bert = nn.Linear(256, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, bert):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
b = self.emb_bert(bert)
x = x + b
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(
modules.ResidualCouplingLayer(
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
mean_only=True,
)
)
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
def remove_weight_norm(self):
for i in range(self.n_flows):
self.flows[i * 2].remove_weight_norm()
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
def remove_weight_norm(self):
self.enc.remove_weight_norm()
class Generator(torch.nn.Module):
def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(
Conv2d(
1,
32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
1024,
1024,
(kernel_size, 1),
1,
padding=(get_padding(kernel_size, 1), 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(
self,
n_vocab,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=False,
**kwargs
):
super().__init__()
self.n_vocab = n_vocab
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.enc_p = TextEncoder(
n_vocab,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
)
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.dp = DurationPredictor(
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
)
if n_speakers > 1:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
def remove_weight_norm(self):
print("Removing weight norm...")
self.dec.remove_weight_norm()
self.flow.remove_weight_norm()
self.enc_q.remove_weight_norm()
def infer(self, x, x_lengths, bert, sid=None, noise_scale=1, length_scale=1, max_len=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, bert)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
logw = self.dp(x, x_mask, g=g)
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
x_mask.dtype
)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = commons.generate_path(w_ceil, attn_mask)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
return o, attn, y_mask, (z, z_p, m_p, logs_p)
import copy
import math
import numpy as np
import scipy
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm
import commons
from commons import init_weights, get_padding
from transforms import piecewise_rational_quadratic_transform
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
):
super(WN, self).__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x, x_mask=None):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class Flip(nn.Module):
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=p_dropout,
gin_channels=gin_channels,
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
def remove_weight_norm(self):
self.enc.remove_weight_norm()
class ConvFlow(nn.Module):
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
n_layers,
num_bins=10,
tail_bound=5.0,
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
self.filter_channels
)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
torch
scipy
transformers
pypinyin
from text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
def cleaned_text_to_sequence(cleaned_text):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
"""
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split()]
return sequence
def sequence_to_text(sequence):
"""Converts a sequence of IDs back to a string"""
result = ""
for symbol_id in sequence:
s = _id_to_symbol[symbol_id]
result += s
return result
pinyin_dict = {
"a": ("^", "a"),
"ai": ("^", "ai"),
"an": ("^", "an"),
"ang": ("^", "ang"),
"ao": ("^", "ao"),
"ba": ("b", "a"),
"bai": ("b", "ai"),
"ban": ("b", "an"),
"bang": ("b", "ang"),
"bao": ("b", "ao"),
"be": ("b", "e"),
"bei": ("b", "ei"),
"ben": ("b", "en"),
"beng": ("b", "eng"),
"bi": ("b", "i"),
"bian": ("b", "ian"),
"biao": ("b", "iao"),
"bie": ("b", "ie"),
"bin": ("b", "in"),
"bing": ("b", "ing"),
"bo": ("b", "o"),
"bu": ("b", "u"),
"ca": ("c", "a"),
"cai": ("c", "ai"),
"can": ("c", "an"),
"cang": ("c", "ang"),
"cao": ("c", "ao"),
"ce": ("c", "e"),
"cen": ("c", "en"),
"ceng": ("c", "eng"),
"cha": ("ch", "a"),
"chai": ("ch", "ai"),
"chan": ("ch", "an"),
"chang": ("ch", "ang"),
"chao": ("ch", "ao"),
"che": ("ch", "e"),
"chen": ("ch", "en"),
"cheng": ("ch", "eng"),
"chi": ("ch", "iii"),
"chong": ("ch", "ong"),
"chou": ("ch", "ou"),
"chu": ("ch", "u"),
"chua": ("ch", "ua"),
"chuai": ("ch", "uai"),
"chuan": ("ch", "uan"),
"chuang": ("ch", "uang"),
"chui": ("ch", "uei"),
"chun": ("ch", "uen"),
"chuo": ("ch", "uo"),
"ci": ("c", "ii"),
"cong": ("c", "ong"),
"cou": ("c", "ou"),
"cu": ("c", "u"),
"cuan": ("c", "uan"),
"cui": ("c", "uei"),
"cun": ("c", "uen"),
"cuo": ("c", "uo"),
"da": ("d", "a"),
"dai": ("d", "ai"),
"dan": ("d", "an"),
"dang": ("d", "ang"),
"dao": ("d", "ao"),
"de": ("d", "e"),
"dei": ("d", "ei"),
"den": ("d", "en"),
"deng": ("d", "eng"),
"di": ("d", "i"),
"dia": ("d", "ia"),
"dian": ("d", "ian"),
"diao": ("d", "iao"),
"die": ("d", "ie"),
"ding": ("d", "ing"),
"diu": ("d", "iou"),
"dong": ("d", "ong"),
"dou": ("d", "ou"),
"du": ("d", "u"),
"duan": ("d", "uan"),
"dui": ("d", "uei"),
"dun": ("d", "uen"),
"duo": ("d", "uo"),
"e": ("^", "e"),
"ei": ("^", "ei"),
"en": ("^", "en"),
"ng": ("^", "en"),
"eng": ("^", "eng"),
"er": ("^", "er"),
"fa": ("f", "a"),
"fan": ("f", "an"),
"fang": ("f", "ang"),
"fei": ("f", "ei"),
"fen": ("f", "en"),
"feng": ("f", "eng"),
"fo": ("f", "o"),
"fou": ("f", "ou"),
"fu": ("f", "u"),
"ga": ("g", "a"),
"gai": ("g", "ai"),
"gan": ("g", "an"),
"gang": ("g", "ang"),
"gao": ("g", "ao"),
"ge": ("g", "e"),
"gei": ("g", "ei"),
"gen": ("g", "en"),
"geng": ("g", "eng"),
"gong": ("g", "ong"),
"gou": ("g", "ou"),
"gu": ("g", "u"),
"gua": ("g", "ua"),
"guai": ("g", "uai"),
"guan": ("g", "uan"),
"guang": ("g", "uang"),
"gui": ("g", "uei"),
"gun": ("g", "uen"),
"guo": ("g", "uo"),
"ha": ("h", "a"),
"hai": ("h", "ai"),
"han": ("h", "an"),
"hang": ("h", "ang"),
"hao": ("h", "ao"),
"he": ("h", "e"),
"hei": ("h", "ei"),
"hen": ("h", "en"),
"heng": ("h", "eng"),
"hong": ("h", "ong"),
"hou": ("h", "ou"),
"hu": ("h", "u"),
"hua": ("h", "ua"),
"huai": ("h", "uai"),
"huan": ("h", "uan"),
"huang": ("h", "uang"),
"hui": ("h", "uei"),
"hun": ("h", "uen"),
"huo": ("h", "uo"),
"ji": ("j", "i"),
"jia": ("j", "ia"),
"jian": ("j", "ian"),
"jiang": ("j", "iang"),
"jiao": ("j", "iao"),
"jie": ("j", "ie"),
"jin": ("j", "in"),
"jing": ("j", "ing"),
"jiong": ("j", "iong"),
"jiu": ("j", "iou"),
"ju": ("j", "v"),
"juan": ("j", "van"),
"jue": ("j", "ve"),
"jun": ("j", "vn"),
"ka": ("k", "a"),
"kai": ("k", "ai"),
"kan": ("k", "an"),
"kang": ("k", "ang"),
"kao": ("k", "ao"),
"ke": ("k", "e"),
"kei": ("k", "ei"),
"ken": ("k", "en"),
"keng": ("k", "eng"),
"kong": ("k", "ong"),
"kou": ("k", "ou"),
"ku": ("k", "u"),
"kua": ("k", "ua"),
"kuai": ("k", "uai"),
"kuan": ("k", "uan"),
"kuang": ("k", "uang"),
"kui": ("k", "uei"),
"kun": ("k", "uen"),
"kuo": ("k", "uo"),
"la": ("l", "a"),
"lai": ("l", "ai"),
"lan": ("l", "an"),
"lang": ("l", "ang"),
"lao": ("l", "ao"),
"le": ("l", "e"),
"lei": ("l", "ei"),
"leng": ("l", "eng"),
"li": ("l", "i"),
"lia": ("l", "ia"),
"lian": ("l", "ian"),
"liang": ("l", "iang"),
"liao": ("l", "iao"),
"lie": ("l", "ie"),
"lin": ("l", "in"),
"ling": ("l", "ing"),
"liu": ("l", "iou"),
"lo": ("l", "o"),
"long": ("l", "ong"),
"lou": ("l", "ou"),
"lu": ("l", "u"),
"lv": ("l", "v"),
"luan": ("l", "uan"),
"lve": ("l", "ve"),
"lue": ("l", "ve"),
"lun": ("l", "uen"),
"luo": ("l", "uo"),
"ma": ("m", "a"),
"mai": ("m", "ai"),
"man": ("m", "an"),
"mang": ("m", "ang"),
"mao": ("m", "ao"),
"me": ("m", "e"),
"mei": ("m", "ei"),
"men": ("m", "en"),
"meng": ("m", "eng"),
"mi": ("m", "i"),
"mian": ("m", "ian"),
"miao": ("m", "iao"),
"mie": ("m", "ie"),
"min": ("m", "in"),
"ming": ("m", "ing"),
"miu": ("m", "iou"),
"mo": ("m", "o"),
"mou": ("m", "ou"),
"mu": ("m", "u"),
"na": ("n", "a"),
"nai": ("n", "ai"),
"nan": ("n", "an"),
"nang": ("n", "ang"),
"nao": ("n", "ao"),
"ne": ("n", "e"),
"nei": ("n", "ei"),
"nen": ("n", "en"),
"neng": ("n", "eng"),
"ni": ("n", "i"),
"nia": ("n", "ia"),
"nian": ("n", "ian"),
"niang": ("n", "iang"),
"niao": ("n", "iao"),
"nie": ("n", "ie"),
"nin": ("n", "in"),
"ning": ("n", "ing"),
"niu": ("n", "iou"),
"nong": ("n", "ong"),
"nou": ("n", "ou"),
"nu": ("n", "u"),
"nv": ("n", "v"),
"nuan": ("n", "uan"),
"nve": ("n", "ve"),
"nue": ("n", "ve"),
"nuo": ("n", "uo"),
"o": ("^", "o"),
"ou": ("^", "ou"),
"pa": ("p", "a"),
"pai": ("p", "ai"),
"pan": ("p", "an"),
"pang": ("p", "ang"),
"pao": ("p", "ao"),
"pe": ("p", "e"),
"pei": ("p", "ei"),
"pen": ("p", "en"),
"peng": ("p", "eng"),
"pi": ("p", "i"),
"pian": ("p", "ian"),
"piao": ("p", "iao"),
"pie": ("p", "ie"),
"pin": ("p", "in"),
"ping": ("p", "ing"),
"po": ("p", "o"),
"pou": ("p", "ou"),
"pu": ("p", "u"),
"qi": ("q", "i"),
"qia": ("q", "ia"),
"qian": ("q", "ian"),
"qiang": ("q", "iang"),
"qiao": ("q", "iao"),
"qie": ("q", "ie"),
"qin": ("q", "in"),
"qing": ("q", "ing"),
"qiong": ("q", "iong"),
"qiu": ("q", "iou"),
"qu": ("q", "v"),
"quan": ("q", "van"),
"que": ("q", "ve"),
"qun": ("q", "vn"),
"ran": ("r", "an"),
"rang": ("r", "ang"),
"rao": ("r", "ao"),
"re": ("r", "e"),
"ren": ("r", "en"),
"reng": ("r", "eng"),
"ri": ("r", "iii"),
"rong": ("r", "ong"),
"rou": ("r", "ou"),
"ru": ("r", "u"),
"rua": ("r", "ua"),
"ruan": ("r", "uan"),
"rui": ("r", "uei"),
"run": ("r", "uen"),
"ruo": ("r", "uo"),
"sa": ("s", "a"),
"sai": ("s", "ai"),
"san": ("s", "an"),
"sang": ("s", "ang"),
"sao": ("s", "ao"),
"se": ("s", "e"),
"sen": ("s", "en"),
"seng": ("s", "eng"),
"sha": ("sh", "a"),
"shai": ("sh", "ai"),
"shan": ("sh", "an"),
"shang": ("sh", "ang"),
"shao": ("sh", "ao"),
"she": ("sh", "e"),
"shei": ("sh", "ei"),
"shen": ("sh", "en"),
"sheng": ("sh", "eng"),
"shi": ("sh", "iii"),
"shou": ("sh", "ou"),
"shu": ("sh", "u"),
"shua": ("sh", "ua"),
"shuai": ("sh", "uai"),
"shuan": ("sh", "uan"),
"shuang": ("sh", "uang"),
"shui": ("sh", "uei"),
"shun": ("sh", "uen"),
"shuo": ("sh", "uo"),
"si": ("s", "ii"),
"song": ("s", "ong"),
"sou": ("s", "ou"),
"su": ("s", "u"),
"suan": ("s", "uan"),
"sui": ("s", "uei"),
"sun": ("s", "uen"),
"suo": ("s", "uo"),
"ta": ("t", "a"),
"tai": ("t", "ai"),
"tan": ("t", "an"),
"tang": ("t", "ang"),
"tao": ("t", "ao"),
"te": ("t", "e"),
"tei": ("t", "ei"),
"teng": ("t", "eng"),
"ti": ("t", "i"),
"tian": ("t", "ian"),
"tiao": ("t", "iao"),
"tie": ("t", "ie"),
"ting": ("t", "ing"),
"tong": ("t", "ong"),
"tou": ("t", "ou"),
"tu": ("t", "u"),
"tuan": ("t", "uan"),
"tui": ("t", "uei"),
"tun": ("t", "uen"),
"tuo": ("t", "uo"),
"wa": ("^", "ua"),
"wai": ("^", "uai"),
"wan": ("^", "uan"),
"wang": ("^", "uang"),
"wei": ("^", "uei"),
"wen": ("^", "uen"),
"weng": ("^", "ueng"),
"wo": ("^", "uo"),
"wu": ("^", "u"),
"xi": ("x", "i"),
"xia": ("x", "ia"),
"xian": ("x", "ian"),
"xiang": ("x", "iang"),
"xiao": ("x", "iao"),
"xie": ("x", "ie"),
"xin": ("x", "in"),
"xing": ("x", "ing"),
"xiong": ("x", "iong"),
"xiu": ("x", "iou"),
"xu": ("x", "v"),
"xuan": ("x", "van"),
"xue": ("x", "ve"),
"xun": ("x", "vn"),
"ya": ("^", "ia"),
"yan": ("^", "ian"),
"yang": ("^", "iang"),
"yao": ("^", "iao"),
"ye": ("^", "ie"),
"yi": ("^", "i"),
"yin": ("^", "in"),
"ying": ("^", "ing"),
"yo": ("^", "iou"),
"yong": ("^", "iong"),
"you": ("^", "iou"),
"yu": ("^", "v"),
"yuan": ("^", "van"),
"yue": ("^", "ve"),
"yun": ("^", "vn"),
"za": ("z", "a"),
"zai": ("z", "ai"),
"zan": ("z", "an"),
"zang": ("z", "ang"),
"zao": ("z", "ao"),
"ze": ("z", "e"),
"zei": ("z", "ei"),
"zen": ("z", "en"),
"zeng": ("z", "eng"),
"zha": ("zh", "a"),
"zhai": ("zh", "ai"),
"zhan": ("zh", "an"),
"zhang": ("zh", "ang"),
"zhao": ("zh", "ao"),
"zhe": ("zh", "e"),
"zhei": ("zh", "ei"),
"zhen": ("zh", "en"),
"zheng": ("zh", "eng"),
"zhi": ("zh", "iii"),
"zhong": ("zh", "ong"),
"zhou": ("zh", "ou"),
"zhu": ("zh", "u"),
"zhua": ("zh", "ua"),
"zhuai": ("zh", "uai"),
"zhuan": ("zh", "uan"),
"zhuang": ("zh", "uang"),
"zhui": ("zh", "uei"),
"zhun": ("zh", "uen"),
"zhuo": ("zh", "uo"),
"zi": ("z", "ii"),
"zong": ("z", "ong"),
"zou": ("z", "ou"),
"zu": ("z", "u"),
"zuan": ("z", "uan"),
"zui": ("z", "uei"),
"zun": ("z", "uen"),
"zuo": ("z", "uo"),
}
_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
_initials = [
"^",
"b",
"c",
"ch",
"d",
"f",
"g",
"h",
"j",
"k",
"l",
"m",
"n",
"p",
"q",
"r",
"s",
"sh",
"t",
"x",
"z",
"zh",
]
_tones = ["1", "2", "3", "4", "5"]
_finals = [
"a",
"ai",
"an",
"ang",
"ao",
"e",
"ei",
"en",
"eng",
"er",
"i",
"ia",
"ian",
"iang",
"iao",
"ie",
"ii",
"iii",
"in",
"ing",
"iong",
"iou",
"o",
"ong",
"ou",
"u",
"ua",
"uai",
"uan",
"uang",
"uei",
"uen",
"ueng",
"uo",
"v",
"van",
"ve",
"vn",
]
symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
\ No newline at end of file
import torch
from torch.nn import functional as F
import numpy as np
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet
import os
import glob
import sys
import argparse
import logging
import json
import subprocess
import numpy as np
from scipy.io.wavfile import read
import torch
MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if optimizer is not None:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
except:
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
logger.info(
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
)
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def load_model(checkpoint_path, model):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
except:
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
return model
def save_model(model, checkpoint_path):
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save({'model': state_dict}, checkpoint_path)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
print(x)
return x
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding="utf-8") as f:
filepaths_and_text = []
for line in f:
path_text = line.strip().split(split)
filepaths_and_text.append(path_text)
return filepaths_and_text
def get_hparams(init=True):
parent_dir = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
type=str,
default=parent_dir + "/configs/bert_vits.json",
help="JSON file for configuration",
)
parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
args = parser.parse_args()
model_dir = os.path.join("./logs", args.model)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
config_path = args.config
config_save_path = os.path.join(model_dir, "config.json")
if init:
with open(config_path, "r") as f:
data = f.read()
with open(config_save_path, "w") as f:
f.write(data)
else:
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else:
open(path, "w").write(cur_hash)
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
h = logging.FileHandler(os.path.join(model_dir, filename))
h.setLevel(logging.DEBUG)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
import re
from pypinyin import Style
from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
from pypinyin.converter import DefaultConverter
from pypinyin.core import Pinyin
from text import pinyin_dict
from bert import TTSProsody
class MyConverter(NeutralToneWith5Mixin, DefaultConverter):
pass
def is_chinese(uchar):
if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
return True
else:
return False
def clean_chinese(text: str):
text = text.strip()
text_clean = []
for char in text:
if (is_chinese(char)):
text_clean.append(char)
else:
if len(text_clean) > 1 and is_chinese(text_clean[-1]):
text_clean.append(',')
text_clean = ''.join(text_clean).strip(',')
return text_clean
class VITS_PinYin:
def __init__(self, bert_path, device):
self.pinyin_parser = Pinyin(MyConverter())
self.prosody = TTSProsody(bert_path, device)
def get_phoneme4pinyin(self, pinyins):
result = []
count_phone = []
for pinyin in pinyins:
if pinyin[:-1] in pinyin_dict:
tone = pinyin[-1]
a = pinyin[:-1]
a1, a2 = pinyin_dict[a]
result += [a1, a2 + tone]
count_phone.append(2)
return result, count_phone
def chinese_to_phonemes(self, text):
text = clean_chinese(text)
phonemes = ["sil"]
chars = ['[PAD]']
count_phone = []
count_phone.append(1)
for subtext in text.split(","):
if (len(subtext) == 0):
continue
pinyins = self.correct_pinyin_tone3(subtext)
sub_p, sub_c = self.get_phoneme4pinyin(pinyins)
phonemes.extend(sub_p)
phonemes.append("sp")
count_phone.extend(sub_c)
count_phone.append(1)
chars.append(subtext)
chars.append(',')
phonemes.append("sil")
count_phone.append(1)
chars.append('[PAD]')
chars = "".join(chars)
char_embeds = self.prosody.get_char_embeds(chars)
char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone)
return " ".join(phonemes), char_embeds
def correct_pinyin_tone3(self, text):
pinyin_list = [p[0] for p in self.pinyin_parser.pinyin(
text, style=Style.TONE3, strict=False, neutral_tone_with_five=True)]
if len(pinyin_list) >= 2:
for i in range(1, len(pinyin_list)):
try:
if re.findall(r'\d', pinyin_list[i-1])[0] == '3' and re.findall(r'\d', pinyin_list[i])[0] == '3':
pinyin_list[i-1] = pinyin_list[i-1].replace('3', '2')
except IndexError:
pass
return pinyin_list
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment