diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py index 6db3e30595c6..ba423827fc0b 100644 --- a/nemo/collections/tts/losses/audio_codec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -512,3 +512,89 @@ def forward(self, disc_scores_real, disc_scores_gen): loss /= len(disc_scores_real) return loss + + +class MMDLoss(Loss): + """ + Maximum mean discrepancy (MMD) loss, as defined in https://arxiv.org/abs/2406.02315 + + Args: + num_codebooks: Number of codebooks. + codebok_dim: Dimension of a single codebook code. + kernel_radii: List of radii for Gaussian kernels + loss_scale: Scaling factor to apply to output loss. + """ + + def __init__(self, num_codebooks, codebook_dim, kernel_radii=(0.1, 1, 5, 10, 20, 50), loss_scale=1.0): + super().__init__() + self.num_codebooks = num_codebooks + self.codebook_dim = codebook_dim + self.kernel_radii = kernel_radii + self.loss_scale = loss_scale + + @staticmethod + def _exp_kernel(dxx, r): + return torch.exp((-0.5 / r) * dxx).sum() + + @staticmethod + def _shuffle_codebooks(x): + N, K, _ = x.size() + x_shuffled = torch.zeros_like(x) + for k in range(K): + batch_perm = torch.randperm(N, device=x.device) + x_shuffled[:, k, :] = x[batch_perm, k, :] + return x_shuffled + + @property + def input_types(self): + return { + "codes": [NeuralType(('B', 'D', 'T'), VoidType())], + } + + @property + def output_types(self): + return { + "loss": NeuralType(elements_type=LossType()) + } + + @typecheck() + def forward(self, codes): + B, D, T = codes.size() + N = B * T + + # [B, K, C, T] + x = codes.reshape(B, self.num_codebooks, self.codebook_dim, T) + # [N, K, C] + x = rearrange(x, 'B K C T -> (B T) K C') + x_mean = x.mean(dim=(0,), keepdim=True) + x_stdev = torch.sqrt(x.var(dim=(0,), keepdim=True) + 1e-8) + x = (x - x_mean) / x_stdev + y = self._shuffle_codebooks(x) + + # [N, D] + x = x.reshape([N, D]) + y = y.reshape([N, D]) + + # [N, N] + xx = torch.mm(x, x.t()) + yy = torch.mm(y, y.t()) + zz = torch.mm(x, y.t()) + + rx = xx.diag().unsqueeze(0).expand_as(xx) + ry = yy.diag().unsqueeze(0).expand_as(yy) + + dxx = rx.t() + rx - 2.0 * xx + dyy = ry.t() + ry - 2.0 * yy + dxy = rx.t() + ry - 2.0 * zz + + loss = 0.0 + coeff = -2.0 / N**2 + denom = N * (N - 1) + for r in self.kernel_radii: + loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dxx, r) - N) / denom + loss += coeff * torch.utils.checkpoint.checkpoint(self._exp_kernel, dxy, r) + loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dyy, r) - N) / denom + + loss = loss.clamp(min=0) + loss = self.loss_scale * loss + return loss \ No newline at end of file diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 33b9a80125b7..b33693b510b8 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -151,6 +151,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): else: raise ValueError(f'Unknown feature loss type {feature_loss_type}.') + if "mmd_loss" in cfg: + self.mmd_loss_fn = instantiate(cfg.mmd_loss) + self.mmd_loss_scale = cfg.get("mmd_loss_scale", 1.0) + else: + self.mmd_loss_fn = None + self.mmd_loss_scale = None + # Codebook loss setup if self.vector_quantizer: self.commit_loss_scale = cfg.get("commit_loss_scale", 1.0) @@ -470,7 +477,7 @@ def _process_batch(self, batch): # [B, T] audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len) - return audio, audio_len, audio_gen, commit_loss + return audio, audio_len, audio_gen, commit_loss, encoded @property def disc_update_prob(self) -> float: @@ -487,7 +494,7 @@ def should_update_disc(self, batch_idx) -> bool: def training_step(self, batch, batch_idx): optim_gen, optim_disc = self.optimizers() - audio, audio_len, audio_gen, commit_loss = self._process_batch(batch) + audio, audio_len, audio_gen, commit_loss, codes = self._process_batch(batch) metrics = { "global_step": self.global_step, @@ -547,6 +554,11 @@ def training_step(self, batch, batch_idx): metrics["g_loss_commit"] = commit_loss generator_losses.append(self.commit_loss_scale * commit_loss) + if self.mmd_loss_scale: + loss_mmd = self.mmd_loss_fn(codes=codes) + metrics["g_loss_mmd"] = loss_mmd + generator_losses.append(self.mmd_loss_scale * loss_mmd) + # compute embeddings for speaker consistency loss if self.use_scl_loss: # concate generated and GT waveforms @@ -592,7 +604,7 @@ def on_train_epoch_end(self): self.update_lr("epoch") def validation_step(self, batch, batch_idx): - audio, audio_len, audio_gen, _ = self._process_batch(batch) + audio, audio_len, audio_gen, _, _ = self._process_batch(batch) loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) loss_stft = self.stft_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)