From e8f799fe6ff19472254e203a71b74770d4886586 Mon Sep 17 00:00:00 2001 From: Guy Nicholson Date: Wed, 24 Jul 2024 22:24:56 +0100 Subject: [PATCH 1/2] Adds fixes for PyTorch 2.x Adds fixes in main() method. Adds package requirements. Update README example to use spectrogram value range [0., 1.] to mitigate crash with negative values. Changes tested on PyTorch 2.4.0. --- .gitignore | 3 ++- readme.md | 9 ++++++--- requirements.txt | 3 +++ spec_augment_pytorch.py | 22 ++++++++++++++++++---- 4 files changed, 29 insertions(+), 8 deletions(-) create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore index b121290..be833ef 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -__* \ No newline at end of file +__* +venv/ \ No newline at end of file diff --git a/readme.md b/readme.md index 24d32eb..e1f5b1b 100644 --- a/readme.md +++ b/readme.md @@ -3,7 +3,7 @@ An implementation of SpecAugment for Pytorch ## How to use -Install pytorch (version==1.6.0 is used for testing). +Install librosa, matplotlib and pyorch (version==2.4.0 used for testing). ```python @@ -14,7 +14,10 @@ p = {'W':40, 'F':29, 'mF':2, 'T':50, 'p':1.0, 'mT':2, 'batch':False} specaug_fn = SpecAugmentTorch(**p) # [batch, c, frequency, n_frame], c=1 for magnitude or mel-spec, c=2 for complex stft -complex_stft = torch.randn(1, 1, 257, 150) +complex_stft = torch.randn(1, 1, 257, 150) +complex_stft = complex_stft - torch.min(complex_stft) +complex_stft = complex_stft / torch.max(complex_stft) + complex_stft_aug = specaug_fn(complex_stft) # [b, c, f, t] visualization_spectrogram(complex_stft_aug[0][0], "blabla") ``` @@ -34,4 +37,4 @@ run command `python spec_augment_pytorch.py` to generate examples (processed wav [2] [zcaceres/spec_augment issue17](https://github.com/zcaceres/spec_augment/issues/17) -[3] [SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition](https://arxiv.org/pdf/1904.08779.pdf) \ No newline at end of file +[3] [SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition](https://arxiv.org/pdf/1904.08779.pdf) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..67322e2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +librosa +matplotlib +torch diff --git a/spec_augment_pytorch.py b/spec_augment_pytorch.py index 812da35..1d085d2 100644 --- a/spec_augment_pytorch.py +++ b/spec_augment_pytorch.py @@ -213,12 +213,28 @@ def forward(self, n2ft): print(wav1.dtype) print(wav1.shape, wav2.shape) wav = torch.from_numpy(np.stack([wav1, wav2])) - spec = torch.stft(wav, 512, 160, 512, torch.hann_window(512)).permute(0, 3, 1, 2) # [N, 2, F, T] + + # From version 1.8.0, return_complex must always be given explicitly for + # real inputs and return_complex=False has been deprecated. Strongly prefer + # return_complex=True as in a future pytorch release, this function will + # only return complex tensors. + spec = torch.stft(wav, 512, 160, 512, torch.hann_window(512), + return_complex=True) + spec = torch.view_as_real(spec).permute(0, 3, 1, 2) # [N, 2, F, T]. print(spec.shape) # [N, 2, F, T] + if len(spec.shape) != 4: + raise ValueError("Spectrogram is not 4-D, wanted shape is [N, 2, F, T].") spec_aug = aug_fn(spec) - wav_aug = torch.istft(spec_aug.permute(0, 2, 3, 1), 512, 160, 512, torch.hann_window(512), length=wav.shape[-1]) + # Changed in version 2.0: Real datatype inputs are no longer supported. + # Input must now have a complex datatype. + # https://pytorch.org/docs/stable/generated/torch.istft.html#torch-istft + spec_aug_complex = torch.view_as_complex( + spec_aug.permute(0, 2, 3, 1).contiguous()) # [N, F, T] + wav_aug = torch.istft( + spec_aug_complex, 512, 160, 512, torch.hann_window(512), + length=wav.shape[-1]) sf.write("./examples/1089-0001-SpecAug.flac", wav_aug[0], sr) sf.write("./examples/1089-0002-SpecAug.flac", wav_aug[1], sr) @@ -229,5 +245,3 @@ def forward(self, n2ft): visualization_spectrogram(mag_aug[0],"1089-0001-SpecAug") visualization_spectrogram(mag[1],"1089-0002") visualization_spectrogram(mag_aug[1],"1089-0002-SpecAug") - - From 9c2e50f684f749a3167125b2225e5e3844ffc60f Mon Sep 17 00:00:00 2001 From: Guy Nicholson Date: Wed, 24 Jul 2024 22:29:51 +0100 Subject: [PATCH 2/2] Fix typos. --- .gitignore | 2 +- readme.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index be833ef..d645689 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ __* -venv/ \ No newline at end of file +venv/ diff --git a/readme.md b/readme.md index e1f5b1b..60077d2 100644 --- a/readme.md +++ b/readme.md @@ -3,7 +3,7 @@ An implementation of SpecAugment for Pytorch ## How to use -Install librosa, matplotlib and pyorch (version==2.4.0 used for testing). +Install librosa, matplotlib and pytorch (version==2.4.0 used for testing). ```python