From bab2ffb0ace60ca209c0c34d4c45a237ec1c16cd Mon Sep 17 00:00:00 2001 From: IgorSwat Date: Wed, 11 Feb 2026 10:55:37 +0100 Subject: [PATCH 1/6] Add whisper kv-cache & fix demo app permissions --- .../models/speech_to_text/SpeechToText.cpp | 3 +- .../models/speech_to_text/asr/ASR.cpp | 33 ++++++++++++++----- .../models/speech_to_text/asr/ASR.h | 2 +- .../src/constants/modelUrls.ts | 6 ++-- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp index eef6d562c..1e2fa8ebc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp @@ -40,7 +40,8 @@ SpeechToText::encode(std::span waveform) const { std::shared_ptr SpeechToText::decode(std::span tokens, std::span encoderOutput) const { - std::vector decoderOutput = this->asr->decode(tokens, encoderOutput); + std::vector decoderOutput = + this->asr->decode(tokens, 0, encoderOutput); return std::make_shared(decoderOutput); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp index 2ed41ff22..455061567 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -42,11 +43,15 @@ GenerationResult ASR::generate(std::span waveform, float temperature, std::vector encoderOutput = this->encode(waveform); std::vector sequenceIds = this->getInitialSequence(options); + std::vector cachedTokens = sequenceIds; const size_t initialSequenceLenght = sequenceIds.size(); std::vector scores; - while (std::cmp_less_equal(sequenceIds.size(), ASR::kMaxDecodeLength)) { - std::vector logits = this->decode(sequenceIds, encoderOutput); + uint64_t startPos = 0; + while (std::cmp_less_equal(startPos + sequenceIds.size(), + ASR::kMaxDecodeLength)) { + std::vector logits = + this->decode(sequenceIds, startPos, encoderOutput); // intentionally comparing float to float // temperatures are predefined, so this is safe @@ -74,7 +79,10 @@ GenerationResult ASR::generate(std::span waveform, float temperature, nextProb = probs[nextId]; } - sequenceIds.push_back(nextId); + // Move the startPos pointer by the amount of tokens we processed + startPos += sequenceIds.size(); + sequenceIds = {nextId}; + cachedTokens.push_back(nextId); scores.push_back(nextProb); if (nextId == this->endOfTranscriptionToken) { @@ -82,8 +90,9 @@ GenerationResult ASR::generate(std::span waveform, float temperature, } } - return {.tokens = std::vector( - sequenceIds.cbegin() + initialSequenceLenght, sequenceIds.cend()), + return {.tokens = std::vector(cachedTokens.cbegin() + + initialSequenceLenght, + cachedTokens.cend()), .scores = scores}; } @@ -318,13 +327,19 @@ std::vector ASR::encode(std::span waveform) const { return {dataPtr, dataPtr + outputNumel}; } -std::vector ASR::decode(std::span tokens, +std::vector ASR::decode(std::span tokens, uint64_t startPos, std::span encoderOutput) const { std::vector tokenShape = {1, static_cast(tokens.size())}; - auto tokensLong = std::vector(tokens.begin(), tokens.end()); + std::vector positionShape = {static_cast(tokens.size())}; auto tokenTensor = executorch::extension::make_tensor_ptr( - tokenShape, tokensLong.data(), ScalarType::Long); + tokenShape, tokens.data(), ScalarType::Long); + + // Populate cache position vector + std::vector cachePositions(tokens.size()); + std::iota(cachePositions.begin(), cachePositions.end(), startPos); + auto positionTensor = executorch::extension::make_tensor_ptr( + positionShape, cachePositions.data(), ScalarType::Long); const auto encoderOutputSize = static_cast(encoderOutput.size()); std::vector encShape = {1, ASR::kNumFrames, @@ -333,7 +348,7 @@ std::vector ASR::decode(std::span tokens, std::move(encShape), encoderOutput.data(), ScalarType::Float); const auto decoderResult = - this->decoder->forward({tokenTensor, encoderTensor}); + this->decoder->forward({tokenTensor, positionTensor, encoderTensor}); if (!decoderResult.ok()) { throw RnExecutorchError(decoderResult.error(), diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h index 16a2f45e6..7bc2e0e0f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h @@ -17,7 +17,7 @@ class ASR { transcribe(std::span waveform, const types::DecodingOptions &options) const; std::vector encode(std::span waveform) const; - std::vector decode(std::span tokens, + std::vector decode(std::span tokens, uint64_t startPos, std::span encoderOutput) const; private: diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 6e76e52b7..3a9eb0ce6 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -418,8 +418,10 @@ export const STYLE_TRANSFER_UDNIE = { // S2T const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/tokenizer.json`; -const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`; -const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`; +// const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`; +// const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`; +const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/resolve/kv-cache/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`; +const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/resolve/kv-cache/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`; const WHISPER_TINY_EN_ENCODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_encoder_xnnpack.pte`; const WHISPER_TINY_EN_DECODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_decoder_xnnpack.pte`; From ea943e4f19b38a6f39af10f0158fb84851155fee Mon Sep 17 00:00:00 2001 From: IgorSwat Date: Tue, 17 Feb 2026 15:19:27 +0100 Subject: [PATCH 2/6] Refactor STT native implementation --- .../host_objects/JsiConversions.h | 6 +- .../common/rnexecutorch/models/BaseModel.h | 2 +- .../models/speech_to_text/SpeechToText.cpp | 81 ++-- .../models/speech_to_text/SpeechToText.h | 36 +- .../models/speech_to_text/asr/ASR.h | 65 --- .../models/speech_to_text/common/schema/ASR.h | 39 ++ .../speech_to_text/common/schema/OnlineASR.h | 44 ++ .../{ => common}/types/DecodingOptions.h | 4 +- .../common/types/GenerationResult.h | 14 + .../common/types/ProcessResult.h | 14 + .../{ => common}/types/Segment.h | 8 +- .../speech_to_text/common/types/Token.h | 9 + .../{ => common}/types/TranscriptionResult.h | 4 +- .../models/speech_to_text/common/types/Word.h | 13 + .../stream/HypothesisBuffer.cpp | 82 ---- .../speech_to_text/stream/HypothesisBuffer.h | 25 - .../stream/OnlineASRProcessor.cpp | 96 ---- .../stream/OnlineASRProcessor.h | 32 -- .../speech_to_text/types/GenerationResult.h | 12 - .../speech_to_text/types/ProcessResult.h | 12 - .../models/speech_to_text/types/Word.h | 13 - .../speech_to_text/{asr => whisper}/ASR.cpp | 442 ++++++++++-------- .../models/speech_to_text/whisper/ASR.h | 167 +++++++ .../models/speech_to_text/whisper/Constants.h | 35 ++ .../whisper/HypothesisBuffer.cpp | 86 ++++ .../speech_to_text/whisper/HypothesisBuffer.h | 64 +++ .../speech_to_text/whisper/OnlineASR.cpp | 113 +++++ .../models/speech_to_text/whisper/OnlineASR.h | 77 +++ .../src/constants/modelUrls.ts | 163 ++++--- .../useSpeechToText.ts | 8 +- packages/react-native-executorch/src/index.ts | 6 +- .../SpeechToTextModule.ts | 17 +- .../react-native-executorch/src/types/stt.ts | 15 +- 33 files changed, 1092 insertions(+), 712 deletions(-) delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h rename packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/{ => common}/types/DecodingOptions.h (73%) create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/GenerationResult.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/ProcessResult.h rename packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/{ => common}/types/Segment.h (51%) create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h rename packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/{ => common}/types/TranscriptionResult.h (69%) create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h rename packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/{asr => whisper}/ASR.cpp (57%) create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index df9abbdef..635c0bb96 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -18,11 +18,11 @@ #include #include #include -#include -#include +#include +#include #include -using namespace rnexecutorch::models::speech_to_text::types; +using namespace rnexecutorch::models::speech_to_text; namespace rnexecutorch::jsi_conversion { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index c40fa2569..1db62c466 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -3,12 +3,12 @@ #include #include -#include "rnexecutorch/metaprogramming/ConstructorHelpers.h" #include #include #include #include #include +#include namespace rnexecutorch { namespace models { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp index 1e2fa8ebc..931bfb2b9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp @@ -1,39 +1,38 @@ #include #include "SpeechToText.h" +#include "common/types/TranscriptionResult.h" +#include "whisper/ASR.h" +#include "whisper/OnlineASR.h" #include #include -#include namespace rnexecutorch::models::speech_to_text { -using namespace ::executorch::extension; -using namespace asr; -using namespace types; -using namespace stream; - -SpeechToText::SpeechToText(const std::string &encoderSource, - const std::string &decoderSource, +SpeechToText::SpeechToText(const std::string &modelName, + const std::string &modelSource, const std::string &tokenizerSource, std::shared_ptr callInvoker) - : callInvoker(std::move(callInvoker)), - encoder(std::make_unique(encoderSource, this->callInvoker)), - decoder(std::make_unique(decoderSource, this->callInvoker)), - tokenizer(std::make_unique(tokenizerSource, - this->callInvoker)), - asr(std::make_unique(this->encoder.get(), this->decoder.get(), - this->tokenizer.get())), - processor(std::make_unique(this->asr.get())), - isStreaming(false), readyToProcess(false) {} - -void SpeechToText::unload() noexcept { - this->encoder->unload(); - this->decoder->unload(); + : callInvoker_(std::move(callInvoker)), isStreaming_(false), + readyToProcess_(false) { + // Switch between the ASR implementations based on model name + if (modelName == "whisper") { + transcriber_ = std::make_unique(modelSource, tokenizerSource, + callInvoker_); + streamer_ = std::make_unique( + static_cast(transcriber_.get())); + } else { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::InvalidConfig, + "[SpeechToText]: Invalid model name: " + modelName); + } } +void SpeechToText::unload() noexcept { transcriber_->unload(); } + std::shared_ptr SpeechToText::encode(std::span waveform) const { - std::vector encoderOutput = this->asr->encode(waveform); + std::vector encoderOutput = transcriber_->encode(waveform); return std::make_shared(encoderOutput); } @@ -41,7 +40,7 @@ std::shared_ptr SpeechToText::decode(std::span tokens, std::span encoderOutput) const { std::vector decoderOutput = - this->asr->decode(tokens, 0, encoderOutput); + transcriber_->decode(tokens, encoderOutput); return std::make_shared(decoderOutput); } @@ -49,7 +48,7 @@ TranscriptionResult SpeechToText::transcribe(std::span waveform, std::string languageOption, bool verbose) const { DecodingOptions options(languageOption, verbose); - std::vector segments = this->asr->transcribe(waveform, options); + std::vector segments = transcriber_->transcribe(waveform, options); std::string fullText; for (const auto &segment : segments) { @@ -71,8 +70,7 @@ TranscriptionResult SpeechToText::transcribe(std::span waveform, } size_t SpeechToText::getMemoryLowerBound() const noexcept { - return this->encoder->getMemoryLowerBound() + - this->decoder->getMemoryLowerBound(); + return transcriber_->getMemoryLowerBound(); } namespace { @@ -106,7 +104,7 @@ TranscriptionResult wordsToResult(const std::vector &words, void SpeechToText::stream(std::shared_ptr callback, std::string languageOption, bool verbose) { - if (this->isStreaming) { + if (isStreaming_) { throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress, "Streaming is already in progress!"); } @@ -116,7 +114,7 @@ void SpeechToText::stream(std::shared_ptr callback, const TranscriptionResult &nonCommitted, bool isDone) { // This moves execution to the JS thread - this->callInvoker->invokeAsync( + callInvoker_->invokeAsync( [callback, committed, nonCommitted, isDone, verbose](jsi::Runtime &rt) { jsi::Value jsiCommitted = rnexecutorch::jsi_conversion::getJsiValue(committed, rt); @@ -128,17 +126,16 @@ void SpeechToText::stream(std::shared_ptr callback, }); }; - this->isStreaming = true; + isStreaming_ = true; DecodingOptions options(languageOption, verbose); - while (this->isStreaming) { - if (!this->readyToProcess || - this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) { + while (isStreaming_) { + if (!readyToProcess_ || !streamer_->ready()) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); continue; } - ProcessResult res = this->processor->processIter(options); + ProcessResult res = streamer_->process(options); TranscriptionResult cRes = wordsToResult(res.committed, languageOption, verbose); @@ -146,28 +143,28 @@ void SpeechToText::stream(std::shared_ptr callback, wordsToResult(res.nonCommitted, languageOption, verbose); nativeCallback(cRes, ncRes, false); - this->readyToProcess = false; + readyToProcess_ = false; } - std::vector finalWords = this->processor->finish(); + std::vector finalWords = streamer_->finish(); TranscriptionResult finalRes = wordsToResult(finalWords, languageOption, verbose); nativeCallback(finalRes, {}, true); - this->resetStreamState(); + resetStreamState(); } -void SpeechToText::streamStop() { this->isStreaming = false; } +void SpeechToText::streamStop() { isStreaming_ = false; } void SpeechToText::streamInsert(std::span waveform) { - this->processor->insertAudioChunk(waveform); - this->readyToProcess = true; + streamer_->insertAudioChunk(waveform); + readyToProcess_ = true; } void SpeechToText::resetStreamState() { - this->isStreaming = false; - this->readyToProcess = false; - this->processor = std::make_unique(this->asr.get()); + isStreaming_ = false; + readyToProcess_ = false; + streamer_->reset(); } } // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h index f9156040d..fbd53b6db 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h @@ -1,19 +1,21 @@ #pragma once -#include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h" -#include #include #include #include +#include "common/schema/ASR.h" +#include "common/schema/OnlineASR.h" +#include "common/types/TranscriptionResult.h" + namespace rnexecutorch { namespace models::speech_to_text { class SpeechToText { public: - explicit SpeechToText(const std::string &encoderSource, - const std::string &decoderSource, + explicit SpeechToText(const std::string &modelName, + const std::string &modelSource, const std::string &tokenizerSource, std::shared_ptr callInvoker); @@ -25,9 +27,9 @@ class SpeechToText { "Registered non-void function")]] std::shared_ptr decode(std::span tokens, std::span encoderOutput) const; [[nodiscard("Registered non-void function")]] - types::TranscriptionResult transcribe(std::span waveform, - std::string languageOption, - bool verbose) const; + TranscriptionResult transcribe(std::span waveform, + std::string languageOption, + bool verbose) const; [[nodiscard("Registered non-void function")]] std::vector transcribeStringOnly(std::span waveform, @@ -42,20 +44,18 @@ class SpeechToText { void streamInsert(std::span waveform); private: - std::shared_ptr callInvoker; - std::unique_ptr encoder; - std::unique_ptr decoder; - std::unique_ptr tokenizer; - std::unique_ptr asr; + // Helper functions + void resetStreamState(); - // Stream - std::unique_ptr processor; - bool isStreaming; - bool readyToProcess; + std::shared_ptr callInvoker_; - constexpr static int32_t kMinAudioSamples = 16000; // 1 second + // ASR-like module (both static transcription & streaming) + std::unique_ptr transcriber_ = nullptr; - void resetStreamState(); + // Online ASR-like module (streaming only) + std::unique_ptr streamer_ = nullptr; + bool isStreaming_ = false; + bool readyToProcess_ = true; }; } // namespace models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h deleted file mode 100644 index 7bc2e0e0f..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once - -#include "rnexecutorch/TokenizerModule.h" -#include "rnexecutorch/models/BaseModel.h" -#include "rnexecutorch/models/speech_to_text/types/DecodingOptions.h" -#include "rnexecutorch/models/speech_to_text/types/GenerationResult.h" -#include "rnexecutorch/models/speech_to_text/types/Segment.h" - -namespace rnexecutorch::models::speech_to_text::asr { - -class ASR { -public: - explicit ASR(const models::BaseModel *encoder, - const models::BaseModel *decoder, - const TokenizerModule *tokenizer); - std::vector - transcribe(std::span waveform, - const types::DecodingOptions &options) const; - std::vector encode(std::span waveform) const; - std::vector decode(std::span tokens, uint64_t startPos, - std::span encoderOutput) const; - -private: - const models::BaseModel *encoder; - const models::BaseModel *decoder; - const TokenizerModule *tokenizer; - - uint64_t startOfTranscriptionToken; - uint64_t endOfTranscriptionToken; - uint64_t timestampBeginToken; - - // Time precision used by Whisper timestamps: each token spans 0.02 seconds - constexpr static float kTimePrecision = 0.02f; - // The maximum number of tokens the decoder can generate per chunk - constexpr static int32_t kMaxDecodeLength = 128; - // Maximum duration of each audio chunk to process (in seconds) - // It is intentionally set to 29 since otherwise only the last chunk would be - // correctly transcribe due to the model's positional encoding limit - constexpr static int32_t kChunkSize = 29; - // Sampling rate expected by Whisper and the model's audio pipeline (16 kHz) - constexpr static int32_t kSamplingRate = 16000; - // Minimum allowed chunk length before processing (in audio samples) - constexpr static int32_t kMinChunkSamples = 1 * 16000; - // Number of mel frames output by the encoder (derived from input spectrogram) - constexpr static int32_t kNumFrames = 1500; - - std::vector - getInitialSequence(const types::DecodingOptions &options) const; - types::GenerationResult generate(std::span waveform, float temperature, - const types::DecodingOptions &options) const; - std::vector - generateWithFallback(std::span waveform, - const types::DecodingOptions &options) const; - std::vector - calculateWordLevelTimestamps(std::span generatedTokens, - const std::span waveform, - float avgLogProb, float temperature, - float compressionRatio) const; - std::vector - estimateWordLevelTimestampsLinear(std::span tokens, - uint64_t start, uint64_t end) const; - float getCompressionRatio(const std::string &text) const; -}; - -} // namespace rnexecutorch::models::speech_to_text::asr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h new file mode 100644 index 000000000..0390a7dfc --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include + +#include "../types/DecodingOptions.h" +#include "../types/Segment.h" +#include + +namespace rnexecutorch::models::speech_to_text::schema { + +/** + * @brief Abstract base class for Automatic Speech Recognition (ASR) models. + * + * Provides a unified interface for speech-to-text models like Whisper, allowing + * for transcription of raw audio waveforms into text segments, as well as + * access to lower-level model components like encoding and decoding. + */ +class ASR { +public: + virtual ~ASR() = default; + + std::vector virtual transcribe( + std::span waveform, const DecodingOptions &options) const = 0; + + virtual std::vector encode(std::span waveform) const = 0; + + virtual std::vector decode(std::span tokens, + std::span encoderOutput, + uint64_t startPos = 0) const = 0; + + // Standard ExecuTorch model methods for compatibility with the rest of the + // API. + virtual void unload() noexcept = 0; + virtual std::size_t getMemoryLowerBound() const noexcept = 0; +}; + +} // namespace rnexecutorch::models::speech_to_text::schema \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h new file mode 100644 index 000000000..b746cf2d9 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#include "../types/DecodingOptions.h" +#include "../types/ProcessResult.h" +#include "../types/Word.h" + +namespace rnexecutorch::models::speech_to_text::schema { + +/** + * @brief Abstract base class for Online (streaming) Automatic Speech + * Recognition. + * + * Provides an interface for processing audio in chunks, allowing for real-time + * transcription results. Implementations of this interface typically maintain + * an internal audio buffer and a hypothesis buffer for incremental decoding. + * + * Requires 5 main methods to be implemented: + * - insertAudioChunk(): for expanding the collected audio (typically adding to + * a buffer) + * - ready(): returns a boolean flag indicating whether the module is ready to + * process the next iteration + * - process(): for processing a next chunk of audio (next iterartion) + * - finish(): called to finish the live transcription mode + * - reset(): resets the streaming state + */ +class OnlineASR { +public: + virtual ~OnlineASR() = default; + + virtual void insertAudioChunk(std::span audio) = 0; + + virtual bool ready() const = 0; + + virtual ProcessResult process(const DecodingOptions &options) = 0; + + virtual std::vector finish() = 0; + + virtual void reset() = 0; +}; + +} // namespace rnexecutorch::models::speech_to_text::schema diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/DecodingOptions.h similarity index 73% rename from packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h rename to packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/DecodingOptions.h index 99774cf52..de0830826 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/DecodingOptions.h @@ -3,7 +3,7 @@ #include #include -namespace rnexecutorch::models::speech_to_text::types { +namespace rnexecutorch::models::speech_to_text { struct DecodingOptions { explicit DecodingOptions(const std::string &language, bool verbose = false) @@ -14,4 +14,4 @@ struct DecodingOptions { bool verbose; }; -} // namespace rnexecutorch::models::speech_to_text::types +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/GenerationResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/GenerationResult.h new file mode 100644 index 000000000..a68566f0e --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/GenerationResult.h @@ -0,0 +1,14 @@ +#pragma once + +#include "Token.h" +#include +#include + +namespace rnexecutorch::models::speech_to_text { + +struct GenerationResult { + std::vector tokens; + std::vector scores; +}; + +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/ProcessResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/ProcessResult.h new file mode 100644 index 000000000..493b5b05d --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/ProcessResult.h @@ -0,0 +1,14 @@ +#pragma once + +#include "Word.h" +#include +#include + +namespace rnexecutorch::models::speech_to_text { + +struct ProcessResult { + std::vector committed; + std::vector nonCommitted; +}; + +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Segment.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Segment.h similarity index 51% rename from packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Segment.h rename to packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Segment.h index b673cbe6e..878589dce 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Segment.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Segment.h @@ -1,13 +1,15 @@ #pragma once +#include "Token.h" #include "Word.h" +#include #include -namespace rnexecutorch::models::speech_to_text::types { +namespace rnexecutorch::models::speech_to_text { struct Segment { std::vector words; - std::vector tokens; // Raw token IDs + std::vector tokens; // Raw token IDs float start; float end; float avgLogprob; @@ -15,4 +17,4 @@ struct Segment { float compressionRatio; }; -} // namespace rnexecutorch::models::speech_to_text::types +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h new file mode 100644 index 000000000..17c4a4091 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace rnexecutorch::models::speech_to_text { + +using Token = uint64_t; + +} // namespace rnexecutorch::models::speech_to_text \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/TranscriptionResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h similarity index 69% rename from packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/TranscriptionResult.h rename to packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h index 5fe3868da..994cdb15e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/TranscriptionResult.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h @@ -3,7 +3,7 @@ #include #include -namespace rnexecutorch::models::speech_to_text::types { +namespace rnexecutorch::models::speech_to_text { struct TranscriptionResult { std::string text; @@ -13,4 +13,4 @@ struct TranscriptionResult { std::vector segments; // Populated only if verbose=true }; -} // namespace rnexecutorch::models::speech_to_text::types \ No newline at end of file +} // namespace rnexecutorch::models::speech_to_text \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h new file mode 100644 index 000000000..c3e0b5d4c --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace rnexecutorch::models::speech_to_text { + +struct Word { + std::string content; + float start; + float end; +}; + +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp deleted file mode 100644 index 31806c126..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp +++ /dev/null @@ -1,82 +0,0 @@ -#include "HypothesisBuffer.h" - -namespace rnexecutorch::models::speech_to_text::stream { - -using namespace types; - -void HypothesisBuffer::insert(std::span newWords, float offset) { - this->fresh.clear(); - for (const auto &word : newWords) { - const float newStart = word.start + offset; - if (newStart > lastCommittedTime - 0.5f) { - this->fresh.emplace_back(word.content, newStart, word.end + offset); - } - } - - if (!this->fresh.empty() && !this->committedInBuffer.empty()) { - const float a = this->fresh.front().start; - if (std::fabs(a - lastCommittedTime) < 1.0f) { - const size_t cn = this->committedInBuffer.size(); - const size_t nn = this->fresh.size(); - const std::size_t maxCheck = std::min({cn, nn, 5}); - for (size_t i = 1; i <= maxCheck; i++) { - std::string c; - for (auto it = this->committedInBuffer.cend() - i; - it != this->committedInBuffer.cend(); ++it) { - if (!c.empty()) { - c += ' '; - } - c += it->content; - } - - std::string tail; - auto it = this->fresh.cbegin(); - for (size_t k = 0; k < i; k++, it++) { - if (!tail.empty()) { - tail += ' '; - } - tail += it->content; - } - - if (c == tail) { - this->fresh.erase(this->fresh.begin(), this->fresh.begin() + i); - break; - } - } - } - } -} - -std::deque HypothesisBuffer::flush() { - std::deque commit; - - while (!this->fresh.empty() && !this->buffer.empty()) { - if (this->fresh.front().content != this->buffer.front().content) { - break; - } - commit.push_back(this->fresh.front()); - this->buffer.pop_front(); - this->fresh.pop_front(); - } - - if (!commit.empty()) { - lastCommittedTime = commit.back().end; - } - - this->buffer = std::move(this->fresh); - this->fresh.clear(); - this->committedInBuffer.insert(this->committedInBuffer.end(), commit.begin(), - commit.end()); - return commit; -} - -void HypothesisBuffer::popCommitted(float time) { - while (!this->committedInBuffer.empty() && - this->committedInBuffer.front().end <= time) { - this->committedInBuffer.pop_front(); - } -} - -std::deque HypothesisBuffer::complete() const { return this->buffer; } - -} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h deleted file mode 100644 index cfa11fd66..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include -#include - -#include "rnexecutorch/models/speech_to_text/types/Word.h" - -namespace rnexecutorch::models::speech_to_text::stream { - -class HypothesisBuffer { -public: - void insert(std::span newWords, float offset); - std::deque flush(); - void popCommitted(float time); - std::deque complete() const; - -private: - float lastCommittedTime = 0.0f; - - std::deque committedInBuffer; - std::deque buffer; - std::deque fresh; -}; - -} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp deleted file mode 100644 index 3137d274b..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include - -#include "OnlineASRProcessor.h" - -namespace rnexecutorch::models::speech_to_text::stream { - -using namespace asr; -using namespace types; - -OnlineASRProcessor::OnlineASRProcessor(const ASR *asr) : asr(asr) {} - -void OnlineASRProcessor::insertAudioChunk(std::span audio) { - audioBuffer.insert(audioBuffer.end(), audio.begin(), audio.end()); -} - -ProcessResult OnlineASRProcessor::processIter(const DecodingOptions &options) { - std::vector res = asr->transcribe(audioBuffer, options); - - std::vector tsw; - for (const auto &segment : res) { - for (const auto &word : segment.words) { - tsw.push_back(word); - } - } - - this->hypothesisBuffer.insert(tsw, this->bufferTimeOffset); - std::deque flushed = this->hypothesisBuffer.flush(); - this->committed.insert(this->committed.end(), flushed.begin(), flushed.end()); - - constexpr int32_t chunkThresholdSec = 15; - if (static_cast(audioBuffer.size()) / - OnlineASRProcessor::kSamplingRate > - chunkThresholdSec) { - chunkCompletedSegment(res); - } - - auto move_to_vector = [](auto& container) { - return std::vector(std::make_move_iterator(container.begin()), - std::make_move_iterator(container.end())); - }; - - std::deque nonCommittedWords = this->hypothesisBuffer.complete(); - - return { move_to_vector(flushed), move_to_vector(nonCommittedWords) }; -} - -void OnlineASRProcessor::chunkCompletedSegment(std::span res) { - if (this->committed.empty()) - return; - - std::vector ends(res.size()); - std::ranges::transform(res, ends.begin(), [](const Segment &seg) { - return seg.words.back().end; - }); - - const float t = this->committed.back().end; - - if (ends.size() > 1) { - float e = ends[ends.size() - 2] + this->bufferTimeOffset; - while (ends.size() > 2 && e > t) { - ends.pop_back(); - e = ends[ends.size() - 2] + this->bufferTimeOffset; - } - if (e <= t) { - chunkAt(e); - } - } -} - -void OnlineASRProcessor::chunkAt(float time) { - this->hypothesisBuffer.popCommitted(time); - - const float cutSeconds = time - this->bufferTimeOffset; - auto startIndex = - static_cast(cutSeconds * OnlineASRProcessor::kSamplingRate); - - if (startIndex < audioBuffer.size()) { - audioBuffer.erase(audioBuffer.begin(), audioBuffer.begin() + startIndex); - } else { - audioBuffer.clear(); - } - - this->bufferTimeOffset = time; -} - -std::vector OnlineASRProcessor::finish() { - std::deque bufferDeq = this->hypothesisBuffer.complete(); - std::vector buffer(std::make_move_iterator(bufferDeq.begin()), - std::make_move_iterator(bufferDeq.end())); - - this->bufferTimeOffset += static_cast(audioBuffer.size()) / - OnlineASRProcessor::kSamplingRate; - return buffer; -} - -} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h deleted file mode 100644 index 98944bdbe..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#include "rnexecutorch/models/speech_to_text/asr/ASR.h" -#include "rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h" -#include "rnexecutorch/models/speech_to_text/types/ProcessResult.h" -#include "rnexecutorch/models/speech_to_text/types/Word.h" - -namespace rnexecutorch::models::speech_to_text::stream { - -class OnlineASRProcessor { -public: - explicit OnlineASRProcessor(const asr::ASR *asr); - - void insertAudioChunk(std::span audio); - types::ProcessResult processIter(const types::DecodingOptions &options); - std::vector finish(); - - std::vector audioBuffer; - -private: - const asr::ASR *asr; - constexpr static int32_t kSamplingRate = 16000; - - HypothesisBuffer hypothesisBuffer; - float bufferTimeOffset = 0.0f; - std::vector committed; - - void chunkCompletedSegment(std::span res); - void chunkAt(float time); -}; - -} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h deleted file mode 100644 index 83bc80dd7..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -namespace rnexecutorch::models::speech_to_text::types { - -struct GenerationResult { - std::vector tokens; - std::vector scores; -}; - -} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h deleted file mode 100644 index 681495e2a..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -namespace rnexecutorch::models::speech_to_text::types { - -struct ProcessResult { - std::vector committed; - std::vector nonCommitted; -}; - -} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h deleted file mode 100644 index 98c72f273..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include - -namespace rnexecutorch::models::speech_to_text::types { - -struct Word { - std::string content; - float start; - float end; -}; - -} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp similarity index 57% rename from packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp rename to packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp index 455061567..c9ec1d322 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp @@ -1,57 +1,259 @@ +#include #include #include -#include #include "ASR.h" -#include "executorch/extension/tensor/tensor_ptr.h" -#include "rnexecutorch/data_processing/Numerical.h" -#include "rnexecutorch/data_processing/gzip.h" +#include "Constants.h" +#include #include +#include +#include + +namespace rnexecutorch::models::speech_to_text::whisper { + +using executorch::runtime::etensor::ScalarType; + +ASR::ASR(const std::string &modelSource, const std::string &tokenizerSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, std::move(callInvoker)), schema::ASR(), + tokenizer_(std::make_unique(tokenizerSource, + this->callInvoker)), + startOfTranscriptionToken_( + tokenizer_->tokenToId(constants::tokens::kStartOfTranscript)), + endOfTranscriptionToken_( + tokenizer_->tokenToId(constants::tokens::kEndOfTranscript)), + timestampBeginToken_( + tokenizer_->tokenToId(constants::tokens::kBeginTimestamp)) {} + +/** + * Whisper inference - full transcription + */ +std::vector ASR::transcribe(std::span waveform, + const DecodingOptions &options) const { + int32_t seek = 0; + std::vector results; + + // We loop through the input audio waveform and process it in 30s chunks. + // This is determined by Whisper models strict 30s audio length requirement. + while (std::cmp_less(seek * constants::kSamplingRate, waveform.size())) { + // Calculate chunk bounds and extract the chunk. + int32_t start = seek * constants::kSamplingRate; + const auto end = + std::min(static_cast((seek + constants::kChunkSize) * + constants::kSamplingRate), + static_cast(waveform.size())); + auto chunk = waveform.subspan(start, end - start); + + if (std::cmp_less(chunk.size(), constants::kMinChunkSamples)) { + break; + } + + // Enter the processing logic. + std::vector segments = this->generate(chunk, options); + + if (segments.empty()) { + seek += constants::kChunkSize; + continue; + } + + for (auto &seg : segments) { + for (auto &w : seg.words) { + w.start += seek; + w.end += seek; + } + + seg.start += seek; + seg.end += seek; + } + + while (!segments.empty() && segments.back().words.empty()) { + segments.pop_back(); + } + + if (!segments.empty() && !segments.back().words.empty()) { + seek = static_cast(segments.back().words.back().end); + } + results.insert(results.end(), std::make_move_iterator(segments.begin()), + std::make_move_iterator(segments.end())); + } + + return results; +} + +/** + * Whisper inference - encoding phase + * + * The input is a standard audio waveform, altough it is implicitly converted + * to a log mel format inside the encoder call. + */ +std::vector ASR::encode(std::span waveform) const { + auto inputShape = {static_cast(waveform.size())}; + + const auto modelInputTensor = executorch::extension::make_tensor_ptr( + std::move(inputShape), waveform.data(), ScalarType::Float); + + const auto encoderResult = this->execute("encode", {modelInputTensor}); + + if (!encoderResult.ok()) { + throw RnExecutorchError(encoderResult.error(), + "[Whisper] The 'encode' method did not succeed. " + "Ensure the model input is correct."); + } + + const auto encoderOutputTensor = encoderResult.get().at(0).toTensor(); + const auto outputNumel = encoderOutputTensor.numel(); + + const float *const dataPtr = encoderOutputTensor.const_data_ptr(); + return {dataPtr, dataPtr + outputNumel}; +} + +/** + * Whisper inference - decoding phase + * + * An autoregressive decoder, called with increasing amount of input tokens. + */ +std::vector ASR::decode(std::span tokens, + std::span encoderOutput, + uint64_t startPos) const { + std::vector tokenShape = {1, static_cast(tokens.size())}; + std::vector positionShape = {static_cast(tokens.size())}; + + auto tokenTensor = executorch::extension::make_tensor_ptr( + tokenShape, tokens.data(), ScalarType::Long); + + // Populate cache position vector + std::vector cachePositions(tokens.size()); + std::iota(cachePositions.begin(), cachePositions.end(), startPos); + auto positionTensor = executorch::extension::make_tensor_ptr( + positionShape, cachePositions.data(), ScalarType::Long); + + const auto encoderOutputSize = static_cast(encoderOutput.size()); + std::vector encShape = {1, constants::kNumFrames, + encoderOutputSize / constants::kNumFrames}; + auto encoderTensor = executorch::extension::make_tensor_ptr( + std::move(encShape), encoderOutput.data(), ScalarType::Float); + + const auto decoderResult = + this->execute("decode", {tokenTensor, positionTensor, encoderTensor}); -namespace rnexecutorch::models::speech_to_text::asr { + if (!decoderResult.ok()) { + throw RnExecutorchError(decoderResult.error(), + "[Whisper] The 'decode' method did not succeed. " + "Ensure the model inputs are correct."); + } -using namespace types; + const auto logitsTensor = decoderResult.get().at(0).toTensor(); + const int32_t outputNumel = static_cast(logitsTensor.numel()); -ASR::ASR(const models::BaseModel *encoder, const models::BaseModel *decoder, - const TokenizerModule *tokenizer) - : encoder(encoder), decoder(decoder), tokenizer(tokenizer), - startOfTranscriptionToken( - this->tokenizer->tokenToId("<|startoftranscript|>")), - endOfTranscriptionToken(this->tokenizer->tokenToId("<|endoftext|>")), - timestampBeginToken(this->tokenizer->tokenToId("<|0.00|>")) {} + const size_t innerDim = logitsTensor.size(1); + const size_t dictSize = logitsTensor.size(2); + const float *const dataPtr = + logitsTensor.const_data_ptr() + (innerDim - 1) * dictSize; + + return {dataPtr, dataPtr + outputNumel / innerDim}; +} + +void ASR::unload() noexcept { BaseModel::unload(); } + +std::size_t ASR::getMemoryLowerBound() const noexcept { + return BaseModel::getMemoryLowerBound(); +} + +/** + * Helper functions - creating initial token IDs sequence + */ std::vector -ASR::getInitialSequence(const DecodingOptions &options) const { +ASR::createInitialSequence(const DecodingOptions &options) const { std::vector seq; - seq.push_back(this->startOfTranscriptionToken); + seq.push_back(startOfTranscriptionToken_); if (options.language.has_value()) { uint64_t langToken = - this->tokenizer->tokenToId("<|" + options.language.value() + "|>"); - uint64_t taskToken = this->tokenizer->tokenToId("<|transcribe|>"); + tokenizer_->tokenToId("<|" + options.language.value() + "|>"); + uint64_t taskToken = tokenizer_->tokenToId("<|transcribe|>"); seq.push_back(langToken); seq.push_back(taskToken); } - seq.push_back(this->timestampBeginToken); + seq.push_back(timestampBeginToken_); return seq; } -GenerationResult ASR::generate(std::span waveform, float temperature, - const DecodingOptions &options) const { +/** + * Helper functions - generation wrapper, with fallback + */ +std::vector ASR::generate(std::span waveform, + const DecodingOptions &options) const { + // A fixed pool of available temperatures + constexpr std::array temperatures = {0.0f, 0.2f, 0.4f, + 0.6f, 0.8f, 1.0f}; + + // Calculate audio features just once to save time. std::vector encoderOutput = this->encode(waveform); - std::vector sequenceIds = this->getInitialSequence(options); + std::vector bestTokens; + float bestAvgLogProb = -std::numeric_limits::infinity(); + float bestCompressionRatio = 0.0f; + float bestTemperature = 0.0f; + + for (auto t : temperatures) { + auto [tokens, scores] = + this->generate(waveform, options, t, {encoderOutput}); + + const float cumLogProb = std::transform_reduce( + scores.begin(), scores.end(), 0.0f, std::plus<>(), + [](float s) { return std::log(std::max(s, 1e-9f)); }); + + const float avgLogProb = cumLogProb / static_cast(tokens.size() + 1); + const std::string text = tokenizer_->decode(tokens, true); + const float compressionRatio = this->calculateCompressionRatio(text); + + if (avgLogProb >= -1.0f && compressionRatio < 2.4f) { + bestTokens = std::move(tokens); + bestAvgLogProb = avgLogProb; + bestCompressionRatio = compressionRatio; + bestTemperature = t; + break; + } + + if (t == temperatures.back() && bestTokens.empty()) { + bestTokens = std::move(tokens); + bestAvgLogProb = avgLogProb; + bestCompressionRatio = compressionRatio; + bestTemperature = t; + } + } + + return this->calculateWordLevelTimestamps(bestTokens, waveform, + bestAvgLogProb, bestTemperature, + bestCompressionRatio); +} + +/** + * Helper functions - generation wrapper, single-temperature inference + */ +GenerationResult +ASR::generate(std::span waveform, const DecodingOptions &options, + float temperature, + std::optional> encoderOutput) const { + std::vector encoderOutputData = !encoderOutput.has_value() + ? this->encode(waveform) + : std::vector(); + std::span encodings = encoderOutput.has_value() + ? encoderOutput.value() + : std::span(encoderOutputData); + + std::vector sequenceIds = this->createInitialSequence(options); std::vector cachedTokens = sequenceIds; const size_t initialSequenceLenght = sequenceIds.size(); std::vector scores; uint64_t startPos = 0; while (std::cmp_less_equal(startPos + sequenceIds.size(), - ASR::kMaxDecodeLength)) { - std::vector logits = - this->decode(sequenceIds, startPos, encoderOutput); + constants::kMaxDecodeLength)) { + std::vector logits = this->decode(sequenceIds, encodings, startPos); // intentionally comparing float to float // temperatures are predefined, so this is safe @@ -85,7 +287,7 @@ GenerationResult ASR::generate(std::span waveform, float temperature, cachedTokens.push_back(nextId); scores.push_back(nextProb); - if (nextId == this->endOfTranscriptionToken) { + if (nextId == endOfTranscriptionToken_) { break; } } @@ -96,74 +298,25 @@ GenerationResult ASR::generate(std::span waveform, float temperature, .scores = scores}; } -float ASR::getCompressionRatio(const std::string &text) const { - size_t compressedSize = gzip::deflateSize(text); - return static_cast(text.size()) / static_cast(compressedSize); -} - -std::vector -ASR::generateWithFallback(std::span waveform, - const DecodingOptions &options) const { - std::vector temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f}; - std::vector bestTokens; - float bestAvgLogProb = -std::numeric_limits::infinity(); - float bestCompressionRatio = 0.0f; - float bestTemperature = 0.0f; - - for (auto t : temperatures) { - auto [tokens, scores] = this->generate(waveform, t, options); - - const float cumLogProb = std::transform_reduce( - scores.begin(), scores.end(), 0.0f, std::plus<>(), - [](float s) { return std::log(std::max(s, 1e-9f)); }); - - const float avgLogProb = cumLogProb / static_cast(tokens.size() + 1); - const std::string text = this->tokenizer->decode(tokens, true); - const float compressionRatio = this->getCompressionRatio(text); - - if (avgLogProb >= -1.0f && compressionRatio < 2.4f) { - bestTokens = std::move(tokens); - bestAvgLogProb = avgLogProb; - bestCompressionRatio = compressionRatio; - bestTemperature = t; - break; - } - - if (t == temperatures.back() && bestTokens.empty()) { - bestTokens = std::move(tokens); - bestAvgLogProb = avgLogProb; - bestCompressionRatio = compressionRatio; - bestTemperature = t; - } - } - - return this->calculateWordLevelTimestamps(bestTokens, waveform, - bestAvgLogProb, bestTemperature, - bestCompressionRatio); -} - -std::vector -ASR::calculateWordLevelTimestamps(std::span generatedTokens, - const std::span waveform, - float avgLogProb, float temperature, - float compressionRatio) const { +std::vector ASR::calculateWordLevelTimestamps( + std::span generatedTokens, const std::span waveform, + float avgLogProb, float temperature, float compressionRatio) const { const size_t generatedTokensSize = generatedTokens.size(); if (generatedTokensSize < 2 || - generatedTokens[generatedTokensSize - 1] != - this->endOfTranscriptionToken || - generatedTokens[generatedTokensSize - 2] < this->timestampBeginToken) { + generatedTokens[generatedTokensSize - 1] != endOfTranscriptionToken_ || + generatedTokens[generatedTokensSize - 2] < timestampBeginToken_) { return {}; } std::vector segments; std::vector tokens; - uint64_t prevTimestamp = this->timestampBeginToken; + uint64_t prevTimestamp = timestampBeginToken_; for (size_t i = 0; i < generatedTokensSize; i++) { - if (generatedTokens[i] < this->timestampBeginToken) { + if (generatedTokens[i] < timestampBeginToken_) { tokens.push_back(generatedTokens[i]); } - if (i > 0 && generatedTokens[i - 1] >= this->timestampBeginToken && - generatedTokens[i] >= this->timestampBeginToken) { + if (i > 0 && generatedTokens[i - 1] >= timestampBeginToken_ && + generatedTokens[i] >= timestampBeginToken_) { const uint64_t start = prevTimestamp; const uint64_t end = generatedTokens[i - 1]; auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end); @@ -210,8 +363,8 @@ ASR::calculateWordLevelTimestamps(std::span generatedTokens, float scalingFactor = static_cast(waveform.size()) / - (ASR::kSamplingRate * (end - this->timestampBeginToken) * - ASR::kTimePrecision); + (constants::kSamplingRate * (end - timestampBeginToken_) * + constants::kTimePrecision); if (scalingFactor < 1.0f) { for (auto &seg : segments) { for (auto &w : seg.words) { @@ -228,7 +381,7 @@ std::vector ASR::estimateWordLevelTimestampsLinear(std::span tokens, uint64_t start, uint64_t end) const { const std::vector tokensVec(tokens.begin(), tokens.end()); - const std::string segmentText = this->tokenizer->decode(tokensVec, true); + const std::string segmentText = tokenizer_->decode(tokensVec, true); std::istringstream iss(segmentText); std::vector wordsStr; std::string word; @@ -241,9 +394,10 @@ ASR::estimateWordLevelTimestampsLinear(std::span tokens, for (const auto &w : wordsStr) { numChars += w.size(); } - const float duration = (end - start) * ASR::kTimePrecision; + const float duration = (end - start) * constants::kTimePrecision; const float timePerChar = duration / std::max(1, numChars); - const float startOffset = (start - timestampBeginToken) * ASR::kTimePrecision; + const float startOffset = + (start - timestampBeginToken_) * constants::kTimePrecision; std::vector wordObjs; wordObjs.reserve(wordsStr.size()); @@ -259,113 +413,9 @@ ASR::estimateWordLevelTimestampsLinear(std::span tokens, return wordObjs; } -std::vector ASR::transcribe(std::span waveform, - const DecodingOptions &options) const { - int32_t seek = 0; - std::vector results; - - while (std::cmp_less(seek * ASR::kSamplingRate, waveform.size())) { - int32_t start = seek * ASR::kSamplingRate; - const auto end = std::min( - static_cast((seek + ASR::kChunkSize) * ASR::kSamplingRate), - static_cast(waveform.size())); - auto chunk = waveform.subspan(start, end - start); - - if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) { - break; - } - - std::vector segments = this->generateWithFallback(chunk, options); - - if (segments.empty()) { - seek += ASR::kChunkSize; - continue; - } - - for (auto &seg : segments) { - for (auto &w : seg.words) { - w.start += seek; - w.end += seek; - } - - seg.start += seek; - seg.end += seek; - } - - while (!segments.empty() && segments.back().words.empty()) { - segments.pop_back(); - } - - if (!segments.empty() && !segments.back().words.empty()) { - seek = static_cast(segments.back().words.back().end); - } - results.insert(results.end(), std::make_move_iterator(segments.begin()), - std::make_move_iterator(segments.end())); - } - - return results; -} - -std::vector ASR::encode(std::span waveform) const { - auto inputShape = {static_cast(waveform.size())}; - - const auto modelInputTensor = executorch::extension::make_tensor_ptr( - std::move(inputShape), waveform.data(), - executorch::runtime::etensor::ScalarType::Float); - const auto encoderResult = this->encoder->forward(modelInputTensor); - - if (!encoderResult.ok()) { - throw RnExecutorchError(encoderResult.error(), - "The model's forward function did not succeed. " - "Ensure the model input is correct."); - } - - const auto decoderOutputTensor = encoderResult.get().at(0).toTensor(); - const auto outputNumel = decoderOutputTensor.numel(); - - const float *const dataPtr = decoderOutputTensor.const_data_ptr(); - return {dataPtr, dataPtr + outputNumel}; -} - -std::vector ASR::decode(std::span tokens, uint64_t startPos, - std::span encoderOutput) const { - std::vector tokenShape = {1, static_cast(tokens.size())}; - std::vector positionShape = {static_cast(tokens.size())}; - - auto tokenTensor = executorch::extension::make_tensor_ptr( - tokenShape, tokens.data(), ScalarType::Long); - - // Populate cache position vector - std::vector cachePositions(tokens.size()); - std::iota(cachePositions.begin(), cachePositions.end(), startPos); - auto positionTensor = executorch::extension::make_tensor_ptr( - positionShape, cachePositions.data(), ScalarType::Long); - - const auto encoderOutputSize = static_cast(encoderOutput.size()); - std::vector encShape = {1, ASR::kNumFrames, - encoderOutputSize / ASR::kNumFrames}; - auto encoderTensor = executorch::extension::make_tensor_ptr( - std::move(encShape), encoderOutput.data(), ScalarType::Float); - - const auto decoderResult = - this->decoder->forward({tokenTensor, positionTensor, encoderTensor}); - - if (!decoderResult.ok()) { - throw RnExecutorchError(decoderResult.error(), - "The model's forward function did not succeed. " - "Ensure the model input is correct."); - } - - const auto logitsTensor = decoderResult.get().at(0).toTensor(); - const int32_t outputNumel = static_cast(logitsTensor.numel()); - - const size_t innerDim = logitsTensor.size(1); - const size_t dictSize = logitsTensor.size(2); - - const float *const dataPtr = - logitsTensor.const_data_ptr() + (innerDim - 1) * dictSize; - - return {dataPtr, dataPtr + outputNumel / innerDim}; +float ASR::calculateCompressionRatio(const std::string &text) const { + size_t compressedSize = gzip::deflateSize(text); + return static_cast(text.size()) / static_cast(compressedSize); } -} // namespace rnexecutorch::models::speech_to_text::asr +} // namespace rnexecutorch::models::speech_to_text::whisper diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h new file mode 100644 index 000000000..54d2d699a --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h @@ -0,0 +1,167 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "../common/schema/ASR.h" +#include "../common/types/GenerationResult.h" +#include "../common/types/Token.h" +#include +#include +#include + +namespace rnexecutorch::models::speech_to_text::whisper { + +using executorch::aten::Tensor; + +/** + * Automatic Speech Recognition (ASR) class for Whisper-based models. + * This class handles both encoding and decoding steps for Whisper family + * models, loading a single model with named entry points for "encode" and + * "decode". + */ +class ASR : public models::BaseModel, public schema::ASR { +public: + ASR(const std::string &modelSource, const std::string &tokenizerSource, + std::shared_ptr callInvoker); + + /** + * @brief The main Whisper transcription API point. + * Wrapps the entire transciption process into a single method. + * + * @param waveform Input audio waveform sampled at 16kHz, similarly to + * encode's input. + * @param options Control variables for decoding process. + */ + std::vector virtual transcribe( + std::span waveform, const DecodingOptions &options) const override; + + /** + * Encodes the input audio waveform into mel spectrogram embeddings. + * + * @param waveform Input audio waveform sampled at 16kHz. + * @return Flat vector containing the encoder's output features. + * The output tensor shape: [1, 1500, 384] for Whisper + * models. + */ + std::vector encode(std::span waveform) const override; + + /** + * Decodes a sequence of tokens into logits given the encoded audio features. + * + * @param tokens A span of token IDs from previous iteration + * (see autoregressive nature of Whisper decoding). + * @param encoderOutput A span of floats containing the precomputed encoder + * embeddings. + * @param startPos The starting position in the sequence (used for KV + * caching). + * @return A vector of floats representing the output logits for + * the next token. + */ + std::vector decode(std::span tokens, + std::span encoderOutput, + uint64_t startPos = 0) const override; + + // Standard ExecuTorch model methods for compatibility with the rest of the + // API. + void unload() noexcept override; + std::size_t getMemoryLowerBound() const noexcept override; + +private: + /** + * A helper factory for creating initial token sequences. + * + * The initial sequence consists of special tokens, such as + * language mark token or timestamp token. It's always a part + * of decoder's input. + * + * @param options Determine a specific properties of the initial sequence, + * such as whether + * to add a language mark token or not. + */ + std::vector + createInitialSequence(const DecodingOptions &options) const; + + /** + * Generation wrapper - wrapps encoding & decoding with + * temperature fallback mechanism. + * It could, in theory, run up to 5 inferences for increasing + * temperature values. + * + * @param waveform Input audio waveform sampled at 16kHz, similarly to + * encode's input. + * @param options Control variables for decoding process. + */ + std::vector generate(std::span waveform, + const DecodingOptions &options) const; + + /** + * Generation wrapper - wrapps encoding & decoding for a single, + * specific temperature value. + * Results in a single inference. + * Allows to skip the encoding phase if encoder results are already provided. + * + * @param waveform Input audio waveform sampled at 16kHz, similarly to + * encode's input. + * @param options Control variables for decoding process. + * @param temperature Controls the scale of randomization during the logits + * resolving process. + * @param encoderOutput An optional parameter. If provided, the encoding phase + * is skipped and the provided value is used instead. + */ + GenerationResult + generate(std::span waveform, const DecodingOptions &options, + float temperature, + std::optional> encoderOutput = std::nullopt) const; + + /** + * Calculates word-level timestamps for a sequence of generated tokens. + * + * This method parses the generated tokens, splits them into segments based on + * timestamp tokens, and applies a linear estimation for individual words. + * It also adjusts timestamps based on the actual waveform length. + * + * @param generatedTokens The sequence of tokens produced by the model. + * @param waveform The original audio signal used for scaling. + * @param avgLogProb Average log probability of the generated sequence. + * @param temperature Temperature used during generation. + * @param compressionRatio Text compression ratio for the generated sequence. + * @return A vector of transcribed segments with word-level + * timing. + */ + std::vector + calculateWordLevelTimestamps(std::span generatedTokens, + const std::span waveform, + float avgLogProb, float temperature, + float compressionRatio) const; + + /** + * Estimates word-level timestamps linearly within a token sequence. + * + * Decodes the tokens into words and distributes the time interval [start, + * end] across words based on their character count. + * + * @param tokens The slice of tokens representing a single segment. + * @param start The timestamp token ID marking the beginning of the segment. + * @param end The timestamp token ID marking the end of the segment. + * @return A vector of Word objects with estimated start and end times. + */ + std::vector + estimateWordLevelTimestampsLinear(std::span tokens, + uint64_t start, uint64_t end) const; + + float calculateCompressionRatio(const std::string &text) const; + + // Submodules - a tokenizer module for decoding process. + std::unique_ptr tokenizer_; + + // Tokenization helper definitions + const Token startOfTranscriptionToken_; + const Token endOfTranscriptionToken_; + const Token timestampBeginToken_; +}; + +} // namespace rnexecutorch::models::speech_to_text::whisper diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h new file mode 100644 index 000000000..1a31f4b62 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace rnexecutorch::models::speech_to_text::whisper::constants { + +// Maximum duration of each audio chunk to process (in seconds) +// It is intentionally set to 29 since otherwise only the last chunk would be +// correctly transcribe due to the model's positional encoding limit +constexpr static int32_t kChunkSize = 29; + +// The maximum number of tokens the decoder can generate per chunk +constexpr static int32_t kMaxDecodeLength = 128; + +// Minimum allowed chunk length before processing (in audio samples) +constexpr static int32_t kMinChunkSamples = 1 * 16000; + +// Number of mel frames output by the encoder (derived from input spectrogram) +constexpr static int32_t kNumFrames = 1500; + +// Sampling rate expected by Whisper and the model's audio pipeline (16 kHz) +constexpr static int32_t kSamplingRate = 16000; + +// Time precision used by Whisper timestamps: each token spans 0.02 seconds +constexpr static float kTimePrecision = 0.02f; + +// Special token constants +namespace tokens { +inline const std::string kStartOfTranscript = "<|startoftranscript|>"; +inline const std::string kEndOfTranscript = "<|endoftext|>"; +inline const std::string kBeginTimestamp = "<|0.00|>"; +} // namespace tokens + +} // namespace rnexecutorch::models::speech_to_text::whisper::constants \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp new file mode 100644 index 000000000..f077cc713 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp @@ -0,0 +1,86 @@ +#include "HypothesisBuffer.h" + +namespace rnexecutorch::models::speech_to_text::whisper::stream { + +void HypothesisBuffer::insert(std::span newWords, float offset) { + fresh_.clear(); + for (const auto &word : newWords) { + const float newStart = word.start + offset; + if (newStart > lastCommittedTime_ - 0.5f) { + fresh_.emplace_back(word.content, newStart, word.end + offset); + } + } + + if (!fresh_.empty() && !committedInBuffer_.empty()) { + const float a = fresh_.front().start; + if (std::fabs(a - lastCommittedTime_) < 1.0f) { + const size_t cn = committedInBuffer_.size(); + const size_t nn = fresh_.size(); + const std::size_t maxCheck = std::min({cn, nn, 5}); + for (size_t i = 1; i <= maxCheck; i++) { + std::string c; + for (auto it = committedInBuffer_.cend() - i; + it != committedInBuffer_.cend(); ++it) { + if (!c.empty()) { + c += ' '; + } + c += it->content; + } + + std::string tail; + auto it = fresh_.cbegin(); + for (size_t k = 0; k < i; k++, it++) { + if (!tail.empty()) { + tail += ' '; + } + tail += it->content; + } + + if (c == tail) { + fresh_.erase(fresh_.begin(), fresh_.begin() + i); + break; + } + } + } + } +} + +std::deque HypothesisBuffer::flush() { + std::deque commit; + + while (!fresh_.empty() && !buffer_.empty()) { + if (fresh_.front().content != buffer_.front().content) { + break; + } + commit.push_back(fresh_.front()); + buffer_.pop_front(); + fresh_.pop_front(); + } + + if (!commit.empty()) { + lastCommittedTime_ = commit.back().end; + } + + buffer_ = std::move(fresh_); + fresh_.clear(); + committedInBuffer_.insert(committedInBuffer_.end(), commit.begin(), + commit.end()); + return commit; +} + +void HypothesisBuffer::popCommitted(float time) { + while (!committedInBuffer_.empty() && + committedInBuffer_.front().end <= time) { + committedInBuffer_.pop_front(); + } +} + +std::deque HypothesisBuffer::complete() const { return buffer_; } + +void HypothesisBuffer::reset() { + buffer_.clear(); + fresh_.clear(); + committedInBuffer_.clear(); +} + +} // namespace rnexecutorch::models::speech_to_text::whisper::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h new file mode 100644 index 000000000..c5be66dcb --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +#include "../common/types/Word.h" + +namespace rnexecutorch::models::speech_to_text::whisper::stream { + +/** + * A buffer for managing streaming transcription hypotheses. + * This class handles stabilization of the transcription result by tracking + * "fresh" hypotheses and "committing" them once they are stable across updates. + */ +class HypothesisBuffer { +public: + /** + * Inserts new words into the hypothesis buffer. + * Words are filtered based on the last committed time and checked for + * overlaps with existing committed words to prevent duplicates. + * + * @param newWords A span of newly generated words. + * @param offset Time offset to adjust the word timestamps. + */ + void insert(std::span newWords, float offset); + + /** + * Moves stable words from the hypothesis into the committed buffer. + * It compares the new hypothesis (fresh) with the previous one (buffer) + * and returns the common prefix as committed words. + * + * @return A deque of words that have been newly committed. + */ + std::deque flush(); + + /** + * Cleans up the history of committed words up to a certain timestamp. + * + * @param time The timestamp limit; words ending before this time are removed. + */ + void popCommitted(float time); + + /** + * Retrieves the current uncommitted hypothesis. + * + * @return A deque containing the words currently in the buffer. + */ + std::deque complete() const; + + /** + * Resets all the stored buffers to the initial state + */ + void reset(); + +private: + float lastCommittedTime_ = 0.0f; + + // Stored buffers + std::deque buffer_; + std::deque fresh_; + std::deque committedInBuffer_; +}; + +} // namespace rnexecutorch::models::speech_to_text::whisper::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp new file mode 100644 index 000000000..b741567c1 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp @@ -0,0 +1,113 @@ +#include "OnlineASR.h" +#include "Constants.h" +#include +#include +#include + +namespace rnexecutorch::models::speech_to_text::whisper::stream { + +OnlineASR::OnlineASR(const ASR *asr) : asr_(asr) {} + +void OnlineASR::insertAudioChunk(std::span audio) { + audioBuffer_.insert(audioBuffer_.end(), audio.begin(), audio.end()); +} + +ProcessResult OnlineASR::process(const DecodingOptions &options) { + // Transcribe the current audio buffer. + std::vector res = asr_->transcribe(audioBuffer_, options); + + // Flatten segments into a single word sequence. + std::vector tsw; + size_t totalWords = 0; + for (const auto &segment : res) { + totalWords += segment.words.size(); + } + tsw.reserve(totalWords); + + for (const auto &segment : res) { + tsw.insert(tsw.end(), segment.words.begin(), segment.words.end()); + } + + // Update hypothesis buffer and commit stable words. + hypothesisBuffer_.insert(tsw, bufferTimeOffset_); + std::deque flushed = hypothesisBuffer_.flush(); + committed_.insert(committed_.end(), flushed.begin(), flushed.end()); + + // Prune processed audio if buffer exceeds threshold (15 seconds). + constexpr int32_t chunkThresholdSec = 15; + if (static_cast(audioBuffer_.size()) / constants::kSamplingRate > + chunkThresholdSec) { + chunkCompletedSegment(res); + } + + auto move_to_vector = [](std::deque &container) { + return std::vector(std::make_move_iterator(container.begin()), + std::make_move_iterator(container.end())); + }; + + std::deque nonCommittedWords = hypothesisBuffer_.complete(); + + return {move_to_vector(flushed), move_to_vector(nonCommittedWords)}; +} + +bool OnlineASR::ready() const { + return audioBuffer_.size() >= constants::kMinChunkSamples; +} + +void OnlineASR::chunkCompletedSegment(std::span res) { + if (committed_.empty() || res.empty()) { + return; + } + + const float lastCommittedTimestamp = committed_.back().end; + + // Search backwards for the last segment that finished before the last + // committed word. We skip the very last segment to maintain context for + // future iterations. + for (int i = static_cast(res.size()) - 2; i >= 0; --i) { + float segmentEndAbsolute = res[i].end + bufferTimeOffset_; + if (segmentEndAbsolute <= lastCommittedTimestamp) { + chunkAt(segmentEndAbsolute); + break; + } + } +} + +void OnlineASR::chunkAt(float time) { + hypothesisBuffer_.popCommitted(time); + + const float cutSeconds = time - bufferTimeOffset_; + auto startIndex = static_cast(cutSeconds * constants::kSamplingRate); + + if (startIndex < audioBuffer_.size()) { + audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + startIndex); + } else { + audioBuffer_.clear(); + } + + bufferTimeOffset_ = time; +} + +std::vector OnlineASR::finish() { + std::deque bufferDeq = hypothesisBuffer_.complete(); + std::vector buffer(std::make_move_iterator(bufferDeq.begin()), + std::make_move_iterator(bufferDeq.end())); + + bufferTimeOffset_ += + static_cast(audioBuffer_.size()) / constants::kSamplingRate; + + // Final cleanup - usually involves clearing local state for next session + reset(); + + return buffer; +} + +void OnlineASR::reset() { + hypothesisBuffer_.reset(); + bufferTimeOffset_ = 0.f; + + audioBuffer_.clear(); + committed_.clear(); +} + +} // namespace rnexecutorch::models::speech_to_text::whisper::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h new file mode 100644 index 000000000..9999d8d14 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h @@ -0,0 +1,77 @@ +#pragma once + +#include "../common/schema/OnlineASR.h" +#include "../common/types/ProcessResult.h" +#include "../common/types/Segment.h" +#include "../common/types/Word.h" +#include "ASR.h" +#include "HypothesisBuffer.h" + +namespace rnexecutorch::models::speech_to_text::whisper::stream { + +/** + * Online Automatic Speech Recognition (OnlineASR) for Whisper. + * It manages continuous processing of audio stream by maintaining a local + * audio buffer and using ASR to transcribe it in increments. + */ +class OnlineASR : public schema::OnlineASR { +public: + OnlineASR(const ASR *asr); + + /** + * Appends new audio samples to the internal processing buffer. + * + * @param audio A span of PCM float samples (expected 16kHz). + */ + void insertAudioChunk(std::span audio) override; + + /** + * Determines whether the model is ready to process the next iteration. + * + * @return True if audioBuffer has enough samples, False otherwise + */ + bool ready() const override; + + /** + * Processes the current audio buffer and returns new transcription results. + * Stability is managed by an internal HypothesisBuffer to ensure that + * only confirmed (stable) text is returned as "committed". + * + * @param options Decoding configuration (language, etc.). + * @return A ProcessResult containing newly committed and uncommitted + * words. + */ + ProcessResult process(const DecodingOptions &options) override; + + /** + * Finalizes the current streaming session. + * Flushes any remaining words from the hypothesis buffer. + * + * @return A vector of remaining transcribed words. + */ + std::vector finish() override; + + /** + * Reset the streaming state by resetting the buffers + */ + void reset() override; + +private: + // Helper functions + void chunkCompletedSegment(std::span res); + void chunkAt(float time); + + // ASR module connection for transcribing the audio + const ASR *asr_; + + // Helper buffers - audio buffer + std::vector audioBuffer_; + + // Helper buffers - hypothesis buffer + HypothesisBuffer hypothesisBuffer_; + float bufferTimeOffset_ = 0.f; + + std::vector committed_; +}; + +} // namespace rnexecutorch::models::speech_to_text::whisper::stream \ No newline at end of file diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 3a9eb0ce6..305286726 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -1,5 +1,5 @@ import { Platform } from 'react-native'; -import { URL_PREFIX, VERSION_TAG } from './versions'; +import { NEXT_VERSION_TAG, URL_PREFIX, VERSION_TAG } from './versions'; // LLMs @@ -417,104 +417,101 @@ export const STYLE_TRANSFER_UDNIE = { }; // S2T +const WHISPER_TINY_EN_MODEL = `${URL_PREFIX}-whisper-tiny.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_en_xnnpack.pte`; const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/tokenizer.json`; -// const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`; -// const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`; -const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/resolve/kv-cache/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`; -const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/resolve/kv-cache/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`; -const WHISPER_TINY_EN_ENCODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_encoder_xnnpack.pte`; -const WHISPER_TINY_EN_DECODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_decoder_xnnpack.pte`; +// const WHISPER_TINY_EN_ENCODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_encoder_xnnpack.pte`; +// const WHISPER_TINY_EN_DECODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_decoder_xnnpack.pte`; -const WHISPER_BASE_EN_TOKENIZER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/tokenizer.json`; -const WHISPER_BASE_EN_ENCODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_encoder_xnnpack.pte`; -const WHISPER_BASE_EN_DECODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_decoder_xnnpack.pte`; +// const WHISPER_BASE_EN_TOKENIZER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_BASE_EN_ENCODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_encoder_xnnpack.pte`; +// const WHISPER_BASE_EN_DECODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_decoder_xnnpack.pte`; -const WHISPER_SMALL_EN_TOKENIZER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/tokenizer.json`; -const WHISPER_SMALL_EN_ENCODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_encoder_xnnpack.pte`; -const WHISPER_SMALL_EN_DECODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_decoder_xnnpack.pte`; +// const WHISPER_SMALL_EN_TOKENIZER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_SMALL_EN_ENCODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_encoder_xnnpack.pte`; +// const WHISPER_SMALL_EN_DECODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_decoder_xnnpack.pte`; -const WHISPER_TINY_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/tokenizer.json`; -const WHISPER_TINY_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_encoder_xnnpack.pte`; -const WHISPER_TINY_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_decoder_xnnpack.pte`; +// const WHISPER_TINY_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_TINY_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_encoder_xnnpack.pte`; +// const WHISPER_TINY_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_decoder_xnnpack.pte`; -const WHISPER_BASE_TOKENIZER = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/tokenizer.json`; -const WHISPER_BASE_ENCODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_encoder_xnnpack.pte`; -const WHISPER_BASE_DECODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_decoder_xnnpack.pte`; +// const WHISPER_BASE_TOKENIZER = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_BASE_ENCODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_encoder_xnnpack.pte`; +// const WHISPER_BASE_DECODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_decoder_xnnpack.pte`; -const WHISPER_SMALL_TOKENIZER = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/tokenizer.json`; -const WHISPER_SMALL_ENCODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_encoder_xnnpack.pte`; -const WHISPER_SMALL_DECODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_decoder_xnnpack.pte`; +// const WHISPER_SMALL_TOKENIZER = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_SMALL_ENCODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_encoder_xnnpack.pte`; +// const WHISPER_SMALL_DECODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_decoder_xnnpack.pte`; /** * @category Models - Speech To Text */ export const WHISPER_TINY_EN = { + type: 'whisper' as const, isMultilingual: false, - encoderSource: WHISPER_TINY_EN_ENCODER, - decoderSource: WHISPER_TINY_EN_DECODER, + modelSource: WHISPER_TINY_EN_MODEL, tokenizerSource: WHISPER_TINY_EN_TOKENIZER, }; -/** - * @category Models - Speech To Text - */ -export const WHISPER_TINY_EN_QUANTIZED = { - isMultilingual: false, - encoderSource: WHISPER_TINY_EN_ENCODER_QUANTIZED, - decoderSource: WHISPER_TINY_EN_DECODER_QUANTIZED, - tokenizerSource: WHISPER_TINY_EN_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_BASE_EN = { - isMultilingual: false, - encoderSource: WHISPER_BASE_EN_ENCODER, - decoderSource: WHISPER_BASE_EN_DECODER, - tokenizerSource: WHISPER_BASE_EN_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_SMALL_EN = { - isMultilingual: false, - encoderSource: WHISPER_SMALL_EN_ENCODER, - decoderSource: WHISPER_SMALL_EN_DECODER, - tokenizerSource: WHISPER_SMALL_EN_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_TINY = { - isMultilingual: true, - encoderSource: WHISPER_TINY_ENCODER_MODEL, - decoderSource: WHISPER_TINY_DECODER_MODEL, - tokenizerSource: WHISPER_TINY_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_BASE = { - isMultilingual: true, - encoderSource: WHISPER_BASE_ENCODER_MODEL, - decoderSource: WHISPER_BASE_DECODER_MODEL, - tokenizerSource: WHISPER_BASE_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_SMALL = { - isMultilingual: true, - encoderSource: WHISPER_SMALL_ENCODER_MODEL, - decoderSource: WHISPER_SMALL_DECODER_MODEL, - tokenizerSource: WHISPER_SMALL_TOKENIZER, -}; +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_TINY_EN_QUANTIZED = { +// isMultilingual: false, +// encoderSource: WHISPER_TINY_EN_ENCODER_QUANTIZED, +// decoderSource: WHISPER_TINY_EN_DECODER_QUANTIZED, +// tokenizerSource: WHISPER_TINY_EN_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_BASE_EN = { +// isMultilingual: false, +// encoderSource: WHISPER_BASE_EN_ENCODER, +// decoderSource: WHISPER_BASE_EN_DECODER, +// tokenizerSource: WHISPER_BASE_EN_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_SMALL_EN = { +// isMultilingual: false, +// encoderSource: WHISPER_SMALL_EN_ENCODER, +// decoderSource: WHISPER_SMALL_EN_DECODER, +// tokenizerSource: WHISPER_SMALL_EN_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_TINY = { +// isMultilingual: true, +// encoderSource: WHISPER_TINY_ENCODER_MODEL, +// decoderSource: WHISPER_TINY_DECODER_MODEL, +// tokenizerSource: WHISPER_TINY_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_BASE = { +// isMultilingual: true, +// encoderSource: WHISPER_BASE_ENCODER_MODEL, +// decoderSource: WHISPER_BASE_DECODER_MODEL, +// tokenizerSource: WHISPER_BASE_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_SMALL = { +// isMultilingual: true, +// encoderSource: WHISPER_SMALL_ENCODER_MODEL, +// decoderSource: WHISPER_SMALL_DECODER_MODEL, +// tokenizerSource: WHISPER_SMALL_TOKENIZER, +// }; // Image segmentation const DEEPLAB_V3_RESNET50_MODEL = `${URL_PREFIX}-deeplab-v3/${VERSION_TAG}/xnnpack/deeplabV3_xnnpack_fp32.pte`; diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts index fa3d8c685..e26b6e3c7 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts @@ -38,9 +38,9 @@ export const useSpeechToText = ({ setIsReady(false); await moduleInstance.load( { + type: model.type, isMultilingual: model.isMultilingual, - encoderSource: model.encoderSource, - decoderSource: model.decoderSource, + modelSource: model.modelSource, tokenizerSource: model.tokenizerSource, }, (progress) => { @@ -59,9 +59,9 @@ export const useSpeechToText = ({ }; }, [ moduleInstance, + model.type, model.isMultilingual, - model.encoderSource, - model.decoderSource, + model.modelSource, model.tokenizerSource, preventLoad, ]); diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index a42881f45..ef89146c5 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -23,9 +23,9 @@ declare global { schedulerStepsOffset: number ) => any; var loadSpeechToText: ( - encoderSource: string, - decoderSource: string, - modelName: string + modelName: string, + modelSource: string, + tokenizerSource: string ) => any; var loadTextToSpeechKokoro: ( lang: string, diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts index 64f4e953f..22f4eb77d 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts @@ -33,26 +33,23 @@ export class SpeechToTextModule { undefined, model.tokenizerSource ); - const encoderDecoderPromise = ResourceFetcher.fetch( + const modelPromise = ResourceFetcher.fetch( onDownloadProgressCallback, - model.encoderSource, - model.decoderSource + model.modelSource ); - const [tokenizerSources, encoderDecoderResults] = await Promise.all([ + const [tokenizerSources, modelSources] = await Promise.all([ tokenizerLoadPromise, - encoderDecoderPromise, + modelPromise, ]); - const encoderSource = encoderDecoderResults?.[0]; - const decoderSource = encoderDecoderResults?.[1]; - if (!encoderSource || !decoderSource || !tokenizerSources) { + if (!modelSources || !tokenizerSources) { throw new RnExecutorchError( RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' ); } this.nativeModule = await global.loadSpeechToText( - encoderSource, - decoderSource, + model.type, + modelSources[0]!, tokenizerSources[0]! ); } diff --git a/packages/react-native-executorch/src/types/stt.ts b/packages/react-native-executorch/src/types/stt.ts index bf7fc6436..df0ab063f 100644 --- a/packages/react-native-executorch/src/types/stt.ts +++ b/packages/react-native-executorch/src/types/stt.ts @@ -10,7 +10,7 @@ export interface SpeechToTextProps { /** * Configuration object containing model sources. */ - model: SpeechToTextModelConfig; + model: SpeechToTextModelConfig; // | ... /** * Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. */ @@ -261,20 +261,19 @@ export interface TranscriptionResult { * @category Types */ export interface SpeechToTextModelConfig { + type: 'whisper'; // | ... (add more in the future) + /** * A boolean flag indicating whether the model supports multiple languages. */ isMultilingual: boolean; /** - * A string that specifies the location of a `.pte` file for the encoder. - */ - encoderSource: ResourceSource; - - /** - * A string that specifies the location of a `.pte` file for the decoder. + * A string that specifies the location of a `.pte` file for the model. + * + * We expect the model to have 2 bundled methods: 'decode' and 'encode'. */ - decoderSource: ResourceSource; + modelSource: ResourceSource; /** * A string that specifies the location to the tokenizer for the model. From b54e4692d2a0bb6a1059befe65057e79d6d9e12b Mon Sep 17 00:00:00 2001 From: IgorSwat Date: Tue, 17 Feb 2026 16:38:51 +0100 Subject: [PATCH 3/6] Fix infinite streaming in demo app --- apps/speech/screens/SpeechToTextScreen.tsx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/apps/speech/screens/SpeechToTextScreen.tsx b/apps/speech/screens/SpeechToTextScreen.tsx index 06813dfcd..f0cf57e4d 100644 --- a/apps/speech/screens/SpeechToTextScreen.tsx +++ b/apps/speech/screens/SpeechToTextScreen.tsx @@ -50,7 +50,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => { const [liveTranscribing, setLiveTranscribing] = useState(false); const scrollViewRef = useRef(null); - const recorder = new AudioRecorder(); + const recorder = useRef(new AudioRecorder()); useEffect(() => { AudioManager.setAudioSessionOptions({ @@ -115,7 +115,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => { const sampleRate = 16000; - recorder.onAudioReady( + recorder.current.onAudioReady( { sampleRate, bufferLength: 0.1 * sampleRate, @@ -131,7 +131,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => { if (!success) { console.warn('Cannot start audio session correctly'); } - const result = recorder.start(); + const result = recorder.current.start(); if (result.status === 'error') { console.warn('Recording problems: ', result.message); } @@ -177,7 +177,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => { const handleStopTranscribeFromMicrophone = () => { isRecordingRef.current = false; - recorder.stop(); + recorder.current.stop(); model.streamStop(); console.log('Live transcription stopped'); setLiveTranscribing(false); From ce5a39a38d14716b8e80a92ce09e4ae7f49bbc71 Mon Sep 17 00:00:00 2001 From: IgorSwat Date: Thu, 19 Feb 2026 12:05:21 +0100 Subject: [PATCH 4/6] Various STT streaming fixes --- .../models/speech_to_text/whisper/ASR.cpp | 38 ++++++-- .../models/speech_to_text/whisper/Constants.h | 8 +- .../whisper/HypothesisBuffer.cpp | 84 +++++++++++----- .../speech_to_text/whisper/OnlineASR.cpp | 96 ++++++++++++++++++- .../models/speech_to_text/whisper/Params.h | 46 +++++++++ .../models/speech_to_text/whisper/Utils.h | 71 ++++++++++++++ 6 files changed, 304 insertions(+), 39 deletions(-) create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp index c9ec1d322..aaf94c8cb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp @@ -4,11 +4,14 @@ #include "ASR.h" #include "Constants.h" +#include "Params.h" #include #include #include #include +#include + namespace rnexecutorch::models::speech_to_text::whisper { using executorch::runtime::etensor::ScalarType; @@ -30,18 +33,24 @@ ASR::ASR(const std::string &modelSource, const std::string &tokenizerSource, */ std::vector ASR::transcribe(std::span waveform, const DecodingOptions &options) const { - int32_t seek = 0; + // Use floats to prevent downcasting and timestamp mismatches + float seek = 0.f; std::vector results; + const float waveformSize = static_cast(waveform.size()); + const float waveformSkipBoundary = + static_cast((constants::kChunkSize - params::kChunkBreakBuffer) * + constants::kSamplingRate); + // We loop through the input audio waveform and process it in 30s chunks. // This is determined by Whisper models strict 30s audio length requirement. - while (std::cmp_less(seek * constants::kSamplingRate, waveform.size())) { + while (seek * constants::kSamplingRate < waveformSize) { // Calculate chunk bounds and extract the chunk. - int32_t start = seek * constants::kSamplingRate; + float start = seek * constants::kSamplingRate; const auto end = - std::min(static_cast((seek + constants::kChunkSize) * - constants::kSamplingRate), - static_cast(waveform.size())); + std::min(static_cast((seek + constants::kChunkSize) * + constants::kSamplingRate), + waveformSize); auto chunk = waveform.subspan(start, end - start); if (std::cmp_less(chunk.size(), constants::kMinChunkSamples)) { @@ -71,7 +80,12 @@ std::vector ASR::transcribe(std::span waveform, } if (!segments.empty() && !segments.back().words.empty()) { - seek = static_cast(segments.back().words.back().end); + // This prevents additional segments to appear, unless the audio length is + // very close to the max chunk size, that is there could be some words + // spoken near the breakpoint. + seek = waveformSize < waveformSkipBoundary + ? seek + constants::kChunkSize + : segments.back().words.back().end; } results.insert(results.end(), std::make_move_iterator(segments.begin()), std::make_move_iterator(segments.end())); @@ -226,6 +240,12 @@ std::vector ASR::generate(std::span waveform, } } + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[ASR] Raw transcription results (tokens): ", bestTokens); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[ASR] Raw transcription results (text): ", + tokenizer_->decode(bestTokens, true)); + return this->calculateWordLevelTimestamps(bestTokens, waveform, bestAvgLogProb, bestTemperature, bestCompressionRatio); @@ -323,7 +343,8 @@ std::vector ASR::calculateWordLevelTimestamps( if (words.size()) { Segment seg; seg.words = std::move(words); - seg.tokens = {}; + // seg.tokens = {}; // WTF ? + seg.tokens = tokens; seg.avgLogprob = avgLogProb; seg.temperature = temperature; seg.compressionRatio = compressionRatio; @@ -382,6 +403,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span tokens, uint64_t start, uint64_t end) const { const std::vector tokensVec(tokens.begin(), tokens.end()); const std::string segmentText = tokenizer_->decode(tokensVec, true); + std::istringstream iss(segmentText); std::vector wordsStr; std::string word; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h index 1a31f4b62..383e72769 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h @@ -21,15 +21,17 @@ constexpr static int32_t kNumFrames = 1500; // Sampling rate expected by Whisper and the model's audio pipeline (16 kHz) constexpr static int32_t kSamplingRate = 16000; +constexpr static int32_t kSamplesPerMilisecond = kSamplingRate / 1000; // Time precision used by Whisper timestamps: each token spans 0.02 seconds constexpr static float kTimePrecision = 0.02f; // Special token constants namespace tokens { -inline const std::string kStartOfTranscript = "<|startoftranscript|>"; -inline const std::string kEndOfTranscript = "<|endoftext|>"; -inline const std::string kBeginTimestamp = "<|0.00|>"; +static const std::string kStartOfTranscript = "<|startoftranscript|>"; +static const std::string kEndOfTranscript = "<|endoftext|>"; +static const std::string kBeginTimestamp = "<|0.00|>"; +static const std::string kBlankAudio = "[BLANK_AUDIO]"; } // namespace tokens } // namespace rnexecutorch::models::speech_to_text::whisper::constants \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp index f077cc713..ad10aa8a6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp @@ -1,45 +1,56 @@ #include "HypothesisBuffer.h" +#include "Params.h" +#include "Utils.h" +#include +#include namespace rnexecutorch::models::speech_to_text::whisper::stream { void HypothesisBuffer::insert(std::span newWords, float offset) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[HypothesisBuffer] Inserting " + + std::to_string(newWords.size()) + + " words with offset " + std::to_string(offset) + "s."); + fresh_.clear(); for (const auto &word : newWords) { const float newStart = word.start + offset; - if (newStart > lastCommittedTime_ - 0.5f) { + // Only accept words that start after or near the last committed time to + // avoid stale data + if (newStart > lastCommittedTime_ - params::kStreamFreshThreshold) { fresh_.emplace_back(word.content, newStart, word.end + offset); } } + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[HypothesisBuffer] Filtered " + + std::to_string(fresh_.size()) + + " words into 'fresh' buffer."); if (!fresh_.empty() && !committedInBuffer_.empty()) { const float a = fresh_.front().start; - if (std::fabs(a - lastCommittedTime_) < 1.0f) { + // Check for overlap with already committed history to avoid duplicates in + // the stream + if (std::fabs(a - lastCommittedTime_) < 2.0f) { const size_t cn = committedInBuffer_.size(); const size_t nn = fresh_.size(); - const std::size_t maxCheck = std::min({cn, nn, 5}); - for (size_t i = 1; i <= maxCheck; i++) { - std::string c; - for (auto it = committedInBuffer_.cend() - i; - it != committedInBuffer_.cend(); ++it) { - if (!c.empty()) { - c += ' '; - } - c += it->content; - } - - std::string tail; - auto it = fresh_.cbegin(); - for (size_t k = 0; k < i; k++, it++) { - if (!tail.empty()) { - tail += ' '; - } - tail += it->content; - } - - if (c == tail) { - fresh_.erase(fresh_.begin(), fresh_.begin() + i); - break; - } + + rnexecutorch::log( + rnexecutorch::LOG_LEVEL::Info, + "[HypothesisBuffer] Checking for overlap. cn=" + std::to_string(cn) + + ", nn=" + std::to_string(nn) + + ", maxCheck=" + std::to_string(params::kStreamMaxOverlapSize)); + + size_t overlapSize = utils::findLargestOverlapingFragment( + committedInBuffer_, fresh_, params::kStreamMaxOverlapSize, + params::kStreamMaxOverlapTimestampDiff); + + if (overlapSize > 0) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[HypothesisBuffer] Detected overlap of " + + std::to_string(overlapSize) + + " words with committed history. Erasing " + "duplicates from 'fresh'."); + fresh_.erase(fresh_.begin(), fresh_.begin() + overlapSize); } } } @@ -48,6 +59,8 @@ void HypothesisBuffer::insert(std::span newWords, float offset) { std::deque HypothesisBuffer::flush() { std::deque commit; + // Find stable prefix: words that haven't changed between last and current + // iteration while (!fresh_.empty() && !buffer_.empty()) { if (fresh_.front().content != buffer_.front().content) { break; @@ -59,19 +72,36 @@ std::deque HypothesisBuffer::flush() { if (!commit.empty()) { lastCommittedTime_ = commit.back().end; + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[HypothesisBuffer] Found stable prefix. Committing " + + std::to_string(commit.size()) + + " words. New lastCommittedTime: " + + std::to_string(lastCommittedTime_) + "s."); } + // Current 'fresh' (remaining) becomes the new 'buffer' for next iteration + // comparison buffer_ = std::move(fresh_); fresh_.clear(); + committedInBuffer_.insert(committedInBuffer_.end(), commit.begin(), commit.end()); + return commit; } void HypothesisBuffer::popCommitted(float time) { + size_t count = 0; while (!committedInBuffer_.empty() && committedInBuffer_.front().end <= time) { committedInBuffer_.pop_front(); + count++; + } + if (count > 0) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[HypothesisBuffer] Popped " + std::to_string(count) + + " old words from committed history up to " + + std::to_string(time) + "s."); } } @@ -81,6 +111,8 @@ void HypothesisBuffer::reset() { buffer_.clear(); fresh_.clear(); committedInBuffer_.clear(); + + lastCommittedTime_ = 0.f; } } // namespace rnexecutorch::models::speech_to_text::whisper::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp index b741567c1..a11e6e9ad 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp @@ -3,19 +3,54 @@ #include #include #include +#include + +#include namespace rnexecutorch::models::speech_to_text::whisper::stream { +namespace { +std::string wordsToString(const auto &words) { + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < words.size(); ++i) { + ss << "'" << words[i].content << "' (" << words[i].start << "s - " + << words[i].end << "s)"; + if (i < words.size() - 1) + ss << ", "; + } + ss << "]"; + return ss.str(); +} +} // namespace + OnlineASR::OnlineASR(const ASR *asr) : asr_(asr) {} void OnlineASR::insertAudioChunk(std::span audio) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Inserting audio chunk of size: " + + std::to_string(audio.size())); audioBuffer_.insert(audioBuffer_.end(), audio.begin(), audio.end()); } ProcessResult OnlineASR::process(const DecodingOptions &options) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Starting process iteration..."); + // Transcribe the current audio buffer. + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Running transcription on audio buffer..."); std::vector res = asr_->transcribe(audioBuffer_, options); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Transcription returned " + + std::to_string(res.size()) + " segments."); + for (size_t i = 0; i < res.size(); ++i) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + " Segment " + std::to_string(i) + ": " + + wordsToString(res[i].words)); + } + // Flatten segments into a single word sequence. std::vector tsw; size_t totalWords = 0; @@ -27,16 +62,42 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) { for (const auto &segment : res) { tsw.insert(tsw.end(), segment.words.begin(), segment.words.end()); } + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Flattened transcription into " + + std::to_string(tsw.size()) + " words."); // Update hypothesis buffer and commit stable words. + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Inserting " + std::to_string(tsw.size()) + + " words into hypothesis buffer with offset " + + std::to_string(bufferTimeOffset_) + "s."); hypothesisBuffer_.insert(tsw, bufferTimeOffset_); + std::deque flushed = hypothesisBuffer_.flush(); + if (!flushed.empty()) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Flushed " + std::to_string(flushed.size()) + + " stable words: " + wordsToString(flushed)); + } else { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] No stable words flushed this iteration."); + } + committed_.insert(committed_.end(), flushed.begin(), flushed.end()); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Total committed words history size: " + + std::to_string(committed_.size())); // Prune processed audio if buffer exceeds threshold (15 seconds). + const float audioDuration = + static_cast(audioBuffer_.size()) / constants::kSamplingRate; constexpr int32_t chunkThresholdSec = 15; - if (static_cast(audioBuffer_.size()) / constants::kSamplingRate > - chunkThresholdSec) { + if (audioDuration > chunkThresholdSec) { + rnexecutorch::log( + rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Audio buffer duration (" + std::to_string(audioDuration) + + "s) exceeds threshold (" + std::to_string(chunkThresholdSec) + + "s). Triggering pruning check."); chunkCompletedSegment(res); } @@ -46,7 +107,12 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) { }; std::deque nonCommittedWords = hypothesisBuffer_.complete(); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Current hypothesis (non-committed): " + + wordsToString(nonCommittedWords)); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Finished process iteration."); return {move_to_vector(flushed), move_to_vector(nonCommittedWords)}; } @@ -56,6 +122,8 @@ bool OnlineASR::ready() const { void OnlineASR::chunkCompletedSegment(std::span res) { if (committed_.empty() || res.empty()) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Pruning skipped: insufficient data."); return; } @@ -67,6 +135,10 @@ void OnlineASR::chunkCompletedSegment(std::span res) { for (int i = static_cast(res.size()) - 2; i >= 0; --i) { float segmentEndAbsolute = res[i].end + bufferTimeOffset_; if (segmentEndAbsolute <= lastCommittedTimestamp) { + rnexecutorch::log( + rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Found stable pruning point at absolute time: " + + std::to_string(segmentEndAbsolute) + "s."); chunkAt(segmentEndAbsolute); break; } @@ -74,21 +146,35 @@ void OnlineASR::chunkCompletedSegment(std::span res) { } void OnlineASR::chunkAt(float time) { + rnexecutorch::log( + rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Pruning buffers at time: " + std::to_string(time) + "s."); + hypothesisBuffer_.popCommitted(time); const float cutSeconds = time - bufferTimeOffset_; auto startIndex = static_cast(cutSeconds * constants::kSamplingRate); if (startIndex < audioBuffer_.size()) { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Erasing " + std::to_string(startIndex) + + " audio samples."); audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + startIndex); } else { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Clearing entire audio buffer."); audioBuffer_.clear(); } bufferTimeOffset_ = time; + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] New buffer time offset: " + + std::to_string(bufferTimeOffset_) + "s."); } std::vector OnlineASR::finish() { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Finalizing streaming session."); std::deque bufferDeq = hypothesisBuffer_.complete(); std::vector buffer(std::make_move_iterator(bufferDeq.begin()), std::make_move_iterator(bufferDeq.end())); @@ -96,6 +182,10 @@ std::vector OnlineASR::finish() { bufferTimeOffset_ += static_cast(audioBuffer_.size()) / constants::kSamplingRate; + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Stream finished. Final hypothesis words: " + + wordsToString(buffer)); + // Final cleanup - usually involves clearing local state for next session reset(); @@ -103,6 +193,8 @@ std::vector OnlineASR::finish() { } void OnlineASR::reset() { + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[OnlineASR] Resetting streaming state."); hypothesisBuffer_.reset(); bufferTimeOffset_ = 0.f; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h new file mode 100644 index 000000000..1c8e5ec3a --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h @@ -0,0 +1,46 @@ +#pragma once + +#include + +/** + * Hyperparameters + * + * Those are adjustable values, which when changed, affect the behavior + * of the underlying model and/or algorithms. + */ +namespace rnexecutorch::models::speech_to_text::whisper::params { + +/** + * Determines the range of buffer left when skipping an audio chunk + * of size lower than maximum allowed chunk size. + * + * If the audio length does not exceed [kChunkSize * kSamplingRate] - [buffer], + * then instead of moving to the last returned timestamp, we jump across the + * entire 30 seconds chunk. This resolves the issue of multiple redundant + * segments being produced by the transcription algorithm. + */ +constexpr static int32_t kChunkBreakBuffer = 2; // [s] + +/** + * Determines the maximum timestamp difference available for a word to be + * considered as fresh in streaming algorithm. + */ +constexpr static float kStreamFreshThreshold = 1.F; // [s], originally 0.5 + +/** + * Determines the maximum expected size of overlapping fragments between + * fresh words buffer and commited words buffer in streaming mode. + * + * It is a limit of maximum amount of erased repeated words from fresh buffer. + * The bigger it gets, the less probable it is to commit the same phrase twice. + */ +constexpr static size_t kStreamMaxOverlapSize = + 10; // Number of overlaping words + +/** + * Similar to kMaxStreamOverlapSize, but this one determines + * the maximum allowed timestamp difference between the overlaping fragments. + */ +constexpr static float kStreamMaxOverlapTimestampDiff = 5.F; // [s] + +} // namespace rnexecutorch::models::speech_to_text::whisper::params \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h new file mode 100644 index 000000000..aec1aa1f5 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h @@ -0,0 +1,71 @@ +#pragma once + +#include "../common/types/Word.h" +#include +#include +#include +#include + +namespace rnexecutorch::models::speech_to_text::whisper::utils { + +/** + * Finds the largest (in number of words) overlaping fragment between word + * vectors A (suffix) and B (prefix). + * + * An overlaping fragment is any fragment C, which can be simultaneously a + * suffix of A and a prefix of B. Example: A = 'Jane likes food and playing + * games', B = 'playing games and sleeping', the overlap fragment C = 'playing + * games'. + * + * @param suffixVec An input vector, where only suffixes can overlap. + * Typically the 'commited' buffer in streaming algorithm. + * @param preffixVec An input vector, where only prefixes can overlap. + * Typically the 'fresh' buffer in streaming algorithm. + * @param maxCheckRange The maximum size of overlapping fragment. Determines the + * range of search. + * @param maxTimestampDiff The maximum allowed timestamp difference between + * overlaping fragments. If exceeded, the fragment are not considered as + * overlaping. + * @return The size of the largest found overlaping fragment. + */ +template +inline size_t findLargestOverlapingFragment(const Container &suffixVec, + const Container &prefixVec, + size_t maxCheckRange = 10, + float maxTimestampDiff = 100.f) { + size_t range = std::min({suffixVec.size(), prefixVec.size(), maxCheckRange}); + + if (range == 0) { + return 0; + } + + // Iterate backwards from the largest possible overlap size down to 1. + // i starts at the index where the suffix of length 'range' begins. + for (size_t i = suffixVec.size() - range; i < suffixVec.size(); ++i) { + // We search for overlaps by searching for the first word of prefixVec + if (suffixVec[i].content == prefixVec[0].content) { + size_t calculatedSize = suffixVec.size() - i; + + // Optimization: Check if the last elements match before full comparison + if (prefixVec[calculatedSize - 1].content != suffixVec.back().content) { + continue; + } + + bool isEqual = std::equal( + suffixVec.begin() + i, suffixVec.end(), prefixVec.begin(), + [maxTimestampDiff](const Word &sWord, const Word &pWord) { + return sWord.content == pWord.content && + std::fabs(sWord.start - pWord.start) <= maxTimestampDiff && + std::fabs(sWord.end - pWord.end) <= maxTimestampDiff; + }); + + if (isEqual) { + return calculatedSize; + } + } + } + + return 0; +} + +} // namespace rnexecutorch::models::speech_to_text::whisper::utils \ No newline at end of file From 278985c915f7befabbbea3ded3f377394f153648 Mon Sep 17 00:00:00 2001 From: IgorSwat Date: Fri, 20 Feb 2026 13:37:35 +0100 Subject: [PATCH 5/6] Add timestamp fix algorithm & other --- .../models/speech_to_text/SpeechToText.cpp | 2 +- .../speech_to_text/common/schema/OnlineASR.h | 2 +- .../whisper/HypothesisBuffer.cpp | 143 +++++----- .../speech_to_text/whisper/HypothesisBuffer.h | 55 ++-- .../speech_to_text/whisper/OnlineASR.cpp | 255 +++++++----------- .../models/speech_to_text/whisper/OnlineASR.h | 15 +- .../models/speech_to_text/whisper/Params.h | 12 + 7 files changed, 207 insertions(+), 277 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp index 931bfb2b9..82955b299 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp @@ -130,7 +130,7 @@ void SpeechToText::stream(std::shared_ptr callback, DecodingOptions options(languageOption, verbose); while (isStreaming_) { - if (!readyToProcess_ || !streamer_->ready()) { + if (!readyToProcess_ || !streamer_->isReady()) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); continue; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h index b746cf2d9..357309391 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h @@ -32,7 +32,7 @@ class OnlineASR { virtual void insertAudioChunk(std::span audio) = 0; - virtual bool ready() const = 0; + virtual bool isReady() const = 0; virtual ProcessResult process(const DecodingOptions &options) = 0; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp index ad10aa8a6..f497019e4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp @@ -2,115 +2,92 @@ #include "Params.h" #include "Utils.h" #include -#include namespace rnexecutorch::models::speech_to_text::whisper::stream { -void HypothesisBuffer::insert(std::span newWords, float offset) { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[HypothesisBuffer] Inserting " + - std::to_string(newWords.size()) + - " words with offset " + std::to_string(offset) + "s."); - +void HypothesisBuffer::insert(std::span words, float offset) { + // Step 1 - decide which words should be considered as fresh. + // Using less amount of fresh words saves a little bit of time, but + // could backfire in terms of quality of the final committed transcript. fresh_.clear(); - for (const auto &word : newWords) { - const float newStart = word.start + offset; - // Only accept words that start after or near the last committed time to - // avoid stale data - if (newStart > lastCommittedTime_ - params::kStreamFreshThreshold) { - fresh_.emplace_back(word.content, newStart, word.end + offset); + for (const Word &word : words) { + // Global start is a beginning timestamp relative only to the beginning of + // the current streaming process. + const float startGlobal = word.start + offset; + const float endGlobal = word.end + offset; + + // To optimize the process, we discard the words which are too old + // according to the calculated timestamp. + if (startGlobal > lastCommittedTime_ - params::kStreamFreshThreshold) { + fresh_.emplace_back(word.content, startGlobal, endGlobal); } } - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[HypothesisBuffer] Filtered " + - std::to_string(fresh_.size()) + - " words into 'fresh' buffer."); - - if (!fresh_.empty() && !committedInBuffer_.empty()) { - const float a = fresh_.front().start; - // Check for overlap with already committed history to avoid duplicates in - // the stream - if (std::fabs(a - lastCommittedTime_) < 2.0f) { - const size_t cn = committedInBuffer_.size(); - const size_t nn = fresh_.size(); - - rnexecutorch::log( - rnexecutorch::LOG_LEVEL::Info, - "[HypothesisBuffer] Checking for overlap. cn=" + std::to_string(cn) + - ", nn=" + std::to_string(nn) + - ", maxCheck=" + std::to_string(params::kStreamMaxOverlapSize)); - - size_t overlapSize = utils::findLargestOverlapingFragment( - committedInBuffer_, fresh_, params::kStreamMaxOverlapSize, - params::kStreamMaxOverlapTimestampDiff); - - if (overlapSize > 0) { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[HypothesisBuffer] Detected overlap of " + - std::to_string(overlapSize) + - " words with committed history. Erasing " - "duplicates from 'fresh'."); - fresh_.erase(fresh_.begin(), fresh_.begin() + overlapSize); - } + + // Step 2 - we have already selected the fresh words. Now it's time to + // correct any mistakes and remove the words which overlap with already + // commited segments - to avoid duplicates. + if (!fresh_.empty() && !committed_.empty()) { + const float freshSequenceStart = fresh_.front().start; + const float freshSequenceEnd = fresh_.back().end; + + // Calculate the largest overlapping fragment size. + // Note that we use size limit (kStreamMaxOverlapSize) for efficiency of the + // algorithm, and timestamp difference limit + // (kStreamMaxOverlapTimestampDiff) to avoid removing correct fragments + // which were just repeated after some time. + size_t overlapSize = utils::findLargestOverlapingFragment( + committed_, fresh_, params::kStreamMaxOverlapSize, + params::kStreamMaxOverlapTimestampDiff); + + // Remove all the overlapping words. + if (overlapSize > 0) { + fresh_.erase(fresh_.begin(), fresh_.begin() + overlapSize); } } } -std::deque HypothesisBuffer::flush() { - std::deque commit; +std::deque HypothesisBuffer::commit() { + std::deque toCommit = {}; - // Find stable prefix: words that haven't changed between last and current - // iteration - while (!fresh_.empty() && !buffer_.empty()) { - if (fresh_.front().content != buffer_.front().content) { - break; - } - commit.push_back(fresh_.front()); - buffer_.pop_front(); + // Find a stable prefix: words that haven't changed between last and current + // iteration. + while (!fresh_.empty() && !hypothesis_.empty() && + fresh_.front().content == hypothesis_.front().content) { + toCommit.emplace_back( + std::move(hypothesis_.front())); // Timestamps from the previous + // iteration tends to be more reliable fresh_.pop_front(); + hypothesis_.pop_front(); } - if (!commit.empty()) { - lastCommittedTime_ = commit.back().end; - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[HypothesisBuffer] Found stable prefix. Committing " + - std::to_string(commit.size()) + - " words. New lastCommittedTime: " + - std::to_string(lastCommittedTime_) + "s."); + // Save the last committed word timestamp. + // This will mark the end of the entire committed sequence. + if (!toCommit.empty()) { + lastCommittedTime_ = toCommit.back().end; } - // Current 'fresh' (remaining) becomes the new 'buffer' for next iteration - // comparison - buffer_ = std::move(fresh_); + // The remaining words from the fresh buffer (uncommitted phrase) + // become a hypothesis for the next iteration. + hypothesis_ = std::move(fresh_); fresh_.clear(); - committedInBuffer_.insert(committedInBuffer_.end(), commit.begin(), - commit.end()); + // The last step is to commit the selected words. + committed_.insert(committed_.end(), toCommit.cbegin(), toCommit.cend()); - return commit; + return toCommit; } -void HypothesisBuffer::popCommitted(float time) { - size_t count = 0; - while (!committedInBuffer_.empty() && - committedInBuffer_.front().end <= time) { - committedInBuffer_.pop_front(); - count++; - } - if (count > 0) { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[HypothesisBuffer] Popped " + std::to_string(count) + - " old words from committed history up to " + - std::to_string(time) + "s."); +void HypothesisBuffer::releaseCommits(size_t wordsToKeep) { + if (committed_.size() > wordsToKeep) { + size_t nWordsToErase = committed_.size() - wordsToKeep; + committed_.erase(committed_.begin(), committed_.begin() + nWordsToErase); } } -std::deque HypothesisBuffer::complete() const { return buffer_; } - void HypothesisBuffer::reset() { - buffer_.clear(); fresh_.clear(); - committedInBuffer_.clear(); + hypothesis_.clear(); + committed_.clear(); lastCommittedTime_ = 0.f; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h index c5be66dcb..226c037b7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h @@ -15,50 +15,59 @@ namespace rnexecutorch::models::speech_to_text::whisper::stream { class HypothesisBuffer { public: /** - * Inserts new words into the hypothesis buffer. + * Inserts new words into the fresh_ buffer. * Words are filtered based on the last committed time and checked for * overlaps with existing committed words to prevent duplicates. * - * @param newWords A span of newly generated words. + * @param newWords A span of recently generated words. * @param offset Time offset to adjust the word timestamps. */ - void insert(std::span newWords, float offset); + void insert(std::span words, float offset); /** - * Moves stable words from the hypothesis into the committed buffer. - * It compares the new hypothesis (fresh) with the previous one (buffer) - * and returns the common prefix as committed words. + * Attempts to commit words present in the fresh_ buffer. + * A phrase from fresh_ buffer can only be committed if it also appears + * in the hypothesis_ buffer (uncommitted words from previous iteration). * - * @return A deque of words that have been newly committed. - */ - std::deque flush(); - - /** - * Cleans up the history of committed words up to a certain timestamp. + * Uncommitted words become a 'hypothesis' and are moved into the hypothesis_ + * buffer. * - * @param time The timestamp limit; words ending before this time are removed. + * @return A sequence of words committed in the current iteration. */ - void popCommitted(float time); + std::deque commit(); /** - * Retrieves the current uncommitted hypothesis. + * Shrinks the committed_ buffer by erasing all words except N latest ones. * - * @return A deque containing the words currently in the buffer. + * Used primarily to relieve increasing memory usage during very + * long streaming sessions. + * + * @param wordsToKeep - number of trailing words to be kept in. */ - std::deque complete() const; + void releaseCommits(size_t wordsToKeep); /** - * Resets all the stored buffers to the initial state + * Resets all the stored buffers and state variables to the initial state */ void reset(); -private: - float lastCommittedTime_ = 0.0f; + // Declare a friendship with OnlineASR to allow it to access the internal + // state of stored buffers. + friend class OnlineASR; +private: // Stored buffers - std::deque buffer_; - std::deque fresh_; - std::deque committedInBuffer_; + // The lifecycle of a correct result word looks as following: + // fresh buffer -> hypothesis buffer -> commited + std::deque + fresh_; // 'New' words from current iterations, which require some checks + // before they go into hypothesis_ buffer. + std::deque + hypothesis_; // Words potentially to be commited, stored between + // iterations (obtained from fresh_ buffer). + std::deque committed_; // A history of already commited words. + + float lastCommittedTime_ = 0.0f; }; } // namespace rnexecutorch::models::speech_to_text::whisper::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp index a11e6e9ad..926f30051 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp @@ -1,205 +1,140 @@ -#include "OnlineASR.h" -#include "Constants.h" #include #include #include #include -#include +#include "Constants.h" +#include "OnlineASR.h" +#include "Params.h" namespace rnexecutorch::models::speech_to_text::whisper::stream { namespace { -std::string wordsToString(const auto &words) { - std::stringstream ss; - ss << "["; - for (size_t i = 0; i < words.size(); ++i) { - ss << "'" << words[i].content << "' (" << words[i].start << "s - " - << words[i].end << "s)"; - if (i < words.size() - 1) - ss << ", "; - } - ss << "]"; - return ss.str(); -} +// A helper function to avoid code duplication. +std::vector move_to_vector(std::deque &container) { + return std::vector(std::make_move_iterator(container.begin()), + std::make_move_iterator(container.end())); +}; } // namespace -OnlineASR::OnlineASR(const ASR *asr) : asr_(asr) {} +OnlineASR::OnlineASR(const ASR *asr) : asr_(asr) { + // Reserve a minimal expected amount of memory for audio buffer. + audioBuffer_.reserve(static_cast(2 * params::kStreamChunkThreshold * + constants::kSamplingRate)); +} void OnlineASR::insertAudioChunk(std::span audio) { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Inserting audio chunk of size: " + - std::to_string(audio.size())); audioBuffer_.insert(audioBuffer_.end(), audio.begin(), audio.end()); } +bool OnlineASR::isReady() const { + return audioBuffer_.size() >= constants::kMinChunkSamples; +} + ProcessResult OnlineASR::process(const DecodingOptions &options) { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Starting process iteration..."); - - // Transcribe the current audio buffer. - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Running transcription on audio buffer..."); - std::vector res = asr_->transcribe(audioBuffer_, options); - - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Transcription returned " + - std::to_string(res.size()) + " segments."); - for (size_t i = 0; i < res.size(); ++i) { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - " Segment " + std::to_string(i) + ": " + - wordsToString(res[i].words)); + // Perform a transcription process to obtain results for + // the current state of the audio buffer. + std::vector transcriptions = asr_->transcribe(audioBuffer_, options); + + if (transcriptions.empty()) { + return {.committed = {}, .nonCommitted = {}}; } // Flatten segments into a single word sequence. - std::vector tsw; - size_t totalWords = 0; - for (const auto &segment : res) { - totalWords += segment.words.size(); + // In this case, Word consists of text and timestamps. + std::vector words; + words.reserve(transcriptions.front().words.size()); + + // Note that we transfer the ownership of moves, so words should not be + // accessed by transcriptions.segment.words afterwards. + for (auto &segment : transcriptions) { + words.insert(words.end(), std::make_move_iterator(segment.words.begin()), + std::make_move_iterator(segment.words.end())); } - tsw.reserve(totalWords); - for (const auto &segment : res) { - tsw.insert(tsw.end(), segment.words.begin(), segment.words.end()); - } - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Flattened transcription into " + - std::to_string(tsw.size()) + " words."); - - // Update hypothesis buffer and commit stable words. - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Inserting " + std::to_string(tsw.size()) + - " words into hypothesis buffer with offset " + - std::to_string(bufferTimeOffset_) + "s."); - hypothesisBuffer_.insert(tsw, bufferTimeOffset_); - - std::deque flushed = hypothesisBuffer_.flush(); - if (!flushed.empty()) { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Flushed " + std::to_string(flushed.size()) + - " stable words: " + wordsToString(flushed)); - } else { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] No stable words flushed this iteration."); + hypothesisBuffer_.insert(words, bufferTimeOffset_); + + // Apply fix for timestamps. + // After the insert() call on hypothesis buffer, the inner fresh_ buffer + // contains either a completely new words or words which overlap only + // with the inner hypothesis_ buffer. + if (!hypothesisBuffer_.fresh_.empty()) { + const float establishedEnd = !hypothesisBuffer_.hypothesis_.empty() + ? hypothesisBuffer_.hypothesis_.back().end + : !hypothesisBuffer_.committed_.empty() + ? hypothesisBuffer_.committed_.back().end + : 0.F; + const float newEnd = + std::max(establishedEnd, hypothesisBuffer_.fresh_.back().end); + float newBegin = hypothesisBuffer_.fresh_.front().start; + for (size_t i = 0; i < hypothesisBuffer_.fresh_.size(); i++) { + // If the word overlaps with the hypothesis, we can simply copy the + // timestamps from the previous iteration (that is, from the hypothesis + // inner buffer). + if (i < hypothesisBuffer_.hypothesis_.size() && + hypothesisBuffer_.fresh_[i].content == + hypothesisBuffer_.hypothesis_[i].content) { + hypothesisBuffer_.fresh_[i].start = + hypothesisBuffer_.hypothesis_[i].start; + hypothesisBuffer_.fresh_[i].end = hypothesisBuffer_.hypothesis_[i].end; + + newBegin = hypothesisBuffer_.fresh_[i].end; + } + + // In case of a new word, we apply timestamp range scaling + // based on timestamps established in previous iterations. + // The idea is, that both ranges of established and completely new words + // should sum up to the final timestamp produced by the model. + // TODO: estimate the epsilon value + float scale = (newEnd - establishedEnd - 0.2f) / (newEnd - newBegin); + hypothesisBuffer_.fresh_[i].start = + (hypothesisBuffer_.fresh_[i].start - newEnd) * scale + newEnd; + hypothesisBuffer_.fresh_[i].end = + (hypothesisBuffer_.fresh_[i].end - newEnd) * scale + newEnd; + } } - committed_.insert(committed_.end(), flushed.begin(), flushed.end()); - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Total committed words history size: " + - std::to_string(committed_.size())); + // Commit matching words. + auto committed = hypothesisBuffer_.commit(); + auto nonCommitted = hypothesisBuffer_.hypothesis_; - // Prune processed audio if buffer exceeds threshold (15 seconds). + // Cut the audio buffer to not exceed the size threshold. + // Since Whisper does not accept waveforms longer than 30 seconds, we need + // to cut the audio at some safe point. const float audioDuration = static_cast(audioBuffer_.size()) / constants::kSamplingRate; - constexpr int32_t chunkThresholdSec = 15; - if (audioDuration > chunkThresholdSec) { - rnexecutorch::log( - rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Audio buffer duration (" + std::to_string(audioDuration) + - "s) exceeds threshold (" + std::to_string(chunkThresholdSec) + - "s). Triggering pruning check."); - chunkCompletedSegment(res); + if (audioDuration > params::kStreamChunkThreshold) { + // Leave some portion of audio in, to improve model behavior + // in future iterations. + const float erasedDuration = + audioDuration - params::kStreamAudioBufferReserve; + const size_t nSamplesToErase = + static_cast(erasedDuration * constants::kSamplingRate); + + audioBuffer_.erase(audioBuffer_.begin(), + audioBuffer_.begin() + nSamplesToErase); + bufferTimeOffset_ += erasedDuration; } - auto move_to_vector = [](std::deque &container) { - return std::vector(std::make_move_iterator(container.begin()), - std::make_move_iterator(container.end())); - }; - - std::deque nonCommittedWords = hypothesisBuffer_.complete(); - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Current hypothesis (non-committed): " + - wordsToString(nonCommittedWords)); - - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Finished process iteration."); - return {move_to_vector(flushed), move_to_vector(nonCommittedWords)}; -} - -bool OnlineASR::ready() const { - return audioBuffer_.size() >= constants::kMinChunkSamples; -} - -void OnlineASR::chunkCompletedSegment(std::span res) { - if (committed_.empty() || res.empty()) { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Pruning skipped: insufficient data."); - return; - } - - const float lastCommittedTimestamp = committed_.back().end; - - // Search backwards for the last segment that finished before the last - // committed word. We skip the very last segment to maintain context for - // future iterations. - for (int i = static_cast(res.size()) - 2; i >= 0; --i) { - float segmentEndAbsolute = res[i].end + bufferTimeOffset_; - if (segmentEndAbsolute <= lastCommittedTimestamp) { - rnexecutorch::log( - rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Found stable pruning point at absolute time: " + - std::to_string(segmentEndAbsolute) + "s."); - chunkAt(segmentEndAbsolute); - break; - } - } -} - -void OnlineASR::chunkAt(float time) { - rnexecutorch::log( - rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Pruning buffers at time: " + std::to_string(time) + "s."); - - hypothesisBuffer_.popCommitted(time); - - const float cutSeconds = time - bufferTimeOffset_; - auto startIndex = static_cast(cutSeconds * constants::kSamplingRate); - - if (startIndex < audioBuffer_.size()) { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Erasing " + std::to_string(startIndex) + - " audio samples."); - audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + startIndex); - } else { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Clearing entire audio buffer."); - audioBuffer_.clear(); - } - - bufferTimeOffset_ = time; - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] New buffer time offset: " + - std::to_string(bufferTimeOffset_) + "s."); + return {.committed = move_to_vector(committed), + .nonCommitted = move_to_vector(nonCommitted)}; } std::vector OnlineASR::finish() { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Finalizing streaming session."); - std::deque bufferDeq = hypothesisBuffer_.complete(); - std::vector buffer(std::make_move_iterator(bufferDeq.begin()), - std::make_move_iterator(bufferDeq.end())); - - bufferTimeOffset_ += - static_cast(audioBuffer_.size()) / constants::kSamplingRate; - - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Stream finished. Final hypothesis words: " + - wordsToString(buffer)); + // We always push the last remaining hypothesis, even if it's not + // confirmed in second iteration. + auto remaining = hypothesisBuffer_.hypothesis_; - // Final cleanup - usually involves clearing local state for next session reset(); - return buffer; + return move_to_vector(remaining); } void OnlineASR::reset() { - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, - "[OnlineASR] Resetting streaming state."); hypothesisBuffer_.reset(); bufferTimeOffset_ = 0.f; audioBuffer_.clear(); - committed_.clear(); } } // namespace rnexecutorch::models::speech_to_text::whisper::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h index 9999d8d14..ea1389ec7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h @@ -30,7 +30,7 @@ class OnlineASR : public schema::OnlineASR { * * @return True if audioBuffer has enough samples, False otherwise */ - bool ready() const override; + bool isReady() const override; /** * Processes the current audio buffer and returns new transcription results. @@ -57,21 +57,18 @@ class OnlineASR : public schema::OnlineASR { void reset() override; private: - // Helper functions - void chunkCompletedSegment(std::span res); - void chunkAt(float time); - // ASR module connection for transcribing the audio const ASR *asr_; // Helper buffers - audio buffer - std::vector audioBuffer_; + // Stores the increasing amounts of streamed audio. + // Cleared from time to time after reaching a threshold size. + std::vector audioBuffer_ = {}; + float bufferTimeOffset_ = 0.f; // Audio buffer offset // Helper buffers - hypothesis buffer + // Manages the whisper streaming hypothesis mechanism. HypothesisBuffer hypothesisBuffer_; - float bufferTimeOffset_ = 0.f; - - std::vector committed_; }; } // namespace rnexecutorch::models::speech_to_text::whisper::stream \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h index 1c8e5ec3a..b8ffa0896 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h @@ -43,4 +43,16 @@ constexpr static size_t kStreamMaxOverlapSize = */ constexpr static float kStreamMaxOverlapTimestampDiff = 5.F; // [s] +/** + * A threshold which exceeded causes the main streaming audio buffer to be + * cleared. + */ +constexpr static float kStreamChunkThreshold = 20.F; // [s] + +/** + * Decides how much of recent audio waveform should be kept in when + * clearing the audio buffer in streaming algorithm. + */ +constexpr static float kStreamAudioBufferReserve = 5.F; // [s] + } // namespace rnexecutorch::models::speech_to_text::whisper::params \ No newline at end of file From 2a3786783319c9d080b4f9a54efb3515c548a305 Mon Sep 17 00:00:00 2001 From: IgorSwat Date: Fri, 20 Feb 2026 18:26:58 +0100 Subject: [PATCH 6/6] Fix punctation comparision issue --- .../host_objects/JsiConversions.h | 3 +- .../models/speech_to_text/SpeechToText.cpp | 2 +- .../models/speech_to_text/common/types/Word.h | 3 ++ .../models/speech_to_text/whisper/ASR.cpp | 12 +++++++- .../models/speech_to_text/whisper/Constants.h | 5 ++++ .../whisper/HypothesisBuffer.cpp | 22 +++++++++----- .../speech_to_text/whisper/OnlineASR.cpp | 29 ++++++++++++++----- .../models/speech_to_text/whisper/OnlineASR.h | 5 ++++ .../models/speech_to_text/whisper/Params.h | 4 +-- 9 files changed, 65 insertions(+), 20 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 635c0bb96..8226db71b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -507,7 +507,8 @@ inline jsi::Value getJsiValue(const Segment &seg, jsi::Runtime &runtime) { jsi::Object wordObj(runtime); wordObj.setProperty( runtime, "word", - jsi::String::createFromUtf8(runtime, seg.words[i].content)); + jsi::String::createFromUtf8(runtime, seg.words[i].content + + seg.words[i].punctations)); wordObj.setProperty(runtime, "start", static_cast(seg.words[i].start)); wordObj.setProperty(runtime, "end", static_cast(seg.words[i].end)); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp index 82955b299..acd17cb0b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp @@ -82,7 +82,7 @@ TranscriptionResult wordsToResult(const std::vector &words, std::string fullText; for (const auto &w : words) { - fullText += w.content; + fullText += w.content + w.punctations; } res.text = fullText; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h index c3e0b5d4c..9de04a9c5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h @@ -8,6 +8,9 @@ struct Word { std::string content; float start; float end; + + std::string punctations = + ""; // Trailing punctations which appear after the main content }; } // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp index aaf94c8cb..09562088e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -429,7 +430,16 @@ ASR::estimateWordLevelTimestampsLinear(std::span tokens, const float wStart = startOffset + prevCharCount * timePerChar; const float wEnd = wStart + timePerChar * wSize; prevCharCount += wSize; - wordObjs.emplace_back(std::move(w), wStart, wEnd); + + // We store punctations separately to other characters. + std::string puncts = ""; + while (!w.empty() && constants::kPunctations.contains(w.back())) { + puncts += w.back(); + w.pop_back(); + } + std::reverse(puncts.begin(), puncts.end()); + + wordObjs.emplace_back(std::move(w), wStart, wEnd, std::move(puncts)); } return wordObjs; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h index 383e72769..1d2c40e93 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h @@ -2,6 +2,7 @@ #include #include +#include namespace rnexecutorch::models::speech_to_text::whisper::constants { @@ -26,6 +27,10 @@ constexpr static int32_t kSamplesPerMilisecond = kSamplingRate / 1000; // Time precision used by Whisper timestamps: each token spans 0.02 seconds constexpr static float kTimePrecision = 0.02f; +// Special characters serving as pause / end of sentence +static const std::unordered_set kPunctations = {',', '.', '?', + '!', ':', ';'}; + // Special token constants namespace tokens { static const std::string kStartOfTranscript = "<|startoftranscript|>"; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp index f497019e4..45b24f64c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp @@ -19,7 +19,8 @@ void HypothesisBuffer::insert(std::span words, float offset) { // To optimize the process, we discard the words which are too old // according to the calculated timestamp. if (startGlobal > lastCommittedTime_ - params::kStreamFreshThreshold) { - fresh_.emplace_back(word.content, startGlobal, endGlobal); + fresh_.emplace_back(word.content, startGlobal, endGlobal, + word.punctations); } } @@ -27,9 +28,6 @@ void HypothesisBuffer::insert(std::span words, float offset) { // correct any mistakes and remove the words which overlap with already // commited segments - to avoid duplicates. if (!fresh_.empty() && !committed_.empty()) { - const float freshSequenceStart = fresh_.front().start; - const float freshSequenceEnd = fresh_.back().end; - // Calculate the largest overlapping fragment size. // Note that we use size limit (kStreamMaxOverlapSize) for efficiency of the // algorithm, and timestamp difference limit @@ -53,9 +51,19 @@ std::deque HypothesisBuffer::commit() { // iteration. while (!fresh_.empty() && !hypothesis_.empty() && fresh_.front().content == hypothesis_.front().content) { - toCommit.emplace_back( - std::move(hypothesis_.front())); // Timestamps from the previous - // iteration tends to be more reliable + // The last word from the fresh_ buffer must also match punctations with the + // hypothesis. This is done in order to ensure correct punctation marks in + // the resulting transcription. + if (fresh_.size() == 1 && + fresh_.front().punctations != hypothesis_.front().punctations) { + break; + } + + // Take timestamps from the hypothesis, but actual content from the fresh + // buffer. + toCommit.emplace_back(std::move(fresh_.front().content), + hypothesis_.front().start, hypothesis_.front().end, + std::move(fresh_.front().punctations)); fresh_.pop_front(); hypothesis_.pop_front(); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp index 926f30051..755c958f9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp @@ -25,6 +25,9 @@ OnlineASR::OnlineASR(const ASR *asr) : asr_(asr) { void OnlineASR::insertAudioChunk(std::span audio) { audioBuffer_.insert(audioBuffer_.end(), audio.begin(), audio.end()); + + // Update the epsilon accordingly to the amount of added audio. + epsilon_ += static_cast(audio.size()) / constants::kSamplingRate; } bool OnlineASR::isReady() const { @@ -59,14 +62,12 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) { // contains either a completely new words or words which overlap only // with the inner hypothesis_ buffer. if (!hypothesisBuffer_.fresh_.empty()) { - const float establishedEnd = !hypothesisBuffer_.hypothesis_.empty() - ? hypothesisBuffer_.hypothesis_.back().end - : !hypothesisBuffer_.committed_.empty() - ? hypothesisBuffer_.committed_.back().end - : 0.F; - const float newEnd = - std::max(establishedEnd, hypothesisBuffer_.fresh_.back().end); + float establishedEnd = !hypothesisBuffer_.committed_.empty() + ? hypothesisBuffer_.committed_.back().end + : 0.F; + const float newEnd = hypothesisBuffer_.fresh_.back().end; float newBegin = hypothesisBuffer_.fresh_.front().start; + for (size_t i = 0; i < hypothesisBuffer_.fresh_.size(); i++) { // If the word overlaps with the hypothesis, we can simply copy the // timestamps from the previous iteration (that is, from the hypothesis @@ -78,7 +79,10 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) { hypothesisBuffer_.hypothesis_[i].start; hypothesisBuffer_.fresh_[i].end = hypothesisBuffer_.hypothesis_[i].end; + establishedEnd = hypothesisBuffer_.hypothesis_[i].end; newBegin = hypothesisBuffer_.fresh_[i].end; + + continue; } // In case of a new word, we apply timestamp range scaling @@ -86,12 +90,21 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) { // The idea is, that both ranges of established and completely new words // should sum up to the final timestamp produced by the model. // TODO: estimate the epsilon value - float scale = (newEnd - establishedEnd - 0.2f) / (newEnd - newBegin); + const float beforeScaleStart = hypothesisBuffer_.fresh_[i].start; + const float beforeScaleEnd = hypothesisBuffer_.fresh_[i].end; + float scale = + (newEnd - establishedEnd) / (newEnd - newBegin); // missing epsilon + // float scale = (newEnd - establishedEnd - epsilon) / (newEnd - + // newBegin); // correct hypothesisBuffer_.fresh_[i].start = (hypothesisBuffer_.fresh_[i].start - newEnd) * scale + newEnd; hypothesisBuffer_.fresh_[i].end = (hypothesisBuffer_.fresh_[i].end - newEnd) * scale + newEnd; } + + // Before committing, save the last timestamp produced by the model. + // That is, the ending timestamp of the last fresh word. + lastNonSilentMoment_ = newEnd; } // Commit matching words. diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h index ea1389ec7..085fcc140 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h @@ -69,6 +69,11 @@ class OnlineASR : public schema::OnlineASR { // Helper buffers - hypothesis buffer // Manages the whisper streaming hypothesis mechanism. HypothesisBuffer hypothesisBuffer_; + + // Silence estimation components + // Used primarily in timestamp range scaling algorithm. + float epsilon_ = 0.F; + float lastNonSilentMoment_ = 0.F; }; } // namespace rnexecutorch::models::speech_to_text::whisper::stream \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h index b8ffa0896..17607669f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h @@ -35,13 +35,13 @@ constexpr static float kStreamFreshThreshold = 1.F; // [s], originally 0.5 * The bigger it gets, the less probable it is to commit the same phrase twice. */ constexpr static size_t kStreamMaxOverlapSize = - 10; // Number of overlaping words + 12; // Number of overlaping words /** * Similar to kMaxStreamOverlapSize, but this one determines * the maximum allowed timestamp difference between the overlaping fragments. */ -constexpr static float kStreamMaxOverlapTimestampDiff = 5.F; // [s] +constexpr static float kStreamMaxOverlapTimestampDiff = 15.F; // [s] /** * A threshold which exceeded causes the main streaming audio buffer to be