swaram / trainer.py
aoxo's picture
Upload 9 files
0d48494 verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2Model
from torchaudio.transforms import MelSpectrogram, InverseMelScale, GriffinLim
import torchaudio
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import GradScaler, autocast
class TextToSpeechDataset(Dataset):
def __init__(self, text_files, audio_files, tokenizer, mel_transform, max_length=512):
self.text_files = text_files
self.audio_files = audio_files
self.tokenizer = tokenizer
self.mel_transform = mel_transform
self.max_length = max_length
def __len__(self):
return len(self.text_files)
def __getitem__(self, idx):
# Load text
with open(self.text_files[idx], 'r') as f:
text = f.read().strip()
# Tokenize text
text_tokens = self.tokenizer.encode(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors="pt"
).squeeze(0)
# Load audio and convert to mel spectrogram
waveform, sample_rate = torchaudio.load(self.audio_files[idx])
mel_spec = self.mel_transform(waveform)
return text_tokens, mel_spec.squeeze(0)
def collate_fn(batch):
text_tokens, mel_specs = zip(*batch)
# Pad text tokens
max_text_len = max(tokens.size(0) for tokens in text_tokens)
text_tokens_padded = torch.stack([
torch.cat([tokens, torch.zeros(max_text_len - tokens.size(0), dtype=tokens.dtype)], dim=0)
if tokens.size(0) < max_text_len
else tokens[:max_text_len]
for tokens in text_tokens
])
# Pad mel spectrograms
max_mel_len = max(spec.size(1) for spec in mel_specs)
mel_specs_padded = torch.stack([
F.pad(spec, (0, max_mel_len - spec.size(1)))
if spec.size(1) < max_mel_len
else spec[:, :max_mel_len]
for spec in mel_specs
])
return text_tokens_padded, mel_specs_padded
class VAEDecoder(nn.Module):
def __init__(self, latent_dim, mel_channels=80):
super().__init__()
# Encoder part (probabilistic)
self.fc_mu = nn.Linear(latent_dim, latent_dim)
self.fc_var = nn.Linear(latent_dim, latent_dim)
# Decoder part
self.decoder_layers = nn.Sequential(
nn.Linear(latent_dim, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, mel_channels * 80), # Output mel spectrogram
nn.Unflatten(1, (mel_channels, 80))
)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, z):
mu = self.fc_mu(z)
log_var = self.fc_var(z)
# Reparameterization trick
z = self.reparameterize(mu, log_var)
# Decode
mel_spec = self.decoder_layers(z)
return mel_spec, mu, log_var
class TextToSpeechModel(nn.Module):
def __init__(self, text_encoder, vae_decoder, latent_dim=256):
super().__init__()
self.text_encoder = text_encoder
self.vae_decoder = vae_decoder
# Projection layer to map encoder output to latent space
self.projection = nn.Linear(text_encoder.config.hidden_size, latent_dim)
def forward(self, text_tokens):
# Encode text
encoder_output = self.text_encoder(text_tokens).last_hidden_state
# Mean pooling of encoder output
text_embedding = encoder_output.mean(dim=1)
# Project to latent space
latent_z = self.projection(text_embedding)
# Decode to mel spectrogram
mel_spec, mu, log_var = self.vae_decoder(latent_z)
return mel_spec, mu, log_var
def vae_loss(reconstruction, target, mu, log_var):
# Reconstruction loss (MSE)
recon_loss = F.mse_loss(reconstruction, target, reduction='mean')
# KL Divergence loss
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + 0.001 * kl_loss
def train_model(num_epochs=10, accumulation_steps=16):
# Tokenizer and mel spectrogram transform
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# Mel spectrogram configuration
mel_transform = MelSpectrogram(
sample_rate=16000,
n_mels=80,
n_fft=1024,
hop_length=256
)
# Data preparation
text_folder = './texts'
audio_folder = './audio'
# Load text and audio files
text_files = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.endswith('.txt')]
audio_files = [os.path.join(audio_folder, f) for f in os.listdir(audio_folder) if f.endswith('.wav')]
# Split dataset
train_texts, val_texts, train_audios, val_audios = train_test_split(
text_files, audio_files, test_size=0.1, random_state=42
)
# Create datasets and dataloaders
train_dataset = TextToSpeechDataset(train_texts, train_audios, tokenizer, mel_transform)
val_dataset = TextToSpeechDataset(val_texts, val_audios, tokenizer, mel_transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
# Model components
text_encoder = GPT2Model.from_pretrained('gpt2')
vae_decoder = VAEDecoder(latent_dim=256)
# Combine into full model
model = TextToSpeechModel(text_encoder, vae_decoder)
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
# Gradient scaler
scaler = GradScaler()
best_val_loss = float('inf')
# Training loop
for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch_idx, (text_tokens, mel_specs) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
text_tokens = text_tokens.to(device)
mel_specs = mel_specs.to(device)
with autocast(dtype=torch.float16, device_type='cuda'):
# Forward pass
reconstructed_mel, mu, log_var = model(text_tokens)
# Compute loss
loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var)
# Scaled loss and backpropagation
loss = loss / accumulation_steps
scaler.scale(loss).backward()
if (batch_idx + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
train_loss += loss.item()
# Validation
model.eval()
val_loss = 0
with torch.no_grad():
for text_tokens, mel_specs in tqdm(val_loader, desc=f"Validation {epoch+1}"):
text_tokens = text_tokens.to(device)
mel_specs = mel_specs.to(device)
reconstructed_mel, mu, log_var = model(text_tokens)
loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var)
val_loss += loss.item()
# Scheduler step
scheduler.step()
# Print epoch summary
print(f'Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}')
# Model saving
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_tts_model.pth')
return model
# Run training
trained_model = train_model()
# Optional: Inference function for generating mel spectrograms
def generate_mel_spectrogram(text, model, tokenizer, device):
model.eval()
with torch.no_grad():
# Tokenize input text
text_tokens = tokenizer.encode(
text,
return_tensors="pt",
truncation=True,
padding='max_length',
max_length=512
).to(device)
# Generate mel spectrogram
mel_spec, _, _ = model(text_tokens)
return mel_spec
# Optional: Convert mel spectrogram back to audio
def mel_to_audio(mel_spec, sample_rate=16000):
# Use griffin-lim for mel spectrogram inversion
inverse_mel = InverseMelScale(sample_rate=sample_rate)
griffin_lim = GriffinLim(sample_rate=sample_rate)
# Convert mel spectrogram back to waveform
waveform = griffin_lim(inverse_mel(mel_spec))
return waveform