Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion asr/wenet/bin/recognize_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def main():
raise RuntimeError("One of either --model or (--checkpoint and --config) must be set.")

if model_arg_set:
reverb_model = load_model(args.model)
reverb_model = load_model(args.model, args.gpu)
else:
reverb_model = ReverbASR(
args.config,
Expand Down
22 changes: 11 additions & 11 deletions asr/wenet/cli/reverb.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,11 @@ def compute_feats(
) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
logging.info(f"detected sample rate: {sample_rate}")
waveform = waveform.to(torch.float)
waveform = waveform.to(torch.float).to(self.device)
if sample_rate != resample_rate:
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate
)(waveform)
waveform = waveform.to(self.device)
).to(self.device)(waveform)
feats = kaldi.fbank(
waveform,
num_mel_bins=num_mel_bins,
Expand All @@ -151,18 +150,17 @@ def feats_batcher(
feats_batch = infeats[
:, b * batch_num_feats : b * batch_num_feats + batch_num_feats, :
]
feats_lengths = torch.tensor([chunk_size] * batch_size, dtype=torch.int32)
feats_lengths = torch.tensor([chunk_size] * batch_size, dtype=torch.int32, device=self.device)
if b == num_batches - 1:
# last batch can be smaller than batch size
last_batch_size = ceil(feats_batch.shape[1] / chunk_size)
last_batch_num_feats = chunk_size * last_batch_size
feats_lengths = torch.tensor(
[chunk_size] * last_batch_size, dtype=torch.int32
)
feats_lengths = torch.tensor([chunk_size] * last_batch_size, dtype=torch.int32, device=self.device)
# Apply padding if needed
pad_amt = last_batch_num_feats - feats_batch.shape[1]
if pad_amt > 0:
feats_lengths[-1] -= pad_amt
if last_batch_size == 1:
feats_lengths[-1] -= pad_amt
feats_batch = F.pad(
input=feats_batch,
pad=(0, 0, 0, pad_amt, 0, 0),
Expand All @@ -171,7 +169,7 @@ def feats_batcher(
)
yield feats_batch.reshape(
-1, chunk_size, self.test_conf["fbank_conf"]["num_mel_bins"]
), feats_lengths.to(self.device)
), feats_lengths

def transcribe_modes(
self,
Expand Down Expand Up @@ -322,7 +320,8 @@ def get_output(


def load_model(
model: str
model: str,
gpu: int = -1,
):
"""Loads a reverb model. If "model" points to a path that exists,
tries to load a model using those files at "model".
Expand Down Expand Up @@ -353,7 +352,8 @@ def load_model(
logging.info(f"Loading the model with {config_path = } and {checkpoint_path = }")
return ReverbASR(
str(config_path),
str(checkpoint_path)
str(checkpoint_path),
gpu = gpu,
)


Expand Down
93 changes: 34 additions & 59 deletions asr/wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,102 +879,77 @@ def forward_attention_decoder(
Args:
hyps (torch.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining
(batch*beam, max_hyps_len)
hyps_lens (torch.Tensor): length of each hyp in hyps
(batch*beam)
encoder_out (torch.Tensor): corresponding encoder output
r_hyps (torch.Tensor): hyps from ctc prefix beam search, already
pad eos at the begining which is used fo right to left decoder
reverse_weight: used for verfing whether used right to left decoder,
> 0 will use.

> 0 will use.
cat_embs (torch.Tensor): category embeddings
(1, cat_emb_dim)
Returns:
torch.Tensor: decoder output
decoder_out (torch.Tensor): decoder output
(batch*beam, max_hyps_len, vocab_size)
r_decoder_out (torch.Tensor): decoder output for right to left decoder
(batch*beam, max_hyps_len, vocab_size)
"""
assert encoder_out.size(0) == 1
num_hyps = hyps.size(0)
assert hyps_lens.size(0) == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
encoder_mask = torch.ones(num_hyps,
1,
encoder_out.size(1),
dtype=torch.bool,
device=encoder_out.device)

batch_size = encoder_out.size(0)
beam_size = hyps.size(0) // batch_size
assert hyps.size(0) == batch_size * beam_size, "Number of hypotheses must be batch_size * beam_size"
assert hyps_lens.size(0) == batch_size * beam_size

# Repeat encoder output for each hypothesis in the beam, maintaining batch separation
encoder_out = encoder_out.unsqueeze(1).expand(-1, beam_size, -1, -1)
encoder_out = encoder_out.view(batch_size * beam_size, -1, encoder_out.size(-1))
encoder_mask = torch.ones(encoder_out.size(0),
1,
encoder_out.size(1),
dtype=torch.bool,
device=encoder_out.device)

# input for right to left decoder
# this hyps_lens has count <sos> token, we need minus it.
r_hyps_lens = hyps_lens - 1
# this hyps has included <sos> token, so it should be
# convert the original hyps.
r_hyps = hyps[:, 1:]
# >>> r_hyps
# >>> tensor([[ 1, 2, 3],
# >>> [ 9, 8, 4],
# >>> [ 2, -1, -1]])
# >>> r_hyps_lens
# >>> tensor([3, 3, 1])

# NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used
# in `reverse_pad_list` thus we have to refine the below code.
# Issue: https://github.com/wenet-e2e/wenet/issues/1113
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)

# Handle right-to-left decoding
max_len = torch.max(r_hyps_lens)
index_range = torch.arange(0, max_len, 1).to(encoder_out.device)
seq_len_expand = r_hyps_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)
# >>> seq_mask
# >>> tensor([[ True, True, True],
# >>> [ True, True, True],
# >>> [ True, False, False]])
index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
seq_mask = seq_len_expand > index_range # (batch*beam, max_len)

index = (seq_len_expand - 1) - index_range # (batch*beam, max_len)
index = index * seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
r_hyps = torch.gather(r_hyps, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps = torch.where(seq_mask, r_hyps, self.eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])

if self.decoder is not None:
# If using language-specific layers, handle cat_embs
if self.lsl_dec:
if verbose:
print("passing cat_emb to decoder")
print(cat_embs)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps,
reverse_weight,
cat_embs) # (num_hyps, max_hyps_len, vocab_size)
cat_embs) # (batch*beam, max_hyps_len, vocab_size)
else:
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps,
reverse_weight,
None) # (num_hyps, max_hyps_len, vocab_size)
None) # (batch*beam, max_hyps_len, vocab_size)

decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)

# right to left decoder may be not used during decoding process,
# which depends on reverse_weight param.
# r_dccoder_out will be 0.0, if reverse_weight is 0.0
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out,
dim=-1)
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
else:
decoder_out, r_decoder_out = None, None

return decoder_out, r_decoder_out

def onmt_attention_decoding(
Expand Down
135 changes: 93 additions & 42 deletions asr/wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,73 +378,124 @@ def attention_rescoring(
device = encoder_outs.device
assert encoder_outs.shape[0] == len(ctc_prefix_results)
batch_size = encoder_outs.shape[0]
results = []

# Collect all hypotheses and their lengths
all_hyps = []
all_ctc_scores = []
beam_sizes = []
for b in range(batch_size):
all_hyps.extend(ctc_prefix_results[b].nbest)
all_ctc_scores.extend(ctc_prefix_results[b].nbest_scores)
beam_sizes.append(len(ctc_prefix_results[b].nbest))

# Pad all hypotheses together
# hyps_pad: (batch*beam, max_hyps_len)
hyps_pad = pad_sequence([torch.tensor(hyp, device=device, dtype=torch.long)
for hyp in all_hyps], True, model.ignore_id)
# hyps_lens: (batch*beam)
hyps_lens = torch.tensor([len(hyp) for hyp in all_hyps],
device=device, dtype=torch.long)

# Handle special tokens if needed
if getattr(model, 'special_tokens', None) is not None \
and "transcribe" in model.special_tokens:
prev_len = hyps_pad.size(1)
# Repeat tasks and langs for each beam
tasks = [infos["tasks"][b] for b in range(batch_size) for _ in range(beam_sizes[b])]
langs = [infos["langs"][b] for b in range(batch_size) for _ in range(beam_sizes[b])]
hyps_pad, _ = add_whisper_tokens(
model.special_tokens,
hyps_pad,
model.ignore_id,
tasks=tasks,
no_timestamp=True,
langs=langs,
use_prev=False)
cur_len = hyps_pad.size(1)
hyps_lens = hyps_lens + cur_len - prev_len
prefix_len = 4
else:
hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at beginning
prefix_len = 1

# Repeat encoder outputs for each beam
encoder_out_lens = []
encoder_outs_expanded = []
for b in range(batch_size):
beam_size = beam_sizes[b]
# encoder_out: (1, max_len, encoder_dim)
encoder_out = encoder_outs[b, :encoder_lens[b], :].unsqueeze(0)
# encoder_outs_expanded: (beam_size, max_len, encoder_dim)
encoder_outs_expanded.append(encoder_out.expand(beam_size, -1, -1))
# encoder_out_lens: (beam_size)
encoder_out_lens.extend([encoder_lens[b]] * beam_size)

# encoder_outs_expanded: (batch*beam, max_len, encoder_dim)
encoder_outs_expanded = torch.cat(encoder_outs_expanded, dim=0)

# Forward decoder with all hypotheses at once
# MDR: forward_attention_decoder will rexpand the encoder_outs_expanded
# to (batch*beam, max_len, encoder_dim). Necessary to maintain C++
# compatibility.
# decoder_out: (batch*beam, max_hyps_len, vocab_size)
# r_decoder_out: (batch*beam, max_hyps_len, vocab_size)
decoder_out, r_decoder_out = model.forward_attention_decoder(
hyps_pad, hyps_lens, encoder_outs_expanded, reverse_weight, cat_embs)

# Process results batch by batch
results = []
offset = 0
for b in range(batch_size):
beam_size = beam_sizes[b]
hyps = ctc_prefix_results[b].nbest
ctc_scores = ctc_prefix_results[b].nbest_scores
hyps_pad = pad_sequence([
torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps
], True, model.ignore_id) # (beam_size, max_hyps_len)
hyps_lens = torch.tensor([len(hyp) for hyp in hyps],
device=device,
dtype=torch.long) # (beam_size,)
if getattr(model, 'special_tokens', None) is not None \
and "transcribe" in model.special_tokens:
prev_len = hyps_pad.size(1)
hyps_pad, _ = add_whisper_tokens(
model.special_tokens,
hyps_pad,
model.ignore_id,
tasks=[infos["tasks"][b]] * len(hyps),
no_timestamp=True,
langs=[infos["langs"][b]] * len(hyps),
use_prev=False)
cur_len = hyps_pad.size(1)
hyps_lens = hyps_lens + cur_len - prev_len
prefix_len = 4
else:
hyps_pad, _ = add_sos_eos(hyps_pad, sos, eos, model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
prefix_len = 1
decoder_out, r_decoder_out = model.forward_attention_decoder(
hyps_pad, hyps_lens, encoder_out, reverse_weight, cat_embs)
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
confidences = []
tokens_confidences = []
for i, hyp in enumerate(hyps):

# Process each hypothesis in the current batch
for i in range(beam_size):
idx = offset + i
hyp = hyps[i]
score = 0.0
tc = [] # tokens confidences

# Calculate forward decoder score
for j, w in enumerate(hyp):
s = decoder_out[i][j + (prefix_len - 1)][w]
s = decoder_out[idx][j + (prefix_len - 1)][w]
score += s
tc.append(math.exp(s))
score += decoder_out[i][len(hyp) + (prefix_len - 1)][eos]
# add right to left decoder score
score += decoder_out[idx][len(hyp) + (prefix_len - 1)][eos]

# Add right to left decoder score if needed
if reverse_weight > 0 and r_decoder_out.dim() > 0:
r_score = 0.0
for j, w in enumerate(hyp):
s = r_decoder_out[i][len(hyp) - j - 1 +
(prefix_len - 1)][w]
s = r_decoder_out[idx][len(hyp) - j - 1 + (prefix_len - 1)][w]
r_score += s
tc[j] = (tc[j] + math.exp(s)) / 2
r_score += r_decoder_out[i][len(hyp) + (prefix_len - 1)][eos]
r_score += r_decoder_out[idx][len(hyp) + (prefix_len - 1)][eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight

confidences.append(math.exp(score / (len(hyp) + 1)))
# add ctc score
score += ctc_scores[i] * ctc_weight
score += all_ctc_scores[idx] * ctc_weight

if score > best_score:
best_score = score
best_index = i
tokens_confidences.append(tc)

# Add best result for current batch
results.append(
DecodeResult(hyps[best_index],
best_score,
confidence=confidences[best_index],
times=ctc_prefix_results[b].nbest_times[best_index],
tokens_confidence=tokens_confidences[best_index]))
best_score,
confidence=confidences[best_index],
times=ctc_prefix_results[b].nbest_times[best_index],
tokens_confidence=tokens_confidences[best_index]))

offset += beam_size

return results

def joint_decoding(
Expand Down