|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from transformers import GPT2Tokenizer, GPT2Model
|
|
from torchaudio.transforms import MelSpectrogram, InverseMelScale, GriffinLim
|
|
import torchaudio
|
|
from sklearn.model_selection import train_test_split
|
|
from tqdm import tqdm
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
from torch.amp import GradScaler, autocast
|
|
|
|
class TextToSpeechDataset(Dataset):
|
|
def __init__(self, text_files, audio_files, tokenizer, mel_transform, max_length=512):
|
|
self.text_files = text_files
|
|
self.audio_files = audio_files
|
|
self.tokenizer = tokenizer
|
|
self.mel_transform = mel_transform
|
|
self.max_length = max_length
|
|
|
|
def __len__(self):
|
|
return len(self.text_files)
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
with open(self.text_files[idx], 'r') as f:
|
|
text = f.read().strip()
|
|
|
|
|
|
text_tokens = self.tokenizer.encode(
|
|
text,
|
|
truncation=True,
|
|
padding='max_length',
|
|
max_length=self.max_length,
|
|
return_tensors="pt"
|
|
).squeeze(0)
|
|
|
|
|
|
waveform, sample_rate = torchaudio.load(self.audio_files[idx])
|
|
mel_spec = self.mel_transform(waveform)
|
|
|
|
return text_tokens, mel_spec.squeeze(0)
|
|
|
|
def collate_fn(batch):
|
|
text_tokens, mel_specs = zip(*batch)
|
|
|
|
|
|
max_text_len = max(tokens.size(0) for tokens in text_tokens)
|
|
text_tokens_padded = torch.stack([
|
|
torch.cat([tokens, torch.zeros(max_text_len - tokens.size(0), dtype=tokens.dtype)], dim=0)
|
|
if tokens.size(0) < max_text_len
|
|
else tokens[:max_text_len]
|
|
for tokens in text_tokens
|
|
])
|
|
|
|
|
|
max_mel_len = max(spec.size(1) for spec in mel_specs)
|
|
mel_specs_padded = torch.stack([
|
|
F.pad(spec, (0, max_mel_len - spec.size(1)))
|
|
if spec.size(1) < max_mel_len
|
|
else spec[:, :max_mel_len]
|
|
for spec in mel_specs
|
|
])
|
|
|
|
return text_tokens_padded, mel_specs_padded
|
|
|
|
class VAEDecoder(nn.Module):
|
|
def __init__(self, latent_dim, mel_channels=80):
|
|
super().__init__()
|
|
|
|
self.fc_mu = nn.Linear(latent_dim, latent_dim)
|
|
self.fc_var = nn.Linear(latent_dim, latent_dim)
|
|
|
|
|
|
self.decoder_layers = nn.Sequential(
|
|
nn.Linear(latent_dim, 512),
|
|
nn.ReLU(),
|
|
nn.Linear(512, 1024),
|
|
nn.ReLU(),
|
|
nn.Linear(1024, mel_channels * 80),
|
|
nn.Unflatten(1, (mel_channels, 80))
|
|
)
|
|
|
|
def reparameterize(self, mu, log_var):
|
|
std = torch.exp(0.5 * log_var)
|
|
eps = torch.randn_like(std)
|
|
return mu + eps * std
|
|
|
|
def forward(self, z):
|
|
mu = self.fc_mu(z)
|
|
log_var = self.fc_var(z)
|
|
|
|
|
|
z = self.reparameterize(mu, log_var)
|
|
|
|
|
|
mel_spec = self.decoder_layers(z)
|
|
|
|
return mel_spec, mu, log_var
|
|
|
|
class TextToSpeechModel(nn.Module):
|
|
def __init__(self, text_encoder, vae_decoder, latent_dim=256):
|
|
super().__init__()
|
|
self.text_encoder = text_encoder
|
|
self.vae_decoder = vae_decoder
|
|
|
|
|
|
self.projection = nn.Linear(text_encoder.config.hidden_size, latent_dim)
|
|
|
|
def forward(self, text_tokens):
|
|
|
|
encoder_output = self.text_encoder(text_tokens).last_hidden_state
|
|
|
|
|
|
text_embedding = encoder_output.mean(dim=1)
|
|
|
|
|
|
latent_z = self.projection(text_embedding)
|
|
|
|
|
|
mel_spec, mu, log_var = self.vae_decoder(latent_z)
|
|
|
|
return mel_spec, mu, log_var
|
|
|
|
def vae_loss(reconstruction, target, mu, log_var):
|
|
|
|
recon_loss = F.mse_loss(reconstruction, target, reduction='mean')
|
|
|
|
|
|
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
|
|
|
return recon_loss + 0.001 * kl_loss
|
|
|
|
def train_model(num_epochs=10, accumulation_steps=16):
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
mel_transform = MelSpectrogram(
|
|
sample_rate=16000,
|
|
n_mels=80,
|
|
n_fft=1024,
|
|
hop_length=256
|
|
)
|
|
|
|
|
|
text_folder = './texts'
|
|
audio_folder = './audio'
|
|
|
|
|
|
text_files = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.endswith('.txt')]
|
|
audio_files = [os.path.join(audio_folder, f) for f in os.listdir(audio_folder) if f.endswith('.wav')]
|
|
|
|
|
|
train_texts, val_texts, train_audios, val_audios = train_test_split(
|
|
text_files, audio_files, test_size=0.1, random_state=42
|
|
)
|
|
|
|
|
|
train_dataset = TextToSpeechDataset(train_texts, train_audios, tokenizer, mel_transform)
|
|
val_dataset = TextToSpeechDataset(val_texts, val_audios, tokenizer, mel_transform)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
|
|
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
|
|
|
|
|
|
text_encoder = GPT2Model.from_pretrained('gpt2')
|
|
vae_decoder = VAEDecoder(latent_dim=256)
|
|
|
|
|
|
model = TextToSpeechModel(text_encoder, vae_decoder)
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model = model.to(device)
|
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
|
|
|
|
|
|
scaler = GradScaler()
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
|
|
for epoch in range(num_epochs):
|
|
model.train()
|
|
train_loss = 0
|
|
|
|
for batch_idx, (text_tokens, mel_specs) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
|
|
text_tokens = text_tokens.to(device)
|
|
mel_specs = mel_specs.to(device)
|
|
|
|
with autocast(dtype=torch.float16, device_type='cuda'):
|
|
|
|
reconstructed_mel, mu, log_var = model(text_tokens)
|
|
|
|
|
|
loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var)
|
|
|
|
|
|
loss = loss / accumulation_steps
|
|
scaler.scale(loss).backward()
|
|
|
|
if (batch_idx + 1) % accumulation_steps == 0:
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
|
|
train_loss += loss.item()
|
|
|
|
|
|
model.eval()
|
|
val_loss = 0
|
|
with torch.no_grad():
|
|
for text_tokens, mel_specs in tqdm(val_loader, desc=f"Validation {epoch+1}"):
|
|
text_tokens = text_tokens.to(device)
|
|
mel_specs = mel_specs.to(device)
|
|
|
|
reconstructed_mel, mu, log_var = model(text_tokens)
|
|
loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var)
|
|
val_loss += loss.item()
|
|
|
|
|
|
scheduler.step()
|
|
|
|
|
|
print(f'Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}')
|
|
|
|
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
torch.save(model.state_dict(), 'best_tts_model.pth')
|
|
|
|
return model
|
|
|
|
|
|
trained_model = train_model()
|
|
|
|
|
|
def generate_mel_spectrogram(text, model, tokenizer, device):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
|
|
text_tokens = tokenizer.encode(
|
|
text,
|
|
return_tensors="pt",
|
|
truncation=True,
|
|
padding='max_length',
|
|
max_length=512
|
|
).to(device)
|
|
|
|
|
|
mel_spec, _, _ = model(text_tokens)
|
|
|
|
return mel_spec
|
|
|
|
|
|
def mel_to_audio(mel_spec, sample_rate=16000):
|
|
|
|
inverse_mel = InverseMelScale(sample_rate=sample_rate)
|
|
griffin_lim = GriffinLim(sample_rate=sample_rate)
|
|
|
|
|
|
waveform = griffin_lim(inverse_mel(mel_spec))
|
|
|
|
return waveform |