diff --git a/.gitignore b/.gitignore index b121290..d645689 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -__* \ No newline at end of file +__* +venv/ diff --git a/readme.md b/readme.md index 24d32eb..60077d2 100644 --- a/readme.md +++ b/readme.md @@ -3,7 +3,7 @@ An implementation of SpecAugment for Pytorch ## How to use -Install pytorch (version==1.6.0 is used for testing). +Install librosa, matplotlib and pytorch (version==2.4.0 used for testing). ```python @@ -14,7 +14,10 @@ p = {'W':40, 'F':29, 'mF':2, 'T':50, 'p':1.0, 'mT':2, 'batch':False} specaug_fn = SpecAugmentTorch(**p) # [batch, c, frequency, n_frame], c=1 for magnitude or mel-spec, c=2 for complex stft -complex_stft = torch.randn(1, 1, 257, 150) +complex_stft = torch.randn(1, 1, 257, 150) +complex_stft = complex_stft - torch.min(complex_stft) +complex_stft = complex_stft / torch.max(complex_stft) + complex_stft_aug = specaug_fn(complex_stft) # [b, c, f, t] visualization_spectrogram(complex_stft_aug[0][0], "blabla") ``` @@ -34,4 +37,4 @@ run command `python spec_augment_pytorch.py` to generate examples (processed wav [2] [zcaceres/spec_augment issue17](https://github.com/zcaceres/spec_augment/issues/17) -[3] [SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition](https://arxiv.org/pdf/1904.08779.pdf) \ No newline at end of file +[3] [SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition](https://arxiv.org/pdf/1904.08779.pdf) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..67322e2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +librosa +matplotlib +torch diff --git a/spec_augment_pytorch.py b/spec_augment_pytorch.py index 812da35..1d085d2 100644 --- a/spec_augment_pytorch.py +++ b/spec_augment_pytorch.py @@ -213,12 +213,28 @@ def forward(self, n2ft): print(wav1.dtype) print(wav1.shape, wav2.shape) wav = torch.from_numpy(np.stack([wav1, wav2])) - spec = torch.stft(wav, 512, 160, 512, torch.hann_window(512)).permute(0, 3, 1, 2) # [N, 2, F, T] + + # From version 1.8.0, return_complex must always be given explicitly for + # real inputs and return_complex=False has been deprecated. Strongly prefer + # return_complex=True as in a future pytorch release, this function will + # only return complex tensors. + spec = torch.stft(wav, 512, 160, 512, torch.hann_window(512), + return_complex=True) + spec = torch.view_as_real(spec).permute(0, 3, 1, 2) # [N, 2, F, T]. print(spec.shape) # [N, 2, F, T] + if len(spec.shape) != 4: + raise ValueError("Spectrogram is not 4-D, wanted shape is [N, 2, F, T].") spec_aug = aug_fn(spec) - wav_aug = torch.istft(spec_aug.permute(0, 2, 3, 1), 512, 160, 512, torch.hann_window(512), length=wav.shape[-1]) + # Changed in version 2.0: Real datatype inputs are no longer supported. + # Input must now have a complex datatype. + # https://pytorch.org/docs/stable/generated/torch.istft.html#torch-istft + spec_aug_complex = torch.view_as_complex( + spec_aug.permute(0, 2, 3, 1).contiguous()) # [N, F, T] + wav_aug = torch.istft( + spec_aug_complex, 512, 160, 512, torch.hann_window(512), + length=wav.shape[-1]) sf.write("./examples/1089-0001-SpecAug.flac", wav_aug[0], sr) sf.write("./examples/1089-0002-SpecAug.flac", wav_aug[1], sr) @@ -229,5 +245,3 @@ def forward(self, n2ft): visualization_spectrogram(mag_aug[0],"1089-0001-SpecAug") visualization_spectrogram(mag[1],"1089-0002") visualization_spectrogram(mag_aug[1],"1089-0002-SpecAug") - -