Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions nemo/collections/tts/losses/audio_codec_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,89 @@ def forward(self, disc_scores_real, disc_scores_gen):
loss /= len(disc_scores_real)

return loss


class MMDLoss(Loss):
"""
Maximum mean discrepancy (MMD) loss, as defined in https://arxiv.org/abs/2406.02315

Args:
num_codebooks: Number of codebooks.
codebok_dim: Dimension of a single codebook code.
kernel_radii: List of radii for Gaussian kernels
loss_scale: Scaling factor to apply to output loss.
"""

def __init__(self, num_codebooks, codebook_dim, kernel_radii=(0.1, 1, 5, 10, 20, 50), loss_scale=1.0):
super().__init__()
self.num_codebooks = num_codebooks
self.codebook_dim = codebook_dim
self.kernel_radii = kernel_radii
self.loss_scale = loss_scale

@staticmethod
def _exp_kernel(dxx, r):
return torch.exp((-0.5 / r) * dxx).sum()

@staticmethod
def _shuffle_codebooks(x):
N, K, _ = x.size()
x_shuffled = torch.zeros_like(x)
for k in range(K):
batch_perm = torch.randperm(N, device=x.device)
x_shuffled[:, k, :] = x[batch_perm, k, :]
return x_shuffled

@property
def input_types(self):
return {
"codes": [NeuralType(('B', 'D', 'T'), VoidType())],
}

@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType())
}

@typecheck()
def forward(self, codes):
B, D, T = codes.size()
N = B * T

# [B, K, C, T]
x = codes.reshape(B, self.num_codebooks, self.codebook_dim, T)
# [N, K, C]
x = rearrange(x, 'B K C T -> (B T) K C')
x_mean = x.mean(dim=(0,), keepdim=True)
x_stdev = torch.sqrt(x.var(dim=(0,), keepdim=True) + 1e-8)
x = (x - x_mean) / x_stdev
y = self._shuffle_codebooks(x)

# [N, D]
x = x.reshape([N, D])
y = y.reshape([N, D])

# [N, N]
xx = torch.mm(x, x.t())
yy = torch.mm(y, y.t())
zz = torch.mm(x, y.t())

rx = xx.diag().unsqueeze(0).expand_as(xx)
ry = yy.diag().unsqueeze(0).expand_as(yy)

dxx = rx.t() + rx - 2.0 * xx
dyy = ry.t() + ry - 2.0 * yy
dxy = rx.t() + ry - 2.0 * zz

loss = 0.0
coeff = -2.0 / N**2
denom = N * (N - 1)
for r in self.kernel_radii:
loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dxx, r) - N) / denom
loss += coeff * torch.utils.checkpoint.checkpoint(self._exp_kernel, dxy, r)
loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dyy, r) - N) / denom

loss = loss.clamp(min=0)
loss = self.loss_scale * loss
return loss
18 changes: 15 additions & 3 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
else:
raise ValueError(f'Unknown feature loss type {feature_loss_type}.')

if "mmd_loss" in cfg:
self.mmd_loss_fn = instantiate(cfg.mmd_loss)
self.mmd_loss_scale = cfg.get("mmd_loss_scale", 1.0)
else:
self.mmd_loss_fn = None
self.mmd_loss_scale = None

# Codebook loss setup
if self.vector_quantizer:
self.commit_loss_scale = cfg.get("commit_loss_scale", 1.0)
Expand Down Expand Up @@ -470,7 +477,7 @@ def _process_batch(self, batch):
# [B, T]
audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len)

return audio, audio_len, audio_gen, commit_loss
return audio, audio_len, audio_gen, commit_loss, encoded

@property
def disc_update_prob(self) -> float:
Expand All @@ -487,7 +494,7 @@ def should_update_disc(self, batch_idx) -> bool:
def training_step(self, batch, batch_idx):
optim_gen, optim_disc = self.optimizers()

audio, audio_len, audio_gen, commit_loss = self._process_batch(batch)
audio, audio_len, audio_gen, commit_loss, codes = self._process_batch(batch)

metrics = {
"global_step": self.global_step,
Expand Down Expand Up @@ -547,6 +554,11 @@ def training_step(self, batch, batch_idx):
metrics["g_loss_commit"] = commit_loss
generator_losses.append(self.commit_loss_scale * commit_loss)

if self.mmd_loss_scale:
loss_mmd = self.mmd_loss_fn(codes=codes)
metrics["g_loss_mmd"] = loss_mmd
generator_losses.append(self.mmd_loss_scale * loss_mmd)

# compute embeddings for speaker consistency loss
if self.use_scl_loss:
# concate generated and GT waveforms
Expand Down Expand Up @@ -592,7 +604,7 @@ def on_train_epoch_end(self):
self.update_lr("epoch")

def validation_step(self, batch, batch_idx):
audio, audio_len, audio_gen, _ = self._process_batch(batch)
audio, audio_len, audio_gen, _, _ = self._process_batch(batch)

loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
loss_stft = self.stft_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
Expand Down