Bert-VITS2-RKNN2 / rknn_run.py
happyme531's picture
Upload 13 files
6cec077 verified
import numpy as np
import onnxruntime as ort
from rknnlite.api.rknn_lite import RKNNLite
import numpy as np
import soundfile as sf
from transformers import AutoTokenizer
import time
import os
import re
import cn2an
from pypinyin import lazy_pinyin, Style
from typing import List
from typing import Tuple
import jieba
import jieba.posseg as psg
def convert_pad_shape(pad_shape):
layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist]
return pad_shape
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = np.arange(max_length, dtype=length.dtype)
return np.expand_dims(x, 0) < np.expand_dims(length, 1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
b, _, t_y, t_x = mask.shape
cum_duration = np.cumsum(duration, -1)
cum_duration_flat = cum_duration.reshape(b * t_x)
path = sequence_mask(cum_duration_flat, t_y)
path = path.reshape(b, t_x, t_y)
path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1]
path = np.expand_dims(path, 1).transpose(0, 1, 3, 2)
return path
class InferenceSession:
def __init__(self, path, Providers=["CPUExecutionProvider"]):
ort_config = ort.SessionOptions()
ort_config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
ort_config.intra_op_num_threads = 4
ort_config.inter_op_num_threads = 4
self.enc = ort.InferenceSession(path["enc"], providers=Providers, sess_options=ort_config)
self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers, sess_options=ort_config)
self.dp = ort.InferenceSession(path["dp"], providers=Providers, sess_options=ort_config)
self.sdp = ort.InferenceSession(path["sdp"], providers=Providers, sess_options=ort_config)
# flow模型用onnx比rknn快
# self.flow = RKNNLite(verbose=False)
# self.flow.load_rknn(path["flow"])
# self.flow.init_runtime(core_mask=RKNNLite.NPU_CORE_1)
self.flow = ort.InferenceSession(path["flow"], providers=Providers, sess_options=ort_config)
self.dec = RKNNLite(verbose=False)
self.dec.load_rknn(path["dec"])
self.dec.init_runtime()
# self.dec = ort.InferenceSession(path["dec"], providers=Providers, sess_options=ort_config)
def __call__(
self,
seq,
tone,
language,
bert_zh,
bert_jp,
bert_en,
vqidx,
sid,
seed=114514,
seq_noise_scale=0.8,
sdp_noise_scale=0.6,
length_scale=1.0,
sdp_ratio=0.0,
rknn_pad_to = 1024
):
if seq.ndim == 1:
seq = np.expand_dims(seq, 0)
if tone.ndim == 1:
tone = np.expand_dims(tone, 0)
if language.ndim == 1:
language = np.expand_dims(language, 0)
assert (seq.ndim == 2, tone.ndim == 2, language.ndim == 2)
start_time = time.time()
g = self.emb_g.run(
None,
{
"sid": sid.astype(np.int64),
},
)[0]
emb_g_time = time.time() - start_time
print(f"emb_g 运行时间: {emb_g_time:.4f} 秒")
g = np.expand_dims(g, -1)
start_time = time.time()
enc_rtn = self.enc.run(
None,
{
"x": seq.astype(np.int64),
"t": tone.astype(np.int64),
"language": language.astype(np.int64),
"bert_0": bert_zh.astype(np.float32),
"bert_1": bert_jp.astype(np.float32),
"bert_2": bert_en.astype(np.float32),
"g": g.astype(np.float32),
# 2.3版本的模型需要注释掉下面两行
"vqidx": vqidx.astype(np.int64),
"sid": sid.astype(np.int64),
},
)
enc_time = time.time() - start_time
print(f"enc 运行时间: {enc_time:.4f} 秒")
x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3]
np.random.seed(seed)
zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale
start_time = time.time()
sdp_output = self.sdp.run(
None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g}
)[0]
sdp_time = time.time() - start_time
print(f"sdp 运行时间: {sdp_time:.4f} 秒")
start_time = time.time()
dp_output = self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[0]
dp_time = time.time() - start_time
print(f"dp 运行时间: {dp_time:.4f} 秒")
logw = sdp_output * (sdp_ratio) + dp_output * (1 - sdp_ratio)
w = np.exp(logw) * x_mask * length_scale
w_ceil = np.ceil(w)
y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype(
np.int64
)
y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1)
attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1)
attn = generate_path(w_ceil, attn_mask)
m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose(
0, 2, 1
) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose(
0, 2, 1
) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = (
m_p
+ np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2])
* np.exp(logs_p)
* seq_noise_scale
)
#truncate to rknn_pad_to
actual_len = z_p.shape[2]
if actual_len > rknn_pad_to:
print("警告, 输入长度超过 rknn_pad_to, 将被截断")
z_p = z_p[:,:,:rknn_pad_to]
y_mask = y_mask[:,:,:rknn_pad_to]
else:
z_p = np.pad(z_p, ((0, 0), (0, 0), (0, rknn_pad_to - z_p.shape[2])))
y_mask = np.pad(y_mask, ((0, 0), (0, 0), (0, rknn_pad_to - y_mask.shape[2])))
start_time = time.time()
z = self.flow.run(
None,
{
"z_p": z_p.astype(np.float32),
"y_mask": y_mask.astype(np.float32),
"g": g,
},
)[0]
flow_time = time.time() - start_time
print(f"flow 运行时间: {flow_time:.4f} 秒")
start_time = time.time()
dec_output = self.dec.inference([z.astype(np.float32), g])[0]
dec_time = time.time() - start_time
print(f"dec 运行时间: {dec_time:.4f} 秒")
# truncate to actual_len*512
return dec_output[:,:,:actual_len*512]
class ToneSandhi:
def __init__(self):
self.must_neural_tone_words = {
"麻烦",
"麻利",
"鸳鸯",
"高粱",
"骨头",
"骆驼",
"马虎",
"首饰",
"馒头",
"馄饨",
"风筝",
"难为",
"队伍",
"阔气",
"闺女",
"门道",
"锄头",
"铺盖",
"铃铛",
"铁匠",
"钥匙",
"里脊",
"里头",
"部分",
"那么",
"道士",
"造化",
"迷糊",
"连累",
"这么",
"这个",
"运气",
"过去",
"软和",
"转悠",
"踏实",
"跳蚤",
"跟头",
"趔趄",
"财主",
"豆腐",
"讲究",
"记性",
"记号",
"认识",
"规矩",
"见识",
"裁缝",
"补丁",
"衣裳",
"衣服",
"衙门",
"街坊",
"行李",
"行当",
"蛤蟆",
"蘑菇",
"薄荷",
"葫芦",
"葡萄",
"萝卜",
"荸荠",
"苗条",
"苗头",
"苍蝇",
"芝麻",
"舒服",
"舒坦",
"舌头",
"自在",
"膏药",
"脾气",
"脑袋",
"脊梁",
"能耐",
"胳膊",
"胭脂",
"胡萝",
"胡琴",
"胡同",
"聪明",
"耽误",
"耽搁",
"耷拉",
"耳朵",
"老爷",
"老实",
"老婆",
"老头",
"老太",
"翻腾",
"罗嗦",
"罐头",
"编辑",
"结实",
"红火",
"累赘",
"糨糊",
"糊涂",
"精神",
"粮食",
"簸箕",
"篱笆",
"算计",
"算盘",
"答应",
"笤帚",
"笑语",
"笑话",
"窟窿",
"窝囊",
"窗户",
"稳当",
"稀罕",
"称呼",
"秧歌",
"秀气",
"秀才",
"福气",
"祖宗",
"砚台",
"码头",
"石榴",
"石头",
"石匠",
"知识",
"眼睛",
"眯缝",
"眨巴",
"眉毛",
"相声",
"盘算",
"白净",
"痢疾",
"痛快",
"疟疾",
"疙瘩",
"疏忽",
"畜生",
"生意",
"甘蔗",
"琵琶",
"琢磨",
"琉璃",
"玻璃",
"玫瑰",
"玄乎",
"狐狸",
"状元",
"特务",
"牲口",
"牙碜",
"牌楼",
"爽快",
"爱人",
"热闹",
"烧饼",
"烟筒",
"烂糊",
"点心",
"炊帚",
"灯笼",
"火候",
"漂亮",
"滑溜",
"溜达",
"温和",
"清楚",
"消息",
"浪头",
"活泼",
"比方",
"正经",
"欺负",
"模糊",
"槟榔",
"棺材",
"棒槌",
"棉花",
"核桃",
"栅栏",
"柴火",
"架势",
"枕头",
"枇杷",
"机灵",
"本事",
"木头",
"木匠",
"朋友",
"月饼",
"月亮",
"暖和",
"明白",
"时候",
"新鲜",
"故事",
"收拾",
"收成",
"提防",
"挖苦",
"挑剔",
"指甲",
"指头",
"拾掇",
"拳头",
"拨弄",
"招牌",
"招呼",
"抬举",
"护士",
"折腾",
"扫帚",
"打量",
"打算",
"打点",
"打扮",
"打听",
"打发",
"扎实",
"扁担",
"戒指",
"懒得",
"意识",
"意思",
"情形",
"悟性",
"怪物",
"思量",
"怎么",
"念头",
"念叨",
"快活",
"忙活",
"志气",
"心思",
"得罪",
"张罗",
"弟兄",
"开通",
"应酬",
"庄稼",
"干事",
"帮手",
"帐篷",
"希罕",
"师父",
"师傅",
"巴结",
"巴掌",
"差事",
"工夫",
"岁数",
"屁股",
"尾巴",
"少爷",
"小气",
"小伙",
"将就",
"对头",
"对付",
"寡妇",
"家伙",
"客气",
"实在",
"官司",
"学问",
"学生",
"字号",
"嫁妆",
"媳妇",
"媒人",
"婆家",
"娘家",
"委屈",
"姑娘",
"姐夫",
"妯娌",
"妥当",
"妖精",
"奴才",
"女婿",
"头发",
"太阳",
"大爷",
"大方",
"大意",
"大夫",
"多少",
"多么",
"外甥",
"壮实",
"地道",
"地方",
"在乎",
"困难",
"嘴巴",
"嘱咐",
"嘟囔",
"嘀咕",
"喜欢",
"喇嘛",
"喇叭",
"商量",
"唾沫",
"哑巴",
"哈欠",
"哆嗦",
"咳嗽",
"和尚",
"告诉",
"告示",
"含糊",
"吓唬",
"后头",
"名字",
"名堂",
"合同",
"吆喝",
"叫唤",
"口袋",
"厚道",
"厉害",
"千斤",
"包袱",
"包涵",
"匀称",
"勤快",
"动静",
"动弹",
"功夫",
"力气",
"前头",
"刺猬",
"刺激",
"别扭",
"利落",
"利索",
"利害",
"分析",
"出息",
"凑合",
"凉快",
"冷战",
"冤枉",
"冒失",
"养活",
"关系",
"先生",
"兄弟",
"便宜",
"使唤",
"佩服",
"作坊",
"体面",
"位置",
"似的",
"伙计",
"休息",
"什么",
"人家",
"亲戚",
"亲家",
"交情",
"云彩",
"事情",
"买卖",
"主意",
"丫头",
"丧气",
"两口",
"东西",
"东家",
"世故",
"不由",
"不在",
"下水",
"下巴",
"上头",
"上司",
"丈夫",
"丈人",
"一辈",
"那个",
"菩萨",
"父亲",
"母亲",
"咕噜",
"邋遢",
"费用",
"冤家",
"甜头",
"介绍",
"荒唐",
"大人",
"泥鳅",
"幸福",
"熟悉",
"计划",
"扑腾",
"蜡烛",
"姥爷",
"照顾",
"喉咙",
"吉他",
"弄堂",
"蚂蚱",
"凤凰",
"拖沓",
"寒碜",
"糟蹋",
"倒腾",
"报复",
"逻辑",
"盘缠",
"喽啰",
"牢骚",
"咖喱",
"扫把",
"惦记",
}
self.must_not_neural_tone_words = {
"男子",
"女子",
"分子",
"原子",
"量子",
"莲子",
"石子",
"瓜子",
"电子",
"人人",
"虎虎",
}
self.punc = ":,;。?!“”‘’':,;.?!"
# the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
# e.g.
# word: "家里"
# pos: "s"
# finals: ['ia1', 'i3']
def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
# reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
for j, item in enumerate(word):
if (
j - 1 >= 0
and item == word[j - 1]
and pos[0] in {"n", "v", "a"}
and word not in self.must_not_neural_tone_words
):
finals[j] = finals[j][:-1] + "5"
ge_idx = word.find("个")
if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
finals[-1] = finals[-1][:-1] + "5"
elif len(word) >= 1 and word[-1] in "的地得":
finals[-1] = finals[-1][:-1] + "5"
# e.g. 走了, 看着, 去过
# elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
# finals[-1] = finals[-1][:-1] + "5"
elif (
len(word) > 1
and word[-1] in "们子"
and pos in {"r", "n"}
and word not in self.must_not_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5"
# e.g. 桌上, 地下, 家里
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
finals[-1] = finals[-1][:-1] + "5"
# e.g. 上来, 下去
elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
finals[-1] = finals[-1][:-1] + "5"
# 个做量词
elif (
ge_idx >= 1
and (
word[ge_idx - 1].isnumeric()
or word[ge_idx - 1] in "几有两半多各整每做是"
)
) or word == "个":
finals[ge_idx] = finals[ge_idx][:-1] + "5"
else:
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5"
word_list = self._split_word(word)
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
for i, word in enumerate(word_list):
# conventional neural in Chinese
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
finals = sum(finals_list, [])
return finals
def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
# e.g. 看不懂
if len(word) == 3 and word[1] == "不":
finals[1] = finals[1][:-1] + "5"
else:
for i, char in enumerate(word):
# "不" before tone4 should be bu2, e.g. 不怕
if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4":
finals[i] = finals[i][:-1] + "2"
return finals
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
# "一" in number sequences, e.g. 一零零, 二一零
if word.find("一") != -1 and all(
[item.isnumeric() for item in word if item != "一"]
):
return finals
# "一" between reduplication words should be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
finals[1] = finals[1][:-1] + "5"
# when "一" is ordinal word, it should be yi1
elif word.startswith("第一"):
finals[1] = finals[1][:-1] + "1"
else:
for i, char in enumerate(word):
if char == "一" and i + 1 < len(word):
# "一" before tone4 should be yi2, e.g. 一段
if finals[i + 1][-1] == "4":
finals[i] = finals[i][:-1] + "2"
# "一" before non-tone4 should be yi4, e.g. 一天
else:
# "一" 后面如果是标点,还读一声
if word[i + 1] not in self.punc:
finals[i] = finals[i][:-1] + "4"
return finals
def _split_word(self, word: str) -> List[str]:
word_list = jieba.cut_for_search(word)
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
first_subword = word_list[0]
first_begin_idx = word.find(first_subword)
if first_begin_idx == 0:
second_subword = word[len(first_subword) :]
new_word_list = [first_subword, second_subword]
else:
second_subword = word[: -len(first_subword)]
new_word_list = [second_subword, first_subword]
return new_word_list
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
if len(word) == 2 and self._all_tone_three(finals):
finals[0] = finals[0][:-1] + "2"
elif len(word) == 3:
word_list = self._split_word(word)
if self._all_tone_three(finals):
# disyllabic + monosyllabic, e.g. 蒙古/包
if len(word_list[0]) == 2:
finals[0] = finals[0][:-1] + "2"
finals[1] = finals[1][:-1] + "2"
# monosyllabic + disyllabic, e.g. 纸/老虎
elif len(word_list[0]) == 1:
finals[1] = finals[1][:-1] + "2"
else:
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
if len(finals_list) == 2:
for i, sub in enumerate(finals_list):
# e.g. 所有/人
if self._all_tone_three(sub) and len(sub) == 2:
finals_list[i][0] = finals_list[i][0][:-1] + "2"
# e.g. 好/喜欢
elif (
i == 1
and not self._all_tone_three(sub)
and finals_list[i][0][-1] == "3"
and finals_list[0][-1][-1] == "3"
):
finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
finals = sum(finals_list, [])
# split idiom into two words who's length is 2
elif len(word) == 4:
finals_list = [finals[:2], finals[2:]]
finals = []
for sub in finals_list:
if self._all_tone_three(sub):
sub[0] = sub[0][:-1] + "2"
finals += sub
return finals
def _all_tone_three(self, finals: List[str]) -> bool:
return all(x[-1] == "3" for x in finals)
# merge "不" and the word behind it
# if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
last_word = ""
for word, pos in seg:
if last_word == "不":
word = last_word + word
if word != "不":
new_seg.append((word, pos))
last_word = word[:]
if last_word == "不":
new_seg.append((last_word, "d"))
last_word = ""
return new_seg
# function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
# function 2: merge single "一" and the word behind it
# if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
# e.g.
# input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
# output seg: [['听一听', 'v']]
def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
# function 1
for i, (word, pos) in enumerate(seg):
if (
i - 1 >= 0
and word == "一"
and i + 1 < len(seg)
and seg[i - 1][0] == seg[i + 1][0]
and seg[i - 1][1] == "v"
):
new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
else:
if (
i - 2 >= 0
and seg[i - 1][0] == "一"
and seg[i - 2][0] == word
and pos == "v"
):
continue
else:
new_seg.append([word, pos])
seg = new_seg
new_seg = []
# function 2
for i, (word, pos) in enumerate(seg):
if new_seg and new_seg[-1][0] == "一":
new_seg[-1][0] = new_seg[-1][0] + word
else:
new_seg.append([word, pos])
return new_seg
# the first and the second words are all_tone_three
def _merge_continuous_three_tones(
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg):
if (
i - 1 >= 0
and self._all_tone_three(sub_finals_list[i - 1])
and self._all_tone_three(sub_finals_list[i])
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:
new_seg.append([word, pos])
else:
new_seg.append([word, pos])
return new_seg
def _is_reduplication(self, word: str) -> bool:
return len(word) == 2 and word[0] == word[1]
# the last char of first word and the first char of second word is tone_three
def _merge_continuous_three_tones_2(
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg):
if (
i - 1 >= 0
and sub_finals_list[i - 1][-1][-1] == "3"
and sub_finals_list[i][0][-1] == "3"
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:
new_seg.append([word, pos])
else:
new_seg.append([word, pos])
return new_seg
def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#":
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
else:
new_seg.append([word, pos])
return new_seg
def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
for i, (word, pos) in enumerate(seg):
if new_seg and word == new_seg[-1][0]:
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
else:
new_seg.append([word, pos])
return new_seg
def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
seg = self._merge_bu(seg)
try:
seg = self._merge_yi(seg)
except:
print("_merge_yi failed")
seg = self._merge_reduplication(seg)
seg = self._merge_continuous_three_tones(seg)
seg = self._merge_continuous_three_tones_2(seg)
seg = self._merge_er(seg)
return seg
def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals)
finals = self._three_sandhi(word, finals)
return finals
punctuation = ["!", "?", "…", ",", ".", "'", "-"]
pu_symbols = punctuation + ["SP", "UNK"]
pad = "_"
# chinese
zh_symbols = [
"E",
"En",
"a",
"ai",
"an",
"ang",
"ao",
"b",
"c",
"ch",
"d",
"e",
"ei",
"en",
"eng",
"er",
"f",
"g",
"h",
"i",
"i0",
"ia",
"ian",
"iang",
"iao",
"ie",
"in",
"ing",
"iong",
"ir",
"iu",
"j",
"k",
"l",
"m",
"n",
"o",
"ong",
"ou",
"p",
"q",
"r",
"s",
"sh",
"t",
"u",
"ua",
"uai",
"uan",
"uang",
"ui",
"un",
"uo",
"v",
"van",
"ve",
"vn",
"w",
"x",
"y",
"z",
"zh",
"AA",
"EE",
"OO",
]
num_zh_tones = 6
# japanese
ja_symbols = [
"N",
"a",
"a:",
"b",
"by",
"ch",
"d",
"dy",
"e",
"e:",
"f",
"g",
"gy",
"h",
"hy",
"i",
"i:",
"j",
"k",
"ky",
"m",
"my",
"n",
"ny",
"o",
"o:",
"p",
"py",
"q",
"r",
"ry",
"s",
"sh",
"t",
"ts",
"ty",
"u",
"u:",
"w",
"y",
"z",
"zy",
]
num_ja_tones = 2
# English
en_symbols = [
"aa",
"ae",
"ah",
"ao",
"aw",
"ay",
"b",
"ch",
"d",
"dh",
"eh",
"er",
"ey",
"f",
"g",
"hh",
"ih",
"iy",
"jh",
"k",
"l",
"m",
"n",
"ng",
"ow",
"oy",
"p",
"r",
"s",
"sh",
"t",
"th",
"uh",
"uw",
"V",
"w",
"y",
"z",
"zh",
]
num_en_tones = 4
# combine all symbols
normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
symbols = [pad] + normal_symbols + pu_symbols
sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
# combine all tones
num_tones = num_zh_tones + num_ja_tones + num_en_tones
# language maps
language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
num_languages = len(language_id_map.keys())
language_tone_start_map = {
"ZH": 0,
"JP": num_zh_tones,
"EN": num_zh_tones + num_ja_tones,
}
current_file_path = os.path.dirname(__file__)
pinyin_to_symbol_map = {
line.split("\t")[0]: line.strip().split("\t")[1]
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
}
rep_map = {
":": ",",
";": ",",
",": ",",
"。": ".",
"!": "!",
"?": "?",
"\n": ".",
"·": ",",
"、": ",",
"...": "…",
"$": ".",
"“": "'",
"”": "'",
'"': "'",
"‘": "'",
"’": "'",
"(": "'",
")": "'",
"(": "'",
")": "'",
"《": "'",
"》": "'",
"【": "'",
"】": "'",
"[": "'",
"]": "'",
"—": "-",
"~": "-",
"~": "-",
"「": "'",
"」": "'",
}
tone_modifier = ToneSandhi()
def replace_punctuation(text):
text = text.replace("嗯", "恩").replace("呣", "母")
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text
def g2p(text):
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
phones, tones, word2ph = _g2p(sentences)
assert sum(word2ph) == len(phones)
assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
phones = ["_"] + phones + ["_"]
tones = [0] + tones + [0]
word2ph = [1] + word2ph + [1]
return phones, tones, word2ph
def _get_initials_finals(word):
initials = []
finals = []
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
for c, v in zip(orig_initials, orig_finals):
initials.append(c)
finals.append(v)
return initials, finals
def _g2p(segments):
phones_list = []
tones_list = []
word2ph = []
for seg in segments:
# Replace all English words in the sentence
seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg)
initials = []
finals = []
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
for word, pos in seg_cut:
if pos == "eng":
continue
sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, [])
finals = sum(finals, [])
#
for c, v in zip(initials, finals):
raw_pinyin = c + v
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c == v:
assert c in punctuation
phone = [c]
tone = "0"
word2ph.append(1)
else:
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c + v_without_tone
assert tone in "12345"
if c:
# 多音节
v_rep_map = {
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone]
else:
# 单音节
pinyin_rep_map = {
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
phone = pinyin_to_symbol_map[pinyin].split(" ")
word2ph.append(len(phone))
phones_list += phone
tones_list += [int(tone)] * len(phone)
return phones_list, tones_list, word2ph
def text_normalize(text):
numbers = re.findall(r"\d+(?:\.?\d+)?", text)
for number in numbers:
text = text.replace(number, cn2an.an2cn(number), 1)
text = replace_punctuation(text)
return text
def get_bert_feature(
text,
word2ph,
style_text=None,
style_weight=0.7,
):
global bert_model
# 使用tokenizer处理输入文本
inputs = tokenizer(text, return_tensors="np",padding="max_length",truncation=True,max_length=256)
# 运行ONNX模型
start_time = time.time()
res = bert_model.inference([inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"]])
flow_time = time.time() - start_time
print(f"bert 运行时间: {flow_time:.4f} 秒")
# 处理输出
# res = np.concatenate(res[0], -1)[0]
res = res[0][0]
if style_text:
assert False # TODO
# style_inputs = tokenizer(style_text, return_tensors="np")
# style_onnx_inputs = {name: style_inputs[name] for name in bert_model.get_inputs()}
# style_res = bert_model.run(None, style_onnx_inputs)
# style_hidden_states = style_res[-1]
# style_res = np.concatenate(style_hidden_states[-3:-2], -1)[0]
# style_res_mean = style_res.mean(0)
assert len(word2ph) == len(text) + 2
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
if style_text:
repeat_feature = (
res[i].repeat(word2phone[i], 1) * (1 - style_weight)
# + style_res_mean.repeat(word2phone[i], 1) * style_weight
)
else:
repeat_feature = np.tile(res[i], (word2phone[i], 1))
phone_level_feature.append(repeat_feature)
phone_level_feature = np.concatenate(phone_level_feature, axis=0)
return phone_level_feature.T
def clean_text(text, language):
norm_text = text_normalize(text)
phones, tones, word2ph = g2p(norm_text)
return norm_text, phones, tones, word2ph
def clean_text_bert(text, language):
norm_text = text_normalize(text)
phones, tones, word2ph = g2p(norm_text)
bert = get_bert_feature(norm_text, word2ph)
return phones, tones, bert
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
def cleaned_text_to_sequence(cleaned_text, tones, language):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
"""
phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
tone_start = language_tone_start_map[language]
tones = [i + tone_start for i in tones]
lang_id = language_id_map[language]
lang_ids = [lang_id for i in phones]
return phones, tones, lang_ids
def text_to_sequence(text, language):
norm_text, phones, tones, word2ph = clean_text(text, language)
return cleaned_text_to_sequence(phones, tones, language)
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def get_text(text, language_str, style_text=None, style_weight=0.7, add_blank=False):
# 在此处实现当前版本的get_text
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
if add_blank:
phone = intersperse(phone, 0)
tone = intersperse(tone, 0)
language = intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert_ori = get_bert_feature(
norm_text, word2ph, style_text, style_weight
)
del word2ph
assert bert_ori.shape[-1] == len(phone), phone
if language_str == "ZH":
bert = bert_ori
ja_bert = np.zeros((1024, len(phone)))
en_bert = np.zeros((1024, len(phone)))
elif language_str == "JP":
bert = np.zeros((1024, len(phone)))
ja_bert = bert_ori
en_bert = np.zeros((1024, len(phone)))
elif language_str == "EN":
bert = np.zeros((1024, len(phone)))
ja_bert = np.zeros((1024, len(phone)))
en_bert = bert_ori
else:
raise ValueError("language_str should be ZH, JP or EN")
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = np.array(phone)
tone = np.array(tone)
language = np.array(language)
return bert, ja_bert, en_bert, phone, tone, language
if __name__ == "__main__":
name = "lx"
model_prefix = f"onnx/{name}/{name}_"
bert_path = "./bert/chinese-roberta-wwm-ext-large"
flow_dec_input_len = 1024
model_sample_rate = 44100
# text = "不必说碧绿的菜畦,光滑的石井栏,高大的皂荚树,紫红的桑葚;也不必说鸣蝉在树叶里长吟,肥胖的黄蜂伏在菜花上,轻捷的叫天子(云雀)忽然从草间直窜向云霄里去了。单是周围的短短的泥墙根一带,就有无限趣味。油蛉在这里低唱, 蟋蟀们在这里弹琴。翻开断砖来,有时会遇见蜈蚣;还有斑蝥,倘若用手指按住它的脊梁,便会“啪”的一声,从后窍喷出一阵烟雾。何首乌藤和木莲藤缠络着,木莲有莲房一般的果实,何首乌有臃肿的根。有人说,何首乌根是有像人形的,吃了便可以成仙,我于是常常拔它起来,牵连不断地拔起来,也曾因此弄坏了泥墙,却从来没有见过有一块根像人样。如果不怕刺,还可以摘到覆盆子,像小珊瑚珠攒成的小球,又酸又甜,色味都比桑葚要好得远。"
text = "我个人认为,这个意大利面就应该拌42号混凝土,因为这个螺丝钉的长度,它很容易会直接影响到挖掘机的扭矩你知道吧。你往里砸的时候,一瞬间它就会产生大量的高能蛋白,俗称ufo,会严重影响经济的发展,甚至对整个太平洋以及充电器都会造成一定的核污染。你知道啊?再者说,根据这个勾股定理,你可以很容易地推断出人工饲养的东条英机,它是可以捕获野生的三角函数的。所以说这个秦始皇的切面是否具有放射性啊,特朗普的N次方是否含有沉淀物,都不影响这个沃尔玛跟维尔康在南极会合。"
global bert_model,tokenizer
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = RKNNLite(verbose=False)
bert_model.load_rknn(bert_path + "/model.rknn")
bert_model.init_runtime()
model = InferenceSession({
"enc": model_prefix + "enc_p.onnx",
"emb_g": model_prefix + "emb.onnx",
"dp": model_prefix + "dp.onnx",
"sdp": model_prefix + "sdp.onnx",
"flow": model_prefix + "flow.onnx",
"dec": model_prefix + "dec.rknn",
})
# 从句号分割
text_seg = re.split(r'(?<=[。!?;])', text)
output_acc = np.array([0.0])
for text in text_seg:
bert, ja_bert, en_bert, phone, tone, language = get_text(text, "ZH", add_blank=True)
bert = np.transpose(bert)
ja_bert = np.transpose(ja_bert)
en_bert = np.transpose(en_bert)
sid = np.array([0])
vqidx = np.array([0])
output = model(phone, tone, language, bert, ja_bert, en_bert, vqidx, sid ,
rknn_pad_to=flow_dec_input_len,
seed=114514,
seq_noise_scale=0.8,
sdp_noise_scale=0.6,
length_scale=1,
sdp_ratio=0,
)[0,0]
output_acc = np.concatenate([output_acc, output])
print(f"已生成长度: {len(output_acc) / model_sample_rate:.2f} 秒")
sf.write('output.wav', output_acc, model_sample_rate)
print("已生成output.wav")