|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
from typing import List, Tuple |
|
import os |
|
import numpy as np |
|
import torch |
|
from text.symbol_table import SymbolTable |
|
from text import text_to_sequence |
|
|
|
|
|
''' |
|
TextToken: map text to id |
|
''' |
|
|
|
|
|
class TextTokenCollator: |
|
def __init__( |
|
self, |
|
text_tokens: List[str], |
|
add_eos: bool = True, |
|
add_bos: bool = True, |
|
pad_symbol: str = "<pad>", |
|
bos_symbol: str = "<bos>", |
|
eos_symbol: str = "<eos>", |
|
): |
|
self.pad_symbol = pad_symbol |
|
self.add_eos = add_eos |
|
self.add_bos = add_bos |
|
self.bos_symbol = bos_symbol |
|
self.eos_symbol = eos_symbol |
|
|
|
unique_tokens = [pad_symbol] |
|
if add_bos: |
|
unique_tokens.append(bos_symbol) |
|
if add_eos: |
|
unique_tokens.append(eos_symbol) |
|
unique_tokens.extend(sorted(text_tokens)) |
|
|
|
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} |
|
self.idx2token = unique_tokens |
|
|
|
def index( |
|
self, tokens_list: List[str] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
seqs, seq_lens = [], [] |
|
for tokens in tokens_list: |
|
assert ( |
|
all([True if s in self.token2idx else False for s in tokens]) |
|
is True |
|
) |
|
seq = ( |
|
([self.bos_symbol] if self.add_bos else []) |
|
+ list(tokens) |
|
+ ([self.eos_symbol] if self.add_eos else []) |
|
) |
|
seqs.append(seq) |
|
seq_lens.append(len(seq)) |
|
|
|
max_len = max(seq_lens) |
|
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): |
|
seq.extend([self.pad_symbol] * (max_len - seq_len)) |
|
|
|
tokens = torch.from_numpy( |
|
np.array( |
|
[[self.token2idx[token] for token in seq] for seq in seqs], |
|
dtype=np.int64, |
|
) |
|
) |
|
tokens_lens = torch.IntTensor(seq_lens) |
|
|
|
return tokens, tokens_lens |
|
|
|
def __call__(self, text): |
|
tokens_seq = [p for p in text] |
|
seq = ( |
|
([self.bos_symbol] if self.add_bos else []) |
|
+ tokens_seq |
|
+ ([self.eos_symbol] if self.add_eos else []) |
|
) |
|
|
|
token_ids = [self.token2idx[token] for token in seq] |
|
token_lens = len(tokens_seq) + self.add_eos + self.add_bos |
|
|
|
return token_ids, token_lens |
|
|
|
|
|
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollator: |
|
text_tokens_path = Path(text_tokens_file) |
|
unique_tokens = SymbolTable.from_file(text_tokens_path) |
|
collater = TextTokenCollator( |
|
unique_tokens.symbols, add_bos=True, add_eos=True |
|
) |
|
token2idx = collater.token2idx |
|
return collater, token2idx |
|
|
|
|
|
class phoneIDCollation: |
|
def __init__(self, cfg, dataset=None, symbols_dict_file=None) -> None: |
|
|
|
if cfg.preprocess.phone_extractor != 'lexicon': |
|
|
|
if symbols_dict_file is None: |
|
assert dataset is not None |
|
symbols_dict_file = os.path.join( |
|
cfg.preprocess.processed_dir, |
|
dataset, |
|
cfg.preprocess.symbols_dict |
|
) |
|
self.text_token_colloator, token2idx = get_text_token_collater(symbols_dict_file) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_phone_id_sequence(self, cfg, phones_seq): |
|
|
|
if cfg.preprocess.phone_extractor == 'lexicon': |
|
phones_seq = ' '.join(phones_seq) |
|
sequence = text_to_sequence(phones_seq, cfg.preprocess.text_cleaners) |
|
else: |
|
sequence, seq_len = self.text_token_colloator(phones_seq) |
|
return sequence |
|
|