|
import numpy as np |
|
import onnxruntime as ort |
|
from rknnlite.api.rknn_lite import RKNNLite |
|
import numpy as np |
|
import soundfile as sf |
|
from transformers import AutoTokenizer |
|
import time |
|
import os |
|
import re |
|
import cn2an |
|
from pypinyin import lazy_pinyin, Style |
|
from typing import List |
|
from typing import Tuple |
|
import jieba |
|
import jieba.posseg as psg |
|
|
|
def convert_pad_shape(pad_shape): |
|
layer = pad_shape[::-1] |
|
pad_shape = [item for sublist in layer for item in sublist] |
|
return pad_shape |
|
|
|
|
|
def sequence_mask(length, max_length=None): |
|
if max_length is None: |
|
max_length = length.max() |
|
x = np.arange(max_length, dtype=length.dtype) |
|
return np.expand_dims(x, 0) < np.expand_dims(length, 1) |
|
|
|
|
|
def generate_path(duration, mask): |
|
""" |
|
duration: [b, 1, t_x] |
|
mask: [b, 1, t_y, t_x] |
|
""" |
|
|
|
b, _, t_y, t_x = mask.shape |
|
cum_duration = np.cumsum(duration, -1) |
|
|
|
cum_duration_flat = cum_duration.reshape(b * t_x) |
|
path = sequence_mask(cum_duration_flat, t_y) |
|
path = path.reshape(b, t_x, t_y) |
|
path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1] |
|
path = np.expand_dims(path, 1).transpose(0, 1, 3, 2) |
|
return path |
|
|
|
|
|
class InferenceSession: |
|
def __init__(self, path, Providers=["CPUExecutionProvider"]): |
|
ort_config = ort.SessionOptions() |
|
ort_config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
ort_config.intra_op_num_threads = 4 |
|
ort_config.inter_op_num_threads = 4 |
|
self.enc = ort.InferenceSession(path["enc"], providers=Providers, sess_options=ort_config) |
|
self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers, sess_options=ort_config) |
|
self.dp = ort.InferenceSession(path["dp"], providers=Providers, sess_options=ort_config) |
|
self.sdp = ort.InferenceSession(path["sdp"], providers=Providers, sess_options=ort_config) |
|
|
|
|
|
|
|
|
|
self.flow = ort.InferenceSession(path["flow"], providers=Providers, sess_options=ort_config) |
|
self.dec = RKNNLite(verbose=False) |
|
self.dec.load_rknn(path["dec"]) |
|
self.dec.init_runtime() |
|
|
|
|
|
def __call__( |
|
self, |
|
seq, |
|
tone, |
|
language, |
|
bert_zh, |
|
bert_jp, |
|
bert_en, |
|
vqidx, |
|
sid, |
|
seed=114514, |
|
seq_noise_scale=0.8, |
|
sdp_noise_scale=0.6, |
|
length_scale=1.0, |
|
sdp_ratio=0.0, |
|
rknn_pad_to = 1024 |
|
): |
|
if seq.ndim == 1: |
|
seq = np.expand_dims(seq, 0) |
|
if tone.ndim == 1: |
|
tone = np.expand_dims(tone, 0) |
|
if language.ndim == 1: |
|
language = np.expand_dims(language, 0) |
|
assert (seq.ndim == 2, tone.ndim == 2, language.ndim == 2) |
|
|
|
start_time = time.time() |
|
g = self.emb_g.run( |
|
None, |
|
{ |
|
"sid": sid.astype(np.int64), |
|
}, |
|
)[0] |
|
emb_g_time = time.time() - start_time |
|
print(f"emb_g 运行时间: {emb_g_time:.4f} 秒") |
|
|
|
g = np.expand_dims(g, -1) |
|
start_time = time.time() |
|
enc_rtn = self.enc.run( |
|
None, |
|
{ |
|
"x": seq.astype(np.int64), |
|
"t": tone.astype(np.int64), |
|
"language": language.astype(np.int64), |
|
"bert_0": bert_zh.astype(np.float32), |
|
"bert_1": bert_jp.astype(np.float32), |
|
"bert_2": bert_en.astype(np.float32), |
|
"g": g.astype(np.float32), |
|
|
|
"vqidx": vqidx.astype(np.int64), |
|
"sid": sid.astype(np.int64), |
|
}, |
|
) |
|
enc_time = time.time() - start_time |
|
print(f"enc 运行时间: {enc_time:.4f} 秒") |
|
|
|
x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3] |
|
np.random.seed(seed) |
|
zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale |
|
|
|
start_time = time.time() |
|
sdp_output = self.sdp.run( |
|
None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g} |
|
)[0] |
|
sdp_time = time.time() - start_time |
|
print(f"sdp 运行时间: {sdp_time:.4f} 秒") |
|
|
|
start_time = time.time() |
|
dp_output = self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[0] |
|
dp_time = time.time() - start_time |
|
print(f"dp 运行时间: {dp_time:.4f} 秒") |
|
|
|
logw = sdp_output * (sdp_ratio) + dp_output * (1 - sdp_ratio) |
|
w = np.exp(logw) * x_mask * length_scale |
|
w_ceil = np.ceil(w) |
|
y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype( |
|
np.int64 |
|
) |
|
y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1) |
|
attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1) |
|
attn = generate_path(w_ceil, attn_mask) |
|
m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose( |
|
0, 2, 1 |
|
) |
|
logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose( |
|
0, 2, 1 |
|
) |
|
|
|
z_p = ( |
|
m_p |
|
+ np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2]) |
|
* np.exp(logs_p) |
|
* seq_noise_scale |
|
) |
|
|
|
actual_len = z_p.shape[2] |
|
if actual_len > rknn_pad_to: |
|
print("警告, 输入长度超过 rknn_pad_to, 将被截断") |
|
z_p = z_p[:,:,:rknn_pad_to] |
|
y_mask = y_mask[:,:,:rknn_pad_to] |
|
else: |
|
z_p = np.pad(z_p, ((0, 0), (0, 0), (0, rknn_pad_to - z_p.shape[2]))) |
|
y_mask = np.pad(y_mask, ((0, 0), (0, 0), (0, rknn_pad_to - y_mask.shape[2]))) |
|
|
|
start_time = time.time() |
|
z = self.flow.run( |
|
None, |
|
{ |
|
"z_p": z_p.astype(np.float32), |
|
"y_mask": y_mask.astype(np.float32), |
|
"g": g, |
|
}, |
|
)[0] |
|
flow_time = time.time() - start_time |
|
print(f"flow 运行时间: {flow_time:.4f} 秒") |
|
|
|
start_time = time.time() |
|
dec_output = self.dec.inference([z.astype(np.float32), g])[0] |
|
dec_time = time.time() - start_time |
|
print(f"dec 运行时间: {dec_time:.4f} 秒") |
|
|
|
|
|
return dec_output[:,:,:actual_len*512] |
|
|
|
|
|
|
|
|
|
class ToneSandhi: |
|
def __init__(self): |
|
self.must_neural_tone_words = { |
|
"麻烦", |
|
"麻利", |
|
"鸳鸯", |
|
"高粱", |
|
"骨头", |
|
"骆驼", |
|
"马虎", |
|
"首饰", |
|
"馒头", |
|
"馄饨", |
|
"风筝", |
|
"难为", |
|
"队伍", |
|
"阔气", |
|
"闺女", |
|
"门道", |
|
"锄头", |
|
"铺盖", |
|
"铃铛", |
|
"铁匠", |
|
"钥匙", |
|
"里脊", |
|
"里头", |
|
"部分", |
|
"那么", |
|
"道士", |
|
"造化", |
|
"迷糊", |
|
"连累", |
|
"这么", |
|
"这个", |
|
"运气", |
|
"过去", |
|
"软和", |
|
"转悠", |
|
"踏实", |
|
"跳蚤", |
|
"跟头", |
|
"趔趄", |
|
"财主", |
|
"豆腐", |
|
"讲究", |
|
"记性", |
|
"记号", |
|
"认识", |
|
"规矩", |
|
"见识", |
|
"裁缝", |
|
"补丁", |
|
"衣裳", |
|
"衣服", |
|
"衙门", |
|
"街坊", |
|
"行李", |
|
"行当", |
|
"蛤蟆", |
|
"蘑菇", |
|
"薄荷", |
|
"葫芦", |
|
"葡萄", |
|
"萝卜", |
|
"荸荠", |
|
"苗条", |
|
"苗头", |
|
"苍蝇", |
|
"芝麻", |
|
"舒服", |
|
"舒坦", |
|
"舌头", |
|
"自在", |
|
"膏药", |
|
"脾气", |
|
"脑袋", |
|
"脊梁", |
|
"能耐", |
|
"胳膊", |
|
"胭脂", |
|
"胡萝", |
|
"胡琴", |
|
"胡同", |
|
"聪明", |
|
"耽误", |
|
"耽搁", |
|
"耷拉", |
|
"耳朵", |
|
"老爷", |
|
"老实", |
|
"老婆", |
|
"老头", |
|
"老太", |
|
"翻腾", |
|
"罗嗦", |
|
"罐头", |
|
"编辑", |
|
"结实", |
|
"红火", |
|
"累赘", |
|
"糨糊", |
|
"糊涂", |
|
"精神", |
|
"粮食", |
|
"簸箕", |
|
"篱笆", |
|
"算计", |
|
"算盘", |
|
"答应", |
|
"笤帚", |
|
"笑语", |
|
"笑话", |
|
"窟窿", |
|
"窝囊", |
|
"窗户", |
|
"稳当", |
|
"稀罕", |
|
"称呼", |
|
"秧歌", |
|
"秀气", |
|
"秀才", |
|
"福气", |
|
"祖宗", |
|
"砚台", |
|
"码头", |
|
"石榴", |
|
"石头", |
|
"石匠", |
|
"知识", |
|
"眼睛", |
|
"眯缝", |
|
"眨巴", |
|
"眉毛", |
|
"相声", |
|
"盘算", |
|
"白净", |
|
"痢疾", |
|
"痛快", |
|
"疟疾", |
|
"疙瘩", |
|
"疏忽", |
|
"畜生", |
|
"生意", |
|
"甘蔗", |
|
"琵琶", |
|
"琢磨", |
|
"琉璃", |
|
"玻璃", |
|
"玫瑰", |
|
"玄乎", |
|
"狐狸", |
|
"状元", |
|
"特务", |
|
"牲口", |
|
"牙碜", |
|
"牌楼", |
|
"爽快", |
|
"爱人", |
|
"热闹", |
|
"烧饼", |
|
"烟筒", |
|
"烂糊", |
|
"点心", |
|
"炊帚", |
|
"灯笼", |
|
"火候", |
|
"漂亮", |
|
"滑溜", |
|
"溜达", |
|
"温和", |
|
"清楚", |
|
"消息", |
|
"浪头", |
|
"活泼", |
|
"比方", |
|
"正经", |
|
"欺负", |
|
"模糊", |
|
"槟榔", |
|
"棺材", |
|
"棒槌", |
|
"棉花", |
|
"核桃", |
|
"栅栏", |
|
"柴火", |
|
"架势", |
|
"枕头", |
|
"枇杷", |
|
"机灵", |
|
"本事", |
|
"木头", |
|
"木匠", |
|
"朋友", |
|
"月饼", |
|
"月亮", |
|
"暖和", |
|
"明白", |
|
"时候", |
|
"新鲜", |
|
"故事", |
|
"收拾", |
|
"收成", |
|
"提防", |
|
"挖苦", |
|
"挑剔", |
|
"指甲", |
|
"指头", |
|
"拾掇", |
|
"拳头", |
|
"拨弄", |
|
"招牌", |
|
"招呼", |
|
"抬举", |
|
"护士", |
|
"折腾", |
|
"扫帚", |
|
"打量", |
|
"打算", |
|
"打点", |
|
"打扮", |
|
"打听", |
|
"打发", |
|
"扎实", |
|
"扁担", |
|
"戒指", |
|
"懒得", |
|
"意识", |
|
"意思", |
|
"情形", |
|
"悟性", |
|
"怪物", |
|
"思量", |
|
"怎么", |
|
"念头", |
|
"念叨", |
|
"快活", |
|
"忙活", |
|
"志气", |
|
"心思", |
|
"得罪", |
|
"张罗", |
|
"弟兄", |
|
"开通", |
|
"应酬", |
|
"庄稼", |
|
"干事", |
|
"帮手", |
|
"帐篷", |
|
"希罕", |
|
"师父", |
|
"师傅", |
|
"巴结", |
|
"巴掌", |
|
"差事", |
|
"工夫", |
|
"岁数", |
|
"屁股", |
|
"尾巴", |
|
"少爷", |
|
"小气", |
|
"小伙", |
|
"将就", |
|
"对头", |
|
"对付", |
|
"寡妇", |
|
"家伙", |
|
"客气", |
|
"实在", |
|
"官司", |
|
"学问", |
|
"学生", |
|
"字号", |
|
"嫁妆", |
|
"媳妇", |
|
"媒人", |
|
"婆家", |
|
"娘家", |
|
"委屈", |
|
"姑娘", |
|
"姐夫", |
|
"妯娌", |
|
"妥当", |
|
"妖精", |
|
"奴才", |
|
"女婿", |
|
"头发", |
|
"太阳", |
|
"大爷", |
|
"大方", |
|
"大意", |
|
"大夫", |
|
"多少", |
|
"多么", |
|
"外甥", |
|
"壮实", |
|
"地道", |
|
"地方", |
|
"在乎", |
|
"困难", |
|
"嘴巴", |
|
"嘱咐", |
|
"嘟囔", |
|
"嘀咕", |
|
"喜欢", |
|
"喇嘛", |
|
"喇叭", |
|
"商量", |
|
"唾沫", |
|
"哑巴", |
|
"哈欠", |
|
"哆嗦", |
|
"咳嗽", |
|
"和尚", |
|
"告诉", |
|
"告示", |
|
"含糊", |
|
"吓唬", |
|
"后头", |
|
"名字", |
|
"名堂", |
|
"合同", |
|
"吆喝", |
|
"叫唤", |
|
"口袋", |
|
"厚道", |
|
"厉害", |
|
"千斤", |
|
"包袱", |
|
"包涵", |
|
"匀称", |
|
"勤快", |
|
"动静", |
|
"动弹", |
|
"功夫", |
|
"力气", |
|
"前头", |
|
"刺猬", |
|
"刺激", |
|
"别扭", |
|
"利落", |
|
"利索", |
|
"利害", |
|
"分析", |
|
"出息", |
|
"凑合", |
|
"凉快", |
|
"冷战", |
|
"冤枉", |
|
"冒失", |
|
"养活", |
|
"关系", |
|
"先生", |
|
"兄弟", |
|
"便宜", |
|
"使唤", |
|
"佩服", |
|
"作坊", |
|
"体面", |
|
"位置", |
|
"似的", |
|
"伙计", |
|
"休息", |
|
"什么", |
|
"人家", |
|
"亲戚", |
|
"亲家", |
|
"交情", |
|
"云彩", |
|
"事情", |
|
"买卖", |
|
"主意", |
|
"丫头", |
|
"丧气", |
|
"两口", |
|
"东西", |
|
"东家", |
|
"世故", |
|
"不由", |
|
"不在", |
|
"下水", |
|
"下巴", |
|
"上头", |
|
"上司", |
|
"丈夫", |
|
"丈人", |
|
"一辈", |
|
"那个", |
|
"菩萨", |
|
"父亲", |
|
"母亲", |
|
"咕噜", |
|
"邋遢", |
|
"费用", |
|
"冤家", |
|
"甜头", |
|
"介绍", |
|
"荒唐", |
|
"大人", |
|
"泥鳅", |
|
"幸福", |
|
"熟悉", |
|
"计划", |
|
"扑腾", |
|
"蜡烛", |
|
"姥爷", |
|
"照顾", |
|
"喉咙", |
|
"吉他", |
|
"弄堂", |
|
"蚂蚱", |
|
"凤凰", |
|
"拖沓", |
|
"寒碜", |
|
"糟蹋", |
|
"倒腾", |
|
"报复", |
|
"逻辑", |
|
"盘缠", |
|
"喽啰", |
|
"牢骚", |
|
"咖喱", |
|
"扫把", |
|
"惦记", |
|
} |
|
self.must_not_neural_tone_words = { |
|
"男子", |
|
"女子", |
|
"分子", |
|
"原子", |
|
"量子", |
|
"莲子", |
|
"石子", |
|
"瓜子", |
|
"电子", |
|
"人人", |
|
"虎虎", |
|
} |
|
self.punc = ":,;。?!“”‘’':,;.?!" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]: |
|
|
|
for j, item in enumerate(word): |
|
if ( |
|
j - 1 >= 0 |
|
and item == word[j - 1] |
|
and pos[0] in {"n", "v", "a"} |
|
and word not in self.must_not_neural_tone_words |
|
): |
|
finals[j] = finals[j][:-1] + "5" |
|
ge_idx = word.find("个") |
|
if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶": |
|
finals[-1] = finals[-1][:-1] + "5" |
|
elif len(word) >= 1 and word[-1] in "的地得": |
|
finals[-1] = finals[-1][:-1] + "5" |
|
|
|
|
|
|
|
elif ( |
|
len(word) > 1 |
|
and word[-1] in "们子" |
|
and pos in {"r", "n"} |
|
and word not in self.must_not_neural_tone_words |
|
): |
|
finals[-1] = finals[-1][:-1] + "5" |
|
|
|
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}: |
|
finals[-1] = finals[-1][:-1] + "5" |
|
|
|
elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开": |
|
finals[-1] = finals[-1][:-1] + "5" |
|
|
|
elif ( |
|
ge_idx >= 1 |
|
and ( |
|
word[ge_idx - 1].isnumeric() |
|
or word[ge_idx - 1] in "几有两半多各整每做是" |
|
) |
|
) or word == "个": |
|
finals[ge_idx] = finals[ge_idx][:-1] + "5" |
|
else: |
|
if ( |
|
word in self.must_neural_tone_words |
|
or word[-2:] in self.must_neural_tone_words |
|
): |
|
finals[-1] = finals[-1][:-1] + "5" |
|
|
|
word_list = self._split_word(word) |
|
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]] |
|
for i, word in enumerate(word_list): |
|
|
|
if ( |
|
word in self.must_neural_tone_words |
|
or word[-2:] in self.must_neural_tone_words |
|
): |
|
finals_list[i][-1] = finals_list[i][-1][:-1] + "5" |
|
finals = sum(finals_list, []) |
|
return finals |
|
|
|
def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]: |
|
|
|
if len(word) == 3 and word[1] == "不": |
|
finals[1] = finals[1][:-1] + "5" |
|
else: |
|
for i, char in enumerate(word): |
|
|
|
if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4": |
|
finals[i] = finals[i][:-1] + "2" |
|
return finals |
|
|
|
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: |
|
|
|
if word.find("一") != -1 and all( |
|
[item.isnumeric() for item in word if item != "一"] |
|
): |
|
return finals |
|
|
|
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]: |
|
finals[1] = finals[1][:-1] + "5" |
|
|
|
elif word.startswith("第一"): |
|
finals[1] = finals[1][:-1] + "1" |
|
else: |
|
for i, char in enumerate(word): |
|
if char == "一" and i + 1 < len(word): |
|
|
|
if finals[i + 1][-1] == "4": |
|
finals[i] = finals[i][:-1] + "2" |
|
|
|
else: |
|
|
|
if word[i + 1] not in self.punc: |
|
finals[i] = finals[i][:-1] + "4" |
|
return finals |
|
|
|
def _split_word(self, word: str) -> List[str]: |
|
word_list = jieba.cut_for_search(word) |
|
word_list = sorted(word_list, key=lambda i: len(i), reverse=False) |
|
first_subword = word_list[0] |
|
first_begin_idx = word.find(first_subword) |
|
if first_begin_idx == 0: |
|
second_subword = word[len(first_subword) :] |
|
new_word_list = [first_subword, second_subword] |
|
else: |
|
second_subword = word[: -len(first_subword)] |
|
new_word_list = [second_subword, first_subword] |
|
return new_word_list |
|
|
|
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]: |
|
if len(word) == 2 and self._all_tone_three(finals): |
|
finals[0] = finals[0][:-1] + "2" |
|
elif len(word) == 3: |
|
word_list = self._split_word(word) |
|
if self._all_tone_three(finals): |
|
|
|
if len(word_list[0]) == 2: |
|
finals[0] = finals[0][:-1] + "2" |
|
finals[1] = finals[1][:-1] + "2" |
|
|
|
elif len(word_list[0]) == 1: |
|
finals[1] = finals[1][:-1] + "2" |
|
else: |
|
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]] |
|
if len(finals_list) == 2: |
|
for i, sub in enumerate(finals_list): |
|
|
|
if self._all_tone_three(sub) and len(sub) == 2: |
|
finals_list[i][0] = finals_list[i][0][:-1] + "2" |
|
|
|
elif ( |
|
i == 1 |
|
and not self._all_tone_three(sub) |
|
and finals_list[i][0][-1] == "3" |
|
and finals_list[0][-1][-1] == "3" |
|
): |
|
finals_list[0][-1] = finals_list[0][-1][:-1] + "2" |
|
finals = sum(finals_list, []) |
|
|
|
elif len(word) == 4: |
|
finals_list = [finals[:2], finals[2:]] |
|
finals = [] |
|
for sub in finals_list: |
|
if self._all_tone_three(sub): |
|
sub[0] = sub[0][:-1] + "2" |
|
finals += sub |
|
|
|
return finals |
|
|
|
def _all_tone_three(self, finals: List[str]) -> bool: |
|
return all(x[-1] == "3" for x in finals) |
|
|
|
|
|
|
|
def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: |
|
new_seg = [] |
|
last_word = "" |
|
for word, pos in seg: |
|
if last_word == "不": |
|
word = last_word + word |
|
if word != "不": |
|
new_seg.append((word, pos)) |
|
last_word = word[:] |
|
if last_word == "不": |
|
new_seg.append((last_word, "d")) |
|
last_word = "" |
|
return new_seg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: |
|
new_seg = [] |
|
|
|
for i, (word, pos) in enumerate(seg): |
|
if ( |
|
i - 1 >= 0 |
|
and word == "一" |
|
and i + 1 < len(seg) |
|
and seg[i - 1][0] == seg[i + 1][0] |
|
and seg[i - 1][1] == "v" |
|
): |
|
new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0] |
|
else: |
|
if ( |
|
i - 2 >= 0 |
|
and seg[i - 1][0] == "一" |
|
and seg[i - 2][0] == word |
|
and pos == "v" |
|
): |
|
continue |
|
else: |
|
new_seg.append([word, pos]) |
|
seg = new_seg |
|
new_seg = [] |
|
|
|
for i, (word, pos) in enumerate(seg): |
|
if new_seg and new_seg[-1][0] == "一": |
|
new_seg[-1][0] = new_seg[-1][0] + word |
|
else: |
|
new_seg.append([word, pos]) |
|
return new_seg |
|
|
|
|
|
def _merge_continuous_three_tones( |
|
self, seg: List[Tuple[str, str]] |
|
) -> List[Tuple[str, str]]: |
|
new_seg = [] |
|
sub_finals_list = [ |
|
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) |
|
for (word, pos) in seg |
|
] |
|
assert len(sub_finals_list) == len(seg) |
|
merge_last = [False] * len(seg) |
|
for i, (word, pos) in enumerate(seg): |
|
if ( |
|
i - 1 >= 0 |
|
and self._all_tone_three(sub_finals_list[i - 1]) |
|
and self._all_tone_three(sub_finals_list[i]) |
|
and not merge_last[i - 1] |
|
): |
|
|
|
if ( |
|
not self._is_reduplication(seg[i - 1][0]) |
|
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3 |
|
): |
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] |
|
merge_last[i] = True |
|
else: |
|
new_seg.append([word, pos]) |
|
else: |
|
new_seg.append([word, pos]) |
|
|
|
return new_seg |
|
|
|
def _is_reduplication(self, word: str) -> bool: |
|
return len(word) == 2 and word[0] == word[1] |
|
|
|
|
|
def _merge_continuous_three_tones_2( |
|
self, seg: List[Tuple[str, str]] |
|
) -> List[Tuple[str, str]]: |
|
new_seg = [] |
|
sub_finals_list = [ |
|
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) |
|
for (word, pos) in seg |
|
] |
|
assert len(sub_finals_list) == len(seg) |
|
merge_last = [False] * len(seg) |
|
for i, (word, pos) in enumerate(seg): |
|
if ( |
|
i - 1 >= 0 |
|
and sub_finals_list[i - 1][-1][-1] == "3" |
|
and sub_finals_list[i][0][-1] == "3" |
|
and not merge_last[i - 1] |
|
): |
|
|
|
if ( |
|
not self._is_reduplication(seg[i - 1][0]) |
|
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3 |
|
): |
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] |
|
merge_last[i] = True |
|
else: |
|
new_seg.append([word, pos]) |
|
else: |
|
new_seg.append([word, pos]) |
|
return new_seg |
|
|
|
def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: |
|
new_seg = [] |
|
for i, (word, pos) in enumerate(seg): |
|
if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#": |
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] |
|
else: |
|
new_seg.append([word, pos]) |
|
return new_seg |
|
|
|
def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: |
|
new_seg = [] |
|
for i, (word, pos) in enumerate(seg): |
|
if new_seg and word == new_seg[-1][0]: |
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] |
|
else: |
|
new_seg.append([word, pos]) |
|
return new_seg |
|
|
|
def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: |
|
seg = self._merge_bu(seg) |
|
try: |
|
seg = self._merge_yi(seg) |
|
except: |
|
print("_merge_yi failed") |
|
seg = self._merge_reduplication(seg) |
|
seg = self._merge_continuous_three_tones(seg) |
|
seg = self._merge_continuous_three_tones_2(seg) |
|
seg = self._merge_er(seg) |
|
return seg |
|
|
|
def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]: |
|
finals = self._bu_sandhi(word, finals) |
|
finals = self._yi_sandhi(word, finals) |
|
finals = self._neural_sandhi(word, pos, finals) |
|
finals = self._three_sandhi(word, finals) |
|
return finals |
|
|
|
|
|
punctuation = ["!", "?", "…", ",", ".", "'", "-"] |
|
pu_symbols = punctuation + ["SP", "UNK"] |
|
pad = "_" |
|
|
|
|
|
zh_symbols = [ |
|
"E", |
|
"En", |
|
"a", |
|
"ai", |
|
"an", |
|
"ang", |
|
"ao", |
|
"b", |
|
"c", |
|
"ch", |
|
"d", |
|
"e", |
|
"ei", |
|
"en", |
|
"eng", |
|
"er", |
|
"f", |
|
"g", |
|
"h", |
|
"i", |
|
"i0", |
|
"ia", |
|
"ian", |
|
"iang", |
|
"iao", |
|
"ie", |
|
"in", |
|
"ing", |
|
"iong", |
|
"ir", |
|
"iu", |
|
"j", |
|
"k", |
|
"l", |
|
"m", |
|
"n", |
|
"o", |
|
"ong", |
|
"ou", |
|
"p", |
|
"q", |
|
"r", |
|
"s", |
|
"sh", |
|
"t", |
|
"u", |
|
"ua", |
|
"uai", |
|
"uan", |
|
"uang", |
|
"ui", |
|
"un", |
|
"uo", |
|
"v", |
|
"van", |
|
"ve", |
|
"vn", |
|
"w", |
|
"x", |
|
"y", |
|
"z", |
|
"zh", |
|
"AA", |
|
"EE", |
|
"OO", |
|
] |
|
num_zh_tones = 6 |
|
|
|
|
|
ja_symbols = [ |
|
"N", |
|
"a", |
|
"a:", |
|
"b", |
|
"by", |
|
"ch", |
|
"d", |
|
"dy", |
|
"e", |
|
"e:", |
|
"f", |
|
"g", |
|
"gy", |
|
"h", |
|
"hy", |
|
"i", |
|
"i:", |
|
"j", |
|
"k", |
|
"ky", |
|
"m", |
|
"my", |
|
"n", |
|
"ny", |
|
"o", |
|
"o:", |
|
"p", |
|
"py", |
|
"q", |
|
"r", |
|
"ry", |
|
"s", |
|
"sh", |
|
"t", |
|
"ts", |
|
"ty", |
|
"u", |
|
"u:", |
|
"w", |
|
"y", |
|
"z", |
|
"zy", |
|
] |
|
num_ja_tones = 2 |
|
|
|
|
|
en_symbols = [ |
|
"aa", |
|
"ae", |
|
"ah", |
|
"ao", |
|
"aw", |
|
"ay", |
|
"b", |
|
"ch", |
|
"d", |
|
"dh", |
|
"eh", |
|
"er", |
|
"ey", |
|
"f", |
|
"g", |
|
"hh", |
|
"ih", |
|
"iy", |
|
"jh", |
|
"k", |
|
"l", |
|
"m", |
|
"n", |
|
"ng", |
|
"ow", |
|
"oy", |
|
"p", |
|
"r", |
|
"s", |
|
"sh", |
|
"t", |
|
"th", |
|
"uh", |
|
"uw", |
|
"V", |
|
"w", |
|
"y", |
|
"z", |
|
"zh", |
|
] |
|
num_en_tones = 4 |
|
|
|
|
|
normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols)) |
|
symbols = [pad] + normal_symbols + pu_symbols |
|
sil_phonemes_ids = [symbols.index(i) for i in pu_symbols] |
|
|
|
|
|
num_tones = num_zh_tones + num_ja_tones + num_en_tones |
|
|
|
|
|
language_id_map = {"ZH": 0, "JP": 1, "EN": 2} |
|
num_languages = len(language_id_map.keys()) |
|
|
|
language_tone_start_map = { |
|
"ZH": 0, |
|
"JP": num_zh_tones, |
|
"EN": num_zh_tones + num_ja_tones, |
|
} |
|
|
|
current_file_path = os.path.dirname(__file__) |
|
pinyin_to_symbol_map = { |
|
line.split("\t")[0]: line.strip().split("\t")[1] |
|
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() |
|
} |
|
|
|
|
|
|
|
|
|
rep_map = { |
|
":": ",", |
|
";": ",", |
|
",": ",", |
|
"。": ".", |
|
"!": "!", |
|
"?": "?", |
|
"\n": ".", |
|
"·": ",", |
|
"、": ",", |
|
"...": "…", |
|
"$": ".", |
|
"“": "'", |
|
"”": "'", |
|
'"': "'", |
|
"‘": "'", |
|
"’": "'", |
|
"(": "'", |
|
")": "'", |
|
"(": "'", |
|
")": "'", |
|
"《": "'", |
|
"》": "'", |
|
"【": "'", |
|
"】": "'", |
|
"[": "'", |
|
"]": "'", |
|
"—": "-", |
|
"~": "-", |
|
"~": "-", |
|
"「": "'", |
|
"」": "'", |
|
} |
|
|
|
tone_modifier = ToneSandhi() |
|
|
|
|
|
def replace_punctuation(text): |
|
text = text.replace("嗯", "恩").replace("呣", "母") |
|
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) |
|
|
|
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) |
|
|
|
replaced_text = re.sub( |
|
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text |
|
) |
|
|
|
return replaced_text |
|
|
|
|
|
def g2p(text): |
|
pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) |
|
sentences = [i for i in re.split(pattern, text) if i.strip() != ""] |
|
phones, tones, word2ph = _g2p(sentences) |
|
assert sum(word2ph) == len(phones) |
|
assert len(word2ph) == len(text) |
|
phones = ["_"] + phones + ["_"] |
|
tones = [0] + tones + [0] |
|
word2ph = [1] + word2ph + [1] |
|
return phones, tones, word2ph |
|
|
|
|
|
def _get_initials_finals(word): |
|
initials = [] |
|
finals = [] |
|
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) |
|
orig_finals = lazy_pinyin( |
|
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3 |
|
) |
|
for c, v in zip(orig_initials, orig_finals): |
|
initials.append(c) |
|
finals.append(v) |
|
return initials, finals |
|
|
|
|
|
def _g2p(segments): |
|
phones_list = [] |
|
tones_list = [] |
|
word2ph = [] |
|
for seg in segments: |
|
|
|
seg = re.sub("[a-zA-Z]+", "", seg) |
|
seg_cut = psg.lcut(seg) |
|
initials = [] |
|
finals = [] |
|
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) |
|
for word, pos in seg_cut: |
|
if pos == "eng": |
|
continue |
|
sub_initials, sub_finals = _get_initials_finals(word) |
|
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) |
|
initials.append(sub_initials) |
|
finals.append(sub_finals) |
|
|
|
|
|
initials = sum(initials, []) |
|
finals = sum(finals, []) |
|
|
|
for c, v in zip(initials, finals): |
|
raw_pinyin = c + v |
|
|
|
|
|
if c == v: |
|
assert c in punctuation |
|
phone = [c] |
|
tone = "0" |
|
word2ph.append(1) |
|
else: |
|
v_without_tone = v[:-1] |
|
tone = v[-1] |
|
|
|
pinyin = c + v_without_tone |
|
assert tone in "12345" |
|
|
|
if c: |
|
|
|
v_rep_map = { |
|
"uei": "ui", |
|
"iou": "iu", |
|
"uen": "un", |
|
} |
|
if v_without_tone in v_rep_map.keys(): |
|
pinyin = c + v_rep_map[v_without_tone] |
|
else: |
|
|
|
pinyin_rep_map = { |
|
"ing": "ying", |
|
"i": "yi", |
|
"in": "yin", |
|
"u": "wu", |
|
} |
|
if pinyin in pinyin_rep_map.keys(): |
|
pinyin = pinyin_rep_map[pinyin] |
|
else: |
|
single_rep_map = { |
|
"v": "yu", |
|
"e": "e", |
|
"i": "y", |
|
"u": "w", |
|
} |
|
if pinyin[0] in single_rep_map.keys(): |
|
pinyin = single_rep_map[pinyin[0]] + pinyin[1:] |
|
|
|
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) |
|
phone = pinyin_to_symbol_map[pinyin].split(" ") |
|
word2ph.append(len(phone)) |
|
|
|
phones_list += phone |
|
tones_list += [int(tone)] * len(phone) |
|
return phones_list, tones_list, word2ph |
|
|
|
|
|
def text_normalize(text): |
|
numbers = re.findall(r"\d+(?:\.?\d+)?", text) |
|
for number in numbers: |
|
text = text.replace(number, cn2an.an2cn(number), 1) |
|
text = replace_punctuation(text) |
|
return text |
|
|
|
def get_bert_feature( |
|
text, |
|
word2ph, |
|
style_text=None, |
|
style_weight=0.7, |
|
): |
|
global bert_model |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="np",padding="max_length",truncation=True,max_length=256) |
|
|
|
|
|
start_time = time.time() |
|
res = bert_model.inference([inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"]]) |
|
flow_time = time.time() - start_time |
|
print(f"bert 运行时间: {flow_time:.4f} 秒") |
|
|
|
|
|
res = res[0][0] |
|
|
|
if style_text: |
|
assert False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert len(word2ph) == len(text) + 2 |
|
word2phone = word2ph |
|
phone_level_feature = [] |
|
for i in range(len(word2phone)): |
|
if style_text: |
|
repeat_feature = ( |
|
res[i].repeat(word2phone[i], 1) * (1 - style_weight) |
|
|
|
) |
|
else: |
|
repeat_feature = np.tile(res[i], (word2phone[i], 1)) |
|
phone_level_feature.append(repeat_feature) |
|
|
|
phone_level_feature = np.concatenate(phone_level_feature, axis=0) |
|
|
|
return phone_level_feature.T |
|
|
|
def clean_text(text, language): |
|
norm_text = text_normalize(text) |
|
phones, tones, word2ph = g2p(norm_text) |
|
return norm_text, phones, tones, word2ph |
|
|
|
|
|
def clean_text_bert(text, language): |
|
norm_text = text_normalize(text) |
|
phones, tones, word2ph = g2p(norm_text) |
|
bert = get_bert_feature(norm_text, word2ph) |
|
return phones, tones, bert |
|
|
|
_symbol_to_id = {s: i for i, s in enumerate(symbols)} |
|
|
|
def cleaned_text_to_sequence(cleaned_text, tones, language): |
|
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text. |
|
Args: |
|
text: string to convert to a sequence |
|
Returns: |
|
List of integers corresponding to the symbols in the text |
|
""" |
|
phones = [_symbol_to_id[symbol] for symbol in cleaned_text] |
|
tone_start = language_tone_start_map[language] |
|
tones = [i + tone_start for i in tones] |
|
lang_id = language_id_map[language] |
|
lang_ids = [lang_id for i in phones] |
|
return phones, tones, lang_ids |
|
|
|
def text_to_sequence(text, language): |
|
norm_text, phones, tones, word2ph = clean_text(text, language) |
|
return cleaned_text_to_sequence(phones, tones, language) |
|
|
|
def intersperse(lst, item): |
|
result = [item] * (len(lst) * 2 + 1) |
|
result[1::2] = lst |
|
return result |
|
|
|
def get_text(text, language_str, style_text=None, style_weight=0.7, add_blank=False): |
|
|
|
norm_text, phone, tone, word2ph = clean_text(text, language_str) |
|
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) |
|
|
|
if add_blank: |
|
phone = intersperse(phone, 0) |
|
tone = intersperse(tone, 0) |
|
language = intersperse(language, 0) |
|
for i in range(len(word2ph)): |
|
word2ph[i] = word2ph[i] * 2 |
|
word2ph[0] += 1 |
|
bert_ori = get_bert_feature( |
|
norm_text, word2ph, style_text, style_weight |
|
) |
|
del word2ph |
|
assert bert_ori.shape[-1] == len(phone), phone |
|
|
|
if language_str == "ZH": |
|
bert = bert_ori |
|
ja_bert = np.zeros((1024, len(phone))) |
|
en_bert = np.zeros((1024, len(phone))) |
|
elif language_str == "JP": |
|
bert = np.zeros((1024, len(phone))) |
|
ja_bert = bert_ori |
|
en_bert = np.zeros((1024, len(phone))) |
|
elif language_str == "EN": |
|
bert = np.zeros((1024, len(phone))) |
|
ja_bert = np.zeros((1024, len(phone))) |
|
en_bert = bert_ori |
|
else: |
|
raise ValueError("language_str should be ZH, JP or EN") |
|
|
|
assert bert.shape[-1] == len( |
|
phone |
|
), f"Bert seq len {bert.shape[-1]} != {len(phone)}" |
|
phone = np.array(phone) |
|
tone = np.array(tone) |
|
language = np.array(language) |
|
return bert, ja_bert, en_bert, phone, tone, language |
|
|
|
if __name__ == "__main__": |
|
name = "lx" |
|
model_prefix = f"onnx/{name}/{name}_" |
|
bert_path = "./bert/chinese-roberta-wwm-ext-large" |
|
flow_dec_input_len = 1024 |
|
model_sample_rate = 44100 |
|
|
|
text = "我个人认为,这个意大利面就应该拌42号混凝土,因为这个螺丝钉的长度,它很容易会直接影响到挖掘机的扭矩你知道吧。你往里砸的时候,一瞬间它就会产生大量的高能蛋白,俗称ufo,会严重影响经济的发展,甚至对整个太平洋以及充电器都会造成一定的核污染。你知道啊?再者说,根据这个勾股定理,你可以很容易地推断出人工饲养的东条英机,它是可以捕获野生的三角函数的。所以说这个秦始皇的切面是否具有放射性啊,特朗普的N次方是否含有沉淀物,都不影响这个沃尔玛跟维尔康在南极会合。" |
|
|
|
global bert_model,tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained(bert_path) |
|
bert_model = RKNNLite(verbose=False) |
|
bert_model.load_rknn(bert_path + "/model.rknn") |
|
bert_model.init_runtime() |
|
model = InferenceSession({ |
|
"enc": model_prefix + "enc_p.onnx", |
|
"emb_g": model_prefix + "emb.onnx", |
|
"dp": model_prefix + "dp.onnx", |
|
"sdp": model_prefix + "sdp.onnx", |
|
"flow": model_prefix + "flow.onnx", |
|
"dec": model_prefix + "dec.rknn", |
|
}) |
|
|
|
|
|
text_seg = re.split(r'(?<=[。!?;])', text) |
|
output_acc = np.array([0.0]) |
|
|
|
for text in text_seg: |
|
bert, ja_bert, en_bert, phone, tone, language = get_text(text, "ZH", add_blank=True) |
|
bert = np.transpose(bert) |
|
ja_bert = np.transpose(ja_bert) |
|
en_bert = np.transpose(en_bert) |
|
|
|
sid = np.array([0]) |
|
vqidx = np.array([0]) |
|
|
|
output = model(phone, tone, language, bert, ja_bert, en_bert, vqidx, sid , |
|
rknn_pad_to=flow_dec_input_len, |
|
seed=114514, |
|
seq_noise_scale=0.8, |
|
sdp_noise_scale=0.6, |
|
length_scale=1, |
|
sdp_ratio=0, |
|
)[0,0] |
|
output_acc = np.concatenate([output_acc, output]) |
|
print(f"已生成长度: {len(output_acc) / model_sample_rate:.2f} 秒") |
|
|
|
sf.write('output.wav', output_acc, model_sample_rate) |
|
print("已生成output.wav") |