Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__*
__*
venv/
9 changes: 6 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ An implementation of SpecAugment for Pytorch

## How to use

Install pytorch (version==1.6.0 is used for testing).
Install librosa, matplotlib and pytorch (version==2.4.0 used for testing).


```python
Expand All @@ -14,7 +14,10 @@ p = {'W':40, 'F':29, 'mF':2, 'T':50, 'p':1.0, 'mT':2, 'batch':False}
specaug_fn = SpecAugmentTorch(**p)

# [batch, c, frequency, n_frame], c=1 for magnitude or mel-spec, c=2 for complex stft
complex_stft = torch.randn(1, 1, 257, 150)
complex_stft = torch.randn(1, 1, 257, 150)
complex_stft = complex_stft - torch.min(complex_stft)
complex_stft = complex_stft / torch.max(complex_stft)

complex_stft_aug = specaug_fn(complex_stft) # [b, c, f, t]
visualization_spectrogram(complex_stft_aug[0][0], "blabla")
```
Expand All @@ -34,4 +37,4 @@ run command `python spec_augment_pytorch.py` to generate examples (processed wav

[2] [zcaceres/spec_augment issue17](https://github.com/zcaceres/spec_augment/issues/17)

[3] [SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition](https://arxiv.org/pdf/1904.08779.pdf)
[3] [SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition](https://arxiv.org/pdf/1904.08779.pdf)
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
librosa
matplotlib
torch
22 changes: 18 additions & 4 deletions spec_augment_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,28 @@ def forward(self, n2ft):
print(wav1.dtype)
print(wav1.shape, wav2.shape)
wav = torch.from_numpy(np.stack([wav1, wav2]))
spec = torch.stft(wav, 512, 160, 512, torch.hann_window(512)).permute(0, 3, 1, 2) # [N, 2, F, T]

# From version 1.8.0, return_complex must always be given explicitly for
# real inputs and return_complex=False has been deprecated. Strongly prefer
# return_complex=True as in a future pytorch release, this function will
# only return complex tensors.
spec = torch.stft(wav, 512, 160, 512, torch.hann_window(512),
return_complex=True)
spec = torch.view_as_real(spec).permute(0, 3, 1, 2) # [N, 2, F, T].
print(spec.shape) # [N, 2, F, T]
if len(spec.shape) != 4:
raise ValueError("Spectrogram is not 4-D, wanted shape is [N, 2, F, T].")

spec_aug = aug_fn(spec)

wav_aug = torch.istft(spec_aug.permute(0, 2, 3, 1), 512, 160, 512, torch.hann_window(512), length=wav.shape[-1])
# Changed in version 2.0: Real datatype inputs are no longer supported.
# Input must now have a complex datatype.
# https://pytorch.org/docs/stable/generated/torch.istft.html#torch-istft
spec_aug_complex = torch.view_as_complex(
spec_aug.permute(0, 2, 3, 1).contiguous()) # [N, F, T]
wav_aug = torch.istft(
spec_aug_complex, 512, 160, 512, torch.hann_window(512),
length=wav.shape[-1])
sf.write("./examples/1089-0001-SpecAug.flac", wav_aug[0], sr)
sf.write("./examples/1089-0002-SpecAug.flac", wav_aug[1], sr)

Expand All @@ -229,5 +245,3 @@ def forward(self, n2ft):
visualization_spectrogram(mag_aug[0],"1089-0001-SpecAug")
visualization_spectrogram(mag[1],"1089-0002")
visualization_spectrogram(mag_aug[1],"1089-0002-SpecAug")