diff --git a/apps/speech/screens/SpeechToTextScreen.tsx b/apps/speech/screens/SpeechToTextScreen.tsx index 06813dfcd..f0cf57e4d 100644 --- a/apps/speech/screens/SpeechToTextScreen.tsx +++ b/apps/speech/screens/SpeechToTextScreen.tsx @@ -50,7 +50,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => { const [liveTranscribing, setLiveTranscribing] = useState(false); const scrollViewRef = useRef(null); - const recorder = new AudioRecorder(); + const recorder = useRef(new AudioRecorder()); useEffect(() => { AudioManager.setAudioSessionOptions({ @@ -115,7 +115,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => { const sampleRate = 16000; - recorder.onAudioReady( + recorder.current.onAudioReady( { sampleRate, bufferLength: 0.1 * sampleRate, @@ -131,7 +131,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => { if (!success) { console.warn('Cannot start audio session correctly'); } - const result = recorder.start(); + const result = recorder.current.start(); if (result.status === 'error') { console.warn('Recording problems: ', result.message); } @@ -177,7 +177,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => { const handleStopTranscribeFromMicrophone = () => { isRecordingRef.current = false; - recorder.stop(); + recorder.current.stop(); model.streamStop(); console.log('Live transcription stopped'); setLiveTranscribing(false); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index df9abbdef..8226db71b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -18,11 +18,11 @@ #include #include #include -#include -#include +#include +#include #include -using namespace rnexecutorch::models::speech_to_text::types; +using namespace rnexecutorch::models::speech_to_text; namespace rnexecutorch::jsi_conversion { @@ -507,7 +507,8 @@ inline jsi::Value getJsiValue(const Segment &seg, jsi::Runtime &runtime) { jsi::Object wordObj(runtime); wordObj.setProperty( runtime, "word", - jsi::String::createFromUtf8(runtime, seg.words[i].content)); + jsi::String::createFromUtf8(runtime, seg.words[i].content + + seg.words[i].punctations)); wordObj.setProperty(runtime, "start", static_cast(seg.words[i].start)); wordObj.setProperty(runtime, "end", static_cast(seg.words[i].end)); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index c40fa2569..1db62c466 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -3,12 +3,12 @@ #include #include -#include "rnexecutorch/metaprogramming/ConstructorHelpers.h" #include #include #include #include #include +#include namespace rnexecutorch { namespace models { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp index eef6d562c..acd17cb0b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp @@ -1,46 +1,46 @@ #include #include "SpeechToText.h" +#include "common/types/TranscriptionResult.h" +#include "whisper/ASR.h" +#include "whisper/OnlineASR.h" #include #include -#include namespace rnexecutorch::models::speech_to_text { -using namespace ::executorch::extension; -using namespace asr; -using namespace types; -using namespace stream; - -SpeechToText::SpeechToText(const std::string &encoderSource, - const std::string &decoderSource, +SpeechToText::SpeechToText(const std::string &modelName, + const std::string &modelSource, const std::string &tokenizerSource, std::shared_ptr callInvoker) - : callInvoker(std::move(callInvoker)), - encoder(std::make_unique(encoderSource, this->callInvoker)), - decoder(std::make_unique(decoderSource, this->callInvoker)), - tokenizer(std::make_unique(tokenizerSource, - this->callInvoker)), - asr(std::make_unique(this->encoder.get(), this->decoder.get(), - this->tokenizer.get())), - processor(std::make_unique(this->asr.get())), - isStreaming(false), readyToProcess(false) {} - -void SpeechToText::unload() noexcept { - this->encoder->unload(); - this->decoder->unload(); + : callInvoker_(std::move(callInvoker)), isStreaming_(false), + readyToProcess_(false) { + // Switch between the ASR implementations based on model name + if (modelName == "whisper") { + transcriber_ = std::make_unique(modelSource, tokenizerSource, + callInvoker_); + streamer_ = std::make_unique( + static_cast(transcriber_.get())); + } else { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::InvalidConfig, + "[SpeechToText]: Invalid model name: " + modelName); + } } +void SpeechToText::unload() noexcept { transcriber_->unload(); } + std::shared_ptr SpeechToText::encode(std::span waveform) const { - std::vector encoderOutput = this->asr->encode(waveform); + std::vector encoderOutput = transcriber_->encode(waveform); return std::make_shared(encoderOutput); } std::shared_ptr SpeechToText::decode(std::span tokens, std::span encoderOutput) const { - std::vector decoderOutput = this->asr->decode(tokens, encoderOutput); + std::vector decoderOutput = + transcriber_->decode(tokens, encoderOutput); return std::make_shared(decoderOutput); } @@ -48,7 +48,7 @@ TranscriptionResult SpeechToText::transcribe(std::span waveform, std::string languageOption, bool verbose) const { DecodingOptions options(languageOption, verbose); - std::vector segments = this->asr->transcribe(waveform, options); + std::vector segments = transcriber_->transcribe(waveform, options); std::string fullText; for (const auto &segment : segments) { @@ -70,8 +70,7 @@ TranscriptionResult SpeechToText::transcribe(std::span waveform, } size_t SpeechToText::getMemoryLowerBound() const noexcept { - return this->encoder->getMemoryLowerBound() + - this->decoder->getMemoryLowerBound(); + return transcriber_->getMemoryLowerBound(); } namespace { @@ -83,7 +82,7 @@ TranscriptionResult wordsToResult(const std::vector &words, std::string fullText; for (const auto &w : words) { - fullText += w.content; + fullText += w.content + w.punctations; } res.text = fullText; @@ -105,7 +104,7 @@ TranscriptionResult wordsToResult(const std::vector &words, void SpeechToText::stream(std::shared_ptr callback, std::string languageOption, bool verbose) { - if (this->isStreaming) { + if (isStreaming_) { throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress, "Streaming is already in progress!"); } @@ -115,7 +114,7 @@ void SpeechToText::stream(std::shared_ptr callback, const TranscriptionResult &nonCommitted, bool isDone) { // This moves execution to the JS thread - this->callInvoker->invokeAsync( + callInvoker_->invokeAsync( [callback, committed, nonCommitted, isDone, verbose](jsi::Runtime &rt) { jsi::Value jsiCommitted = rnexecutorch::jsi_conversion::getJsiValue(committed, rt); @@ -127,17 +126,16 @@ void SpeechToText::stream(std::shared_ptr callback, }); }; - this->isStreaming = true; + isStreaming_ = true; DecodingOptions options(languageOption, verbose); - while (this->isStreaming) { - if (!this->readyToProcess || - this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) { + while (isStreaming_) { + if (!readyToProcess_ || !streamer_->isReady()) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); continue; } - ProcessResult res = this->processor->processIter(options); + ProcessResult res = streamer_->process(options); TranscriptionResult cRes = wordsToResult(res.committed, languageOption, verbose); @@ -145,28 +143,28 @@ void SpeechToText::stream(std::shared_ptr callback, wordsToResult(res.nonCommitted, languageOption, verbose); nativeCallback(cRes, ncRes, false); - this->readyToProcess = false; + readyToProcess_ = false; } - std::vector finalWords = this->processor->finish(); + std::vector finalWords = streamer_->finish(); TranscriptionResult finalRes = wordsToResult(finalWords, languageOption, verbose); nativeCallback(finalRes, {}, true); - this->resetStreamState(); + resetStreamState(); } -void SpeechToText::streamStop() { this->isStreaming = false; } +void SpeechToText::streamStop() { isStreaming_ = false; } void SpeechToText::streamInsert(std::span waveform) { - this->processor->insertAudioChunk(waveform); - this->readyToProcess = true; + streamer_->insertAudioChunk(waveform); + readyToProcess_ = true; } void SpeechToText::resetStreamState() { - this->isStreaming = false; - this->readyToProcess = false; - this->processor = std::make_unique(this->asr.get()); + isStreaming_ = false; + readyToProcess_ = false; + streamer_->reset(); } } // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h index f9156040d..fbd53b6db 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h @@ -1,19 +1,21 @@ #pragma once -#include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h" -#include #include #include #include +#include "common/schema/ASR.h" +#include "common/schema/OnlineASR.h" +#include "common/types/TranscriptionResult.h" + namespace rnexecutorch { namespace models::speech_to_text { class SpeechToText { public: - explicit SpeechToText(const std::string &encoderSource, - const std::string &decoderSource, + explicit SpeechToText(const std::string &modelName, + const std::string &modelSource, const std::string &tokenizerSource, std::shared_ptr callInvoker); @@ -25,9 +27,9 @@ class SpeechToText { "Registered non-void function")]] std::shared_ptr decode(std::span tokens, std::span encoderOutput) const; [[nodiscard("Registered non-void function")]] - types::TranscriptionResult transcribe(std::span waveform, - std::string languageOption, - bool verbose) const; + TranscriptionResult transcribe(std::span waveform, + std::string languageOption, + bool verbose) const; [[nodiscard("Registered non-void function")]] std::vector transcribeStringOnly(std::span waveform, @@ -42,20 +44,18 @@ class SpeechToText { void streamInsert(std::span waveform); private: - std::shared_ptr callInvoker; - std::unique_ptr encoder; - std::unique_ptr decoder; - std::unique_ptr tokenizer; - std::unique_ptr asr; + // Helper functions + void resetStreamState(); - // Stream - std::unique_ptr processor; - bool isStreaming; - bool readyToProcess; + std::shared_ptr callInvoker_; - constexpr static int32_t kMinAudioSamples = 16000; // 1 second + // ASR-like module (both static transcription & streaming) + std::unique_ptr transcriber_ = nullptr; - void resetStreamState(); + // Online ASR-like module (streaming only) + std::unique_ptr streamer_ = nullptr; + bool isStreaming_ = false; + bool readyToProcess_ = true; }; } // namespace models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp deleted file mode 100644 index 2ed41ff22..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +++ /dev/null @@ -1,356 +0,0 @@ -#include -#include - -#include "ASR.h" -#include "executorch/extension/tensor/tensor_ptr.h" -#include "rnexecutorch/data_processing/Numerical.h" -#include "rnexecutorch/data_processing/gzip.h" -#include - -namespace rnexecutorch::models::speech_to_text::asr { - -using namespace types; - -ASR::ASR(const models::BaseModel *encoder, const models::BaseModel *decoder, - const TokenizerModule *tokenizer) - : encoder(encoder), decoder(decoder), tokenizer(tokenizer), - startOfTranscriptionToken( - this->tokenizer->tokenToId("<|startoftranscript|>")), - endOfTranscriptionToken(this->tokenizer->tokenToId("<|endoftext|>")), - timestampBeginToken(this->tokenizer->tokenToId("<|0.00|>")) {} - -std::vector -ASR::getInitialSequence(const DecodingOptions &options) const { - std::vector seq; - seq.push_back(this->startOfTranscriptionToken); - - if (options.language.has_value()) { - uint64_t langToken = - this->tokenizer->tokenToId("<|" + options.language.value() + "|>"); - uint64_t taskToken = this->tokenizer->tokenToId("<|transcribe|>"); - seq.push_back(langToken); - seq.push_back(taskToken); - } - - seq.push_back(this->timestampBeginToken); - - return seq; -} - -GenerationResult ASR::generate(std::span waveform, float temperature, - const DecodingOptions &options) const { - std::vector encoderOutput = this->encode(waveform); - - std::vector sequenceIds = this->getInitialSequence(options); - const size_t initialSequenceLenght = sequenceIds.size(); - std::vector scores; - - while (std::cmp_less_equal(sequenceIds.size(), ASR::kMaxDecodeLength)) { - std::vector logits = this->decode(sequenceIds, encoderOutput); - - // intentionally comparing float to float - // temperatures are predefined, so this is safe - if (temperature == 0.0f) { - numerical::softmax(logits); - } else { - numerical::softmaxWithTemperature(logits, temperature); - } - - const std::vector &probs = logits; - - uint64_t nextId; - float nextProb; - - // intentionally comparing float to float - // temperatures are predefined, so this is safe - if (temperature == 0.0f) { - auto maxIt = std::ranges::max_element(probs); - nextId = static_cast(std::distance(probs.begin(), maxIt)); - nextProb = *maxIt; - } else { - std::discrete_distribution<> dist(probs.begin(), probs.end()); - std::mt19937 gen((std::random_device{}())); - nextId = dist(gen); - nextProb = probs[nextId]; - } - - sequenceIds.push_back(nextId); - scores.push_back(nextProb); - - if (nextId == this->endOfTranscriptionToken) { - break; - } - } - - return {.tokens = std::vector( - sequenceIds.cbegin() + initialSequenceLenght, sequenceIds.cend()), - .scores = scores}; -} - -float ASR::getCompressionRatio(const std::string &text) const { - size_t compressedSize = gzip::deflateSize(text); - return static_cast(text.size()) / static_cast(compressedSize); -} - -std::vector -ASR::generateWithFallback(std::span waveform, - const DecodingOptions &options) const { - std::vector temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f}; - std::vector bestTokens; - float bestAvgLogProb = -std::numeric_limits::infinity(); - float bestCompressionRatio = 0.0f; - float bestTemperature = 0.0f; - - for (auto t : temperatures) { - auto [tokens, scores] = this->generate(waveform, t, options); - - const float cumLogProb = std::transform_reduce( - scores.begin(), scores.end(), 0.0f, std::plus<>(), - [](float s) { return std::log(std::max(s, 1e-9f)); }); - - const float avgLogProb = cumLogProb / static_cast(tokens.size() + 1); - const std::string text = this->tokenizer->decode(tokens, true); - const float compressionRatio = this->getCompressionRatio(text); - - if (avgLogProb >= -1.0f && compressionRatio < 2.4f) { - bestTokens = std::move(tokens); - bestAvgLogProb = avgLogProb; - bestCompressionRatio = compressionRatio; - bestTemperature = t; - break; - } - - if (t == temperatures.back() && bestTokens.empty()) { - bestTokens = std::move(tokens); - bestAvgLogProb = avgLogProb; - bestCompressionRatio = compressionRatio; - bestTemperature = t; - } - } - - return this->calculateWordLevelTimestamps(bestTokens, waveform, - bestAvgLogProb, bestTemperature, - bestCompressionRatio); -} - -std::vector -ASR::calculateWordLevelTimestamps(std::span generatedTokens, - const std::span waveform, - float avgLogProb, float temperature, - float compressionRatio) const { - const size_t generatedTokensSize = generatedTokens.size(); - if (generatedTokensSize < 2 || - generatedTokens[generatedTokensSize - 1] != - this->endOfTranscriptionToken || - generatedTokens[generatedTokensSize - 2] < this->timestampBeginToken) { - return {}; - } - std::vector segments; - std::vector tokens; - uint64_t prevTimestamp = this->timestampBeginToken; - - for (size_t i = 0; i < generatedTokensSize; i++) { - if (generatedTokens[i] < this->timestampBeginToken) { - tokens.push_back(generatedTokens[i]); - } - if (i > 0 && generatedTokens[i - 1] >= this->timestampBeginToken && - generatedTokens[i] >= this->timestampBeginToken) { - const uint64_t start = prevTimestamp; - const uint64_t end = generatedTokens[i - 1]; - auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end); - if (words.size()) { - Segment seg; - seg.words = std::move(words); - seg.tokens = {}; - seg.avgLogprob = avgLogProb; - seg.temperature = temperature; - seg.compressionRatio = compressionRatio; - - if (!seg.words.empty()) { - seg.start = seg.words.front().start; - seg.end = seg.words.back().end; - } else { - seg.start = 0.0; - seg.end = 0.0; - } - - segments.push_back(std::move(seg)); - } - tokens.clear(); - prevTimestamp = generatedTokens[i]; - } - } - - const uint64_t start = prevTimestamp; - const uint64_t end = generatedTokens[generatedTokensSize - 2]; - auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end); - - Segment seg; - seg.words = std::move(words); - seg.tokens = tokens; - seg.avgLogprob = avgLogProb; - seg.temperature = temperature; - seg.compressionRatio = compressionRatio; - - if (!seg.words.empty()) { - seg.start = seg.words.front().start; - seg.end = seg.words.back().end; - } - - segments.push_back(std::move(seg)); - - float scalingFactor = - static_cast(waveform.size()) / - (ASR::kSamplingRate * (end - this->timestampBeginToken) * - ASR::kTimePrecision); - if (scalingFactor < 1.0f) { - for (auto &seg : segments) { - for (auto &w : seg.words) { - w.start *= scalingFactor; - w.end *= scalingFactor; - } - } - } - - return segments; -} - -std::vector -ASR::estimateWordLevelTimestampsLinear(std::span tokens, - uint64_t start, uint64_t end) const { - const std::vector tokensVec(tokens.begin(), tokens.end()); - const std::string segmentText = this->tokenizer->decode(tokensVec, true); - std::istringstream iss(segmentText); - std::vector wordsStr; - std::string word; - while (iss >> word) { - wordsStr.emplace_back(" "); - wordsStr.back().append(word); - } - - size_t numChars = 0; - for (const auto &w : wordsStr) { - numChars += w.size(); - } - const float duration = (end - start) * ASR::kTimePrecision; - const float timePerChar = duration / std::max(1, numChars); - const float startOffset = (start - timestampBeginToken) * ASR::kTimePrecision; - - std::vector wordObjs; - wordObjs.reserve(wordsStr.size()); - int32_t prevCharCount = 0; - for (auto &w : wordsStr) { - const auto wSize = static_cast(w.size()); - const float wStart = startOffset + prevCharCount * timePerChar; - const float wEnd = wStart + timePerChar * wSize; - prevCharCount += wSize; - wordObjs.emplace_back(std::move(w), wStart, wEnd); - } - - return wordObjs; -} - -std::vector ASR::transcribe(std::span waveform, - const DecodingOptions &options) const { - int32_t seek = 0; - std::vector results; - - while (std::cmp_less(seek * ASR::kSamplingRate, waveform.size())) { - int32_t start = seek * ASR::kSamplingRate; - const auto end = std::min( - static_cast((seek + ASR::kChunkSize) * ASR::kSamplingRate), - static_cast(waveform.size())); - auto chunk = waveform.subspan(start, end - start); - - if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) { - break; - } - - std::vector segments = this->generateWithFallback(chunk, options); - - if (segments.empty()) { - seek += ASR::kChunkSize; - continue; - } - - for (auto &seg : segments) { - for (auto &w : seg.words) { - w.start += seek; - w.end += seek; - } - - seg.start += seek; - seg.end += seek; - } - - while (!segments.empty() && segments.back().words.empty()) { - segments.pop_back(); - } - - if (!segments.empty() && !segments.back().words.empty()) { - seek = static_cast(segments.back().words.back().end); - } - results.insert(results.end(), std::make_move_iterator(segments.begin()), - std::make_move_iterator(segments.end())); - } - - return results; -} - -std::vector ASR::encode(std::span waveform) const { - auto inputShape = {static_cast(waveform.size())}; - - const auto modelInputTensor = executorch::extension::make_tensor_ptr( - std::move(inputShape), waveform.data(), - executorch::runtime::etensor::ScalarType::Float); - const auto encoderResult = this->encoder->forward(modelInputTensor); - - if (!encoderResult.ok()) { - throw RnExecutorchError(encoderResult.error(), - "The model's forward function did not succeed. " - "Ensure the model input is correct."); - } - - const auto decoderOutputTensor = encoderResult.get().at(0).toTensor(); - const auto outputNumel = decoderOutputTensor.numel(); - - const float *const dataPtr = decoderOutputTensor.const_data_ptr(); - return {dataPtr, dataPtr + outputNumel}; -} - -std::vector ASR::decode(std::span tokens, - std::span encoderOutput) const { - std::vector tokenShape = {1, static_cast(tokens.size())}; - auto tokensLong = std::vector(tokens.begin(), tokens.end()); - - auto tokenTensor = executorch::extension::make_tensor_ptr( - tokenShape, tokensLong.data(), ScalarType::Long); - - const auto encoderOutputSize = static_cast(encoderOutput.size()); - std::vector encShape = {1, ASR::kNumFrames, - encoderOutputSize / ASR::kNumFrames}; - auto encoderTensor = executorch::extension::make_tensor_ptr( - std::move(encShape), encoderOutput.data(), ScalarType::Float); - - const auto decoderResult = - this->decoder->forward({tokenTensor, encoderTensor}); - - if (!decoderResult.ok()) { - throw RnExecutorchError(decoderResult.error(), - "The model's forward function did not succeed. " - "Ensure the model input is correct."); - } - - const auto logitsTensor = decoderResult.get().at(0).toTensor(); - const int32_t outputNumel = static_cast(logitsTensor.numel()); - - const size_t innerDim = logitsTensor.size(1); - const size_t dictSize = logitsTensor.size(2); - - const float *const dataPtr = - logitsTensor.const_data_ptr() + (innerDim - 1) * dictSize; - - return {dataPtr, dataPtr + outputNumel / innerDim}; -} - -} // namespace rnexecutorch::models::speech_to_text::asr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h deleted file mode 100644 index 16a2f45e6..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once - -#include "rnexecutorch/TokenizerModule.h" -#include "rnexecutorch/models/BaseModel.h" -#include "rnexecutorch/models/speech_to_text/types/DecodingOptions.h" -#include "rnexecutorch/models/speech_to_text/types/GenerationResult.h" -#include "rnexecutorch/models/speech_to_text/types/Segment.h" - -namespace rnexecutorch::models::speech_to_text::asr { - -class ASR { -public: - explicit ASR(const models::BaseModel *encoder, - const models::BaseModel *decoder, - const TokenizerModule *tokenizer); - std::vector - transcribe(std::span waveform, - const types::DecodingOptions &options) const; - std::vector encode(std::span waveform) const; - std::vector decode(std::span tokens, - std::span encoderOutput) const; - -private: - const models::BaseModel *encoder; - const models::BaseModel *decoder; - const TokenizerModule *tokenizer; - - uint64_t startOfTranscriptionToken; - uint64_t endOfTranscriptionToken; - uint64_t timestampBeginToken; - - // Time precision used by Whisper timestamps: each token spans 0.02 seconds - constexpr static float kTimePrecision = 0.02f; - // The maximum number of tokens the decoder can generate per chunk - constexpr static int32_t kMaxDecodeLength = 128; - // Maximum duration of each audio chunk to process (in seconds) - // It is intentionally set to 29 since otherwise only the last chunk would be - // correctly transcribe due to the model's positional encoding limit - constexpr static int32_t kChunkSize = 29; - // Sampling rate expected by Whisper and the model's audio pipeline (16 kHz) - constexpr static int32_t kSamplingRate = 16000; - // Minimum allowed chunk length before processing (in audio samples) - constexpr static int32_t kMinChunkSamples = 1 * 16000; - // Number of mel frames output by the encoder (derived from input spectrogram) - constexpr static int32_t kNumFrames = 1500; - - std::vector - getInitialSequence(const types::DecodingOptions &options) const; - types::GenerationResult generate(std::span waveform, float temperature, - const types::DecodingOptions &options) const; - std::vector - generateWithFallback(std::span waveform, - const types::DecodingOptions &options) const; - std::vector - calculateWordLevelTimestamps(std::span generatedTokens, - const std::span waveform, - float avgLogProb, float temperature, - float compressionRatio) const; - std::vector - estimateWordLevelTimestampsLinear(std::span tokens, - uint64_t start, uint64_t end) const; - float getCompressionRatio(const std::string &text) const; -}; - -} // namespace rnexecutorch::models::speech_to_text::asr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h new file mode 100644 index 000000000..0390a7dfc --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include + +#include "../types/DecodingOptions.h" +#include "../types/Segment.h" +#include + +namespace rnexecutorch::models::speech_to_text::schema { + +/** + * @brief Abstract base class for Automatic Speech Recognition (ASR) models. + * + * Provides a unified interface for speech-to-text models like Whisper, allowing + * for transcription of raw audio waveforms into text segments, as well as + * access to lower-level model components like encoding and decoding. + */ +class ASR { +public: + virtual ~ASR() = default; + + std::vector virtual transcribe( + std::span waveform, const DecodingOptions &options) const = 0; + + virtual std::vector encode(std::span waveform) const = 0; + + virtual std::vector decode(std::span tokens, + std::span encoderOutput, + uint64_t startPos = 0) const = 0; + + // Standard ExecuTorch model methods for compatibility with the rest of the + // API. + virtual void unload() noexcept = 0; + virtual std::size_t getMemoryLowerBound() const noexcept = 0; +}; + +} // namespace rnexecutorch::models::speech_to_text::schema \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h new file mode 100644 index 000000000..357309391 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/OnlineASR.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#include "../types/DecodingOptions.h" +#include "../types/ProcessResult.h" +#include "../types/Word.h" + +namespace rnexecutorch::models::speech_to_text::schema { + +/** + * @brief Abstract base class for Online (streaming) Automatic Speech + * Recognition. + * + * Provides an interface for processing audio in chunks, allowing for real-time + * transcription results. Implementations of this interface typically maintain + * an internal audio buffer and a hypothesis buffer for incremental decoding. + * + * Requires 5 main methods to be implemented: + * - insertAudioChunk(): for expanding the collected audio (typically adding to + * a buffer) + * - ready(): returns a boolean flag indicating whether the module is ready to + * process the next iteration + * - process(): for processing a next chunk of audio (next iterartion) + * - finish(): called to finish the live transcription mode + * - reset(): resets the streaming state + */ +class OnlineASR { +public: + virtual ~OnlineASR() = default; + + virtual void insertAudioChunk(std::span audio) = 0; + + virtual bool isReady() const = 0; + + virtual ProcessResult process(const DecodingOptions &options) = 0; + + virtual std::vector finish() = 0; + + virtual void reset() = 0; +}; + +} // namespace rnexecutorch::models::speech_to_text::schema diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/DecodingOptions.h similarity index 73% rename from packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h rename to packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/DecodingOptions.h index 99774cf52..de0830826 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/DecodingOptions.h @@ -3,7 +3,7 @@ #include #include -namespace rnexecutorch::models::speech_to_text::types { +namespace rnexecutorch::models::speech_to_text { struct DecodingOptions { explicit DecodingOptions(const std::string &language, bool verbose = false) @@ -14,4 +14,4 @@ struct DecodingOptions { bool verbose; }; -} // namespace rnexecutorch::models::speech_to_text::types +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/GenerationResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/GenerationResult.h new file mode 100644 index 000000000..a68566f0e --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/GenerationResult.h @@ -0,0 +1,14 @@ +#pragma once + +#include "Token.h" +#include +#include + +namespace rnexecutorch::models::speech_to_text { + +struct GenerationResult { + std::vector tokens; + std::vector scores; +}; + +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/ProcessResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/ProcessResult.h new file mode 100644 index 000000000..493b5b05d --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/ProcessResult.h @@ -0,0 +1,14 @@ +#pragma once + +#include "Word.h" +#include +#include + +namespace rnexecutorch::models::speech_to_text { + +struct ProcessResult { + std::vector committed; + std::vector nonCommitted; +}; + +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Segment.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Segment.h similarity index 51% rename from packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Segment.h rename to packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Segment.h index b673cbe6e..878589dce 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Segment.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Segment.h @@ -1,13 +1,15 @@ #pragma once +#include "Token.h" #include "Word.h" +#include #include -namespace rnexecutorch::models::speech_to_text::types { +namespace rnexecutorch::models::speech_to_text { struct Segment { std::vector words; - std::vector tokens; // Raw token IDs + std::vector tokens; // Raw token IDs float start; float end; float avgLogprob; @@ -15,4 +17,4 @@ struct Segment { float compressionRatio; }; -} // namespace rnexecutorch::models::speech_to_text::types +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h new file mode 100644 index 000000000..17c4a4091 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace rnexecutorch::models::speech_to_text { + +using Token = uint64_t; + +} // namespace rnexecutorch::models::speech_to_text \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/TranscriptionResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h similarity index 69% rename from packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/TranscriptionResult.h rename to packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h index 5fe3868da..994cdb15e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/TranscriptionResult.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h @@ -3,7 +3,7 @@ #include #include -namespace rnexecutorch::models::speech_to_text::types { +namespace rnexecutorch::models::speech_to_text { struct TranscriptionResult { std::string text; @@ -13,4 +13,4 @@ struct TranscriptionResult { std::vector segments; // Populated only if verbose=true }; -} // namespace rnexecutorch::models::speech_to_text::types \ No newline at end of file +} // namespace rnexecutorch::models::speech_to_text \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h new file mode 100644 index 000000000..9de04a9c5 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace rnexecutorch::models::speech_to_text { + +struct Word { + std::string content; + float start; + float end; + + std::string punctations = + ""; // Trailing punctations which appear after the main content +}; + +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp deleted file mode 100644 index 31806c126..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp +++ /dev/null @@ -1,82 +0,0 @@ -#include "HypothesisBuffer.h" - -namespace rnexecutorch::models::speech_to_text::stream { - -using namespace types; - -void HypothesisBuffer::insert(std::span newWords, float offset) { - this->fresh.clear(); - for (const auto &word : newWords) { - const float newStart = word.start + offset; - if (newStart > lastCommittedTime - 0.5f) { - this->fresh.emplace_back(word.content, newStart, word.end + offset); - } - } - - if (!this->fresh.empty() && !this->committedInBuffer.empty()) { - const float a = this->fresh.front().start; - if (std::fabs(a - lastCommittedTime) < 1.0f) { - const size_t cn = this->committedInBuffer.size(); - const size_t nn = this->fresh.size(); - const std::size_t maxCheck = std::min({cn, nn, 5}); - for (size_t i = 1; i <= maxCheck; i++) { - std::string c; - for (auto it = this->committedInBuffer.cend() - i; - it != this->committedInBuffer.cend(); ++it) { - if (!c.empty()) { - c += ' '; - } - c += it->content; - } - - std::string tail; - auto it = this->fresh.cbegin(); - for (size_t k = 0; k < i; k++, it++) { - if (!tail.empty()) { - tail += ' '; - } - tail += it->content; - } - - if (c == tail) { - this->fresh.erase(this->fresh.begin(), this->fresh.begin() + i); - break; - } - } - } - } -} - -std::deque HypothesisBuffer::flush() { - std::deque commit; - - while (!this->fresh.empty() && !this->buffer.empty()) { - if (this->fresh.front().content != this->buffer.front().content) { - break; - } - commit.push_back(this->fresh.front()); - this->buffer.pop_front(); - this->fresh.pop_front(); - } - - if (!commit.empty()) { - lastCommittedTime = commit.back().end; - } - - this->buffer = std::move(this->fresh); - this->fresh.clear(); - this->committedInBuffer.insert(this->committedInBuffer.end(), commit.begin(), - commit.end()); - return commit; -} - -void HypothesisBuffer::popCommitted(float time) { - while (!this->committedInBuffer.empty() && - this->committedInBuffer.front().end <= time) { - this->committedInBuffer.pop_front(); - } -} - -std::deque HypothesisBuffer::complete() const { return this->buffer; } - -} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h deleted file mode 100644 index cfa11fd66..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include -#include - -#include "rnexecutorch/models/speech_to_text/types/Word.h" - -namespace rnexecutorch::models::speech_to_text::stream { - -class HypothesisBuffer { -public: - void insert(std::span newWords, float offset); - std::deque flush(); - void popCommitted(float time); - std::deque complete() const; - -private: - float lastCommittedTime = 0.0f; - - std::deque committedInBuffer; - std::deque buffer; - std::deque fresh; -}; - -} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp deleted file mode 100644 index 3137d274b..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include - -#include "OnlineASRProcessor.h" - -namespace rnexecutorch::models::speech_to_text::stream { - -using namespace asr; -using namespace types; - -OnlineASRProcessor::OnlineASRProcessor(const ASR *asr) : asr(asr) {} - -void OnlineASRProcessor::insertAudioChunk(std::span audio) { - audioBuffer.insert(audioBuffer.end(), audio.begin(), audio.end()); -} - -ProcessResult OnlineASRProcessor::processIter(const DecodingOptions &options) { - std::vector res = asr->transcribe(audioBuffer, options); - - std::vector tsw; - for (const auto &segment : res) { - for (const auto &word : segment.words) { - tsw.push_back(word); - } - } - - this->hypothesisBuffer.insert(tsw, this->bufferTimeOffset); - std::deque flushed = this->hypothesisBuffer.flush(); - this->committed.insert(this->committed.end(), flushed.begin(), flushed.end()); - - constexpr int32_t chunkThresholdSec = 15; - if (static_cast(audioBuffer.size()) / - OnlineASRProcessor::kSamplingRate > - chunkThresholdSec) { - chunkCompletedSegment(res); - } - - auto move_to_vector = [](auto& container) { - return std::vector(std::make_move_iterator(container.begin()), - std::make_move_iterator(container.end())); - }; - - std::deque nonCommittedWords = this->hypothesisBuffer.complete(); - - return { move_to_vector(flushed), move_to_vector(nonCommittedWords) }; -} - -void OnlineASRProcessor::chunkCompletedSegment(std::span res) { - if (this->committed.empty()) - return; - - std::vector ends(res.size()); - std::ranges::transform(res, ends.begin(), [](const Segment &seg) { - return seg.words.back().end; - }); - - const float t = this->committed.back().end; - - if (ends.size() > 1) { - float e = ends[ends.size() - 2] + this->bufferTimeOffset; - while (ends.size() > 2 && e > t) { - ends.pop_back(); - e = ends[ends.size() - 2] + this->bufferTimeOffset; - } - if (e <= t) { - chunkAt(e); - } - } -} - -void OnlineASRProcessor::chunkAt(float time) { - this->hypothesisBuffer.popCommitted(time); - - const float cutSeconds = time - this->bufferTimeOffset; - auto startIndex = - static_cast(cutSeconds * OnlineASRProcessor::kSamplingRate); - - if (startIndex < audioBuffer.size()) { - audioBuffer.erase(audioBuffer.begin(), audioBuffer.begin() + startIndex); - } else { - audioBuffer.clear(); - } - - this->bufferTimeOffset = time; -} - -std::vector OnlineASRProcessor::finish() { - std::deque bufferDeq = this->hypothesisBuffer.complete(); - std::vector buffer(std::make_move_iterator(bufferDeq.begin()), - std::make_move_iterator(bufferDeq.end())); - - this->bufferTimeOffset += static_cast(audioBuffer.size()) / - OnlineASRProcessor::kSamplingRate; - return buffer; -} - -} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h deleted file mode 100644 index 98944bdbe..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#include "rnexecutorch/models/speech_to_text/asr/ASR.h" -#include "rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h" -#include "rnexecutorch/models/speech_to_text/types/ProcessResult.h" -#include "rnexecutorch/models/speech_to_text/types/Word.h" - -namespace rnexecutorch::models::speech_to_text::stream { - -class OnlineASRProcessor { -public: - explicit OnlineASRProcessor(const asr::ASR *asr); - - void insertAudioChunk(std::span audio); - types::ProcessResult processIter(const types::DecodingOptions &options); - std::vector finish(); - - std::vector audioBuffer; - -private: - const asr::ASR *asr; - constexpr static int32_t kSamplingRate = 16000; - - HypothesisBuffer hypothesisBuffer; - float bufferTimeOffset = 0.0f; - std::vector committed; - - void chunkCompletedSegment(std::span res); - void chunkAt(float time); -}; - -} // namespace rnexecutorch::models::speech_to_text::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h deleted file mode 100644 index 83bc80dd7..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -namespace rnexecutorch::models::speech_to_text::types { - -struct GenerationResult { - std::vector tokens; - std::vector scores; -}; - -} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h deleted file mode 100644 index 681495e2a..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -namespace rnexecutorch::models::speech_to_text::types { - -struct ProcessResult { - std::vector committed; - std::vector nonCommitted; -}; - -} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h deleted file mode 100644 index 98c72f273..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/types/Word.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include - -namespace rnexecutorch::models::speech_to_text::types { - -struct Word { - std::string content; - float start; - float end; -}; - -} // namespace rnexecutorch::models::speech_to_text::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp new file mode 100644 index 000000000..09562088e --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp @@ -0,0 +1,453 @@ +#include +#include +#include +#include + +#include "ASR.h" +#include "Constants.h" +#include "Params.h" +#include +#include +#include +#include + +#include + +namespace rnexecutorch::models::speech_to_text::whisper { + +using executorch::runtime::etensor::ScalarType; + +ASR::ASR(const std::string &modelSource, const std::string &tokenizerSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, std::move(callInvoker)), schema::ASR(), + tokenizer_(std::make_unique(tokenizerSource, + this->callInvoker)), + startOfTranscriptionToken_( + tokenizer_->tokenToId(constants::tokens::kStartOfTranscript)), + endOfTranscriptionToken_( + tokenizer_->tokenToId(constants::tokens::kEndOfTranscript)), + timestampBeginToken_( + tokenizer_->tokenToId(constants::tokens::kBeginTimestamp)) {} + +/** + * Whisper inference - full transcription + */ +std::vector ASR::transcribe(std::span waveform, + const DecodingOptions &options) const { + // Use floats to prevent downcasting and timestamp mismatches + float seek = 0.f; + std::vector results; + + const float waveformSize = static_cast(waveform.size()); + const float waveformSkipBoundary = + static_cast((constants::kChunkSize - params::kChunkBreakBuffer) * + constants::kSamplingRate); + + // We loop through the input audio waveform and process it in 30s chunks. + // This is determined by Whisper models strict 30s audio length requirement. + while (seek * constants::kSamplingRate < waveformSize) { + // Calculate chunk bounds and extract the chunk. + float start = seek * constants::kSamplingRate; + const auto end = + std::min(static_cast((seek + constants::kChunkSize) * + constants::kSamplingRate), + waveformSize); + auto chunk = waveform.subspan(start, end - start); + + if (std::cmp_less(chunk.size(), constants::kMinChunkSamples)) { + break; + } + + // Enter the processing logic. + std::vector segments = this->generate(chunk, options); + + if (segments.empty()) { + seek += constants::kChunkSize; + continue; + } + + for (auto &seg : segments) { + for (auto &w : seg.words) { + w.start += seek; + w.end += seek; + } + + seg.start += seek; + seg.end += seek; + } + + while (!segments.empty() && segments.back().words.empty()) { + segments.pop_back(); + } + + if (!segments.empty() && !segments.back().words.empty()) { + // This prevents additional segments to appear, unless the audio length is + // very close to the max chunk size, that is there could be some words + // spoken near the breakpoint. + seek = waveformSize < waveformSkipBoundary + ? seek + constants::kChunkSize + : segments.back().words.back().end; + } + results.insert(results.end(), std::make_move_iterator(segments.begin()), + std::make_move_iterator(segments.end())); + } + + return results; +} + +/** + * Whisper inference - encoding phase + * + * The input is a standard audio waveform, altough it is implicitly converted + * to a log mel format inside the encoder call. + */ +std::vector ASR::encode(std::span waveform) const { + auto inputShape = {static_cast(waveform.size())}; + + const auto modelInputTensor = executorch::extension::make_tensor_ptr( + std::move(inputShape), waveform.data(), ScalarType::Float); + + const auto encoderResult = this->execute("encode", {modelInputTensor}); + + if (!encoderResult.ok()) { + throw RnExecutorchError(encoderResult.error(), + "[Whisper] The 'encode' method did not succeed. " + "Ensure the model input is correct."); + } + + const auto encoderOutputTensor = encoderResult.get().at(0).toTensor(); + const auto outputNumel = encoderOutputTensor.numel(); + + const float *const dataPtr = encoderOutputTensor.const_data_ptr(); + return {dataPtr, dataPtr + outputNumel}; +} + +/** + * Whisper inference - decoding phase + * + * An autoregressive decoder, called with increasing amount of input tokens. + */ +std::vector ASR::decode(std::span tokens, + std::span encoderOutput, + uint64_t startPos) const { + std::vector tokenShape = {1, static_cast(tokens.size())}; + std::vector positionShape = {static_cast(tokens.size())}; + + auto tokenTensor = executorch::extension::make_tensor_ptr( + tokenShape, tokens.data(), ScalarType::Long); + + // Populate cache position vector + std::vector cachePositions(tokens.size()); + std::iota(cachePositions.begin(), cachePositions.end(), startPos); + auto positionTensor = executorch::extension::make_tensor_ptr( + positionShape, cachePositions.data(), ScalarType::Long); + + const auto encoderOutputSize = static_cast(encoderOutput.size()); + std::vector encShape = {1, constants::kNumFrames, + encoderOutputSize / constants::kNumFrames}; + auto encoderTensor = executorch::extension::make_tensor_ptr( + std::move(encShape), encoderOutput.data(), ScalarType::Float); + + const auto decoderResult = + this->execute("decode", {tokenTensor, positionTensor, encoderTensor}); + + if (!decoderResult.ok()) { + throw RnExecutorchError(decoderResult.error(), + "[Whisper] The 'decode' method did not succeed. " + "Ensure the model inputs are correct."); + } + + const auto logitsTensor = decoderResult.get().at(0).toTensor(); + const int32_t outputNumel = static_cast(logitsTensor.numel()); + + const size_t innerDim = logitsTensor.size(1); + const size_t dictSize = logitsTensor.size(2); + + const float *const dataPtr = + logitsTensor.const_data_ptr() + (innerDim - 1) * dictSize; + + return {dataPtr, dataPtr + outputNumel / innerDim}; +} + +void ASR::unload() noexcept { BaseModel::unload(); } + +std::size_t ASR::getMemoryLowerBound() const noexcept { + return BaseModel::getMemoryLowerBound(); +} + +/** + * Helper functions - creating initial token IDs sequence + */ +std::vector +ASR::createInitialSequence(const DecodingOptions &options) const { + std::vector seq; + seq.push_back(startOfTranscriptionToken_); + + if (options.language.has_value()) { + uint64_t langToken = + tokenizer_->tokenToId("<|" + options.language.value() + "|>"); + uint64_t taskToken = tokenizer_->tokenToId("<|transcribe|>"); + seq.push_back(langToken); + seq.push_back(taskToken); + } + + seq.push_back(timestampBeginToken_); + + return seq; +} + +/** + * Helper functions - generation wrapper, with fallback + */ +std::vector ASR::generate(std::span waveform, + const DecodingOptions &options) const { + // A fixed pool of available temperatures + constexpr std::array temperatures = {0.0f, 0.2f, 0.4f, + 0.6f, 0.8f, 1.0f}; + + // Calculate audio features just once to save time. + std::vector encoderOutput = this->encode(waveform); + + std::vector bestTokens; + float bestAvgLogProb = -std::numeric_limits::infinity(); + float bestCompressionRatio = 0.0f; + float bestTemperature = 0.0f; + + for (auto t : temperatures) { + auto [tokens, scores] = + this->generate(waveform, options, t, {encoderOutput}); + + const float cumLogProb = std::transform_reduce( + scores.begin(), scores.end(), 0.0f, std::plus<>(), + [](float s) { return std::log(std::max(s, 1e-9f)); }); + + const float avgLogProb = cumLogProb / static_cast(tokens.size() + 1); + const std::string text = tokenizer_->decode(tokens, true); + const float compressionRatio = this->calculateCompressionRatio(text); + + if (avgLogProb >= -1.0f && compressionRatio < 2.4f) { + bestTokens = std::move(tokens); + bestAvgLogProb = avgLogProb; + bestCompressionRatio = compressionRatio; + bestTemperature = t; + break; + } + + if (t == temperatures.back() && bestTokens.empty()) { + bestTokens = std::move(tokens); + bestAvgLogProb = avgLogProb; + bestCompressionRatio = compressionRatio; + bestTemperature = t; + } + } + + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[ASR] Raw transcription results (tokens): ", bestTokens); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "[ASR] Raw transcription results (text): ", + tokenizer_->decode(bestTokens, true)); + + return this->calculateWordLevelTimestamps(bestTokens, waveform, + bestAvgLogProb, bestTemperature, + bestCompressionRatio); +} + +/** + * Helper functions - generation wrapper, single-temperature inference + */ +GenerationResult +ASR::generate(std::span waveform, const DecodingOptions &options, + float temperature, + std::optional> encoderOutput) const { + std::vector encoderOutputData = !encoderOutput.has_value() + ? this->encode(waveform) + : std::vector(); + std::span encodings = encoderOutput.has_value() + ? encoderOutput.value() + : std::span(encoderOutputData); + + std::vector sequenceIds = this->createInitialSequence(options); + std::vector cachedTokens = sequenceIds; + const size_t initialSequenceLenght = sequenceIds.size(); + std::vector scores; + + uint64_t startPos = 0; + while (std::cmp_less_equal(startPos + sequenceIds.size(), + constants::kMaxDecodeLength)) { + std::vector logits = this->decode(sequenceIds, encodings, startPos); + + // intentionally comparing float to float + // temperatures are predefined, so this is safe + if (temperature == 0.0f) { + numerical::softmax(logits); + } else { + numerical::softmaxWithTemperature(logits, temperature); + } + + const std::vector &probs = logits; + + uint64_t nextId; + float nextProb; + + // intentionally comparing float to float + // temperatures are predefined, so this is safe + if (temperature == 0.0f) { + auto maxIt = std::ranges::max_element(probs); + nextId = static_cast(std::distance(probs.begin(), maxIt)); + nextProb = *maxIt; + } else { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + std::mt19937 gen((std::random_device{}())); + nextId = dist(gen); + nextProb = probs[nextId]; + } + + // Move the startPos pointer by the amount of tokens we processed + startPos += sequenceIds.size(); + sequenceIds = {nextId}; + cachedTokens.push_back(nextId); + scores.push_back(nextProb); + + if (nextId == endOfTranscriptionToken_) { + break; + } + } + + return {.tokens = std::vector(cachedTokens.cbegin() + + initialSequenceLenght, + cachedTokens.cend()), + .scores = scores}; +} + +std::vector ASR::calculateWordLevelTimestamps( + std::span generatedTokens, const std::span waveform, + float avgLogProb, float temperature, float compressionRatio) const { + const size_t generatedTokensSize = generatedTokens.size(); + if (generatedTokensSize < 2 || + generatedTokens[generatedTokensSize - 1] != endOfTranscriptionToken_ || + generatedTokens[generatedTokensSize - 2] < timestampBeginToken_) { + return {}; + } + std::vector segments; + std::vector tokens; + uint64_t prevTimestamp = timestampBeginToken_; + + for (size_t i = 0; i < generatedTokensSize; i++) { + if (generatedTokens[i] < timestampBeginToken_) { + tokens.push_back(generatedTokens[i]); + } + if (i > 0 && generatedTokens[i - 1] >= timestampBeginToken_ && + generatedTokens[i] >= timestampBeginToken_) { + const uint64_t start = prevTimestamp; + const uint64_t end = generatedTokens[i - 1]; + auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end); + if (words.size()) { + Segment seg; + seg.words = std::move(words); + // seg.tokens = {}; // WTF ? + seg.tokens = tokens; + seg.avgLogprob = avgLogProb; + seg.temperature = temperature; + seg.compressionRatio = compressionRatio; + + if (!seg.words.empty()) { + seg.start = seg.words.front().start; + seg.end = seg.words.back().end; + } else { + seg.start = 0.0; + seg.end = 0.0; + } + + segments.push_back(std::move(seg)); + } + tokens.clear(); + prevTimestamp = generatedTokens[i]; + } + } + + const uint64_t start = prevTimestamp; + const uint64_t end = generatedTokens[generatedTokensSize - 2]; + auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end); + + Segment seg; + seg.words = std::move(words); + seg.tokens = tokens; + seg.avgLogprob = avgLogProb; + seg.temperature = temperature; + seg.compressionRatio = compressionRatio; + + if (!seg.words.empty()) { + seg.start = seg.words.front().start; + seg.end = seg.words.back().end; + } + + segments.push_back(std::move(seg)); + + float scalingFactor = + static_cast(waveform.size()) / + (constants::kSamplingRate * (end - timestampBeginToken_) * + constants::kTimePrecision); + if (scalingFactor < 1.0f) { + for (auto &seg : segments) { + for (auto &w : seg.words) { + w.start *= scalingFactor; + w.end *= scalingFactor; + } + } + } + + return segments; +} + +std::vector +ASR::estimateWordLevelTimestampsLinear(std::span tokens, + uint64_t start, uint64_t end) const { + const std::vector tokensVec(tokens.begin(), tokens.end()); + const std::string segmentText = tokenizer_->decode(tokensVec, true); + + std::istringstream iss(segmentText); + std::vector wordsStr; + std::string word; + while (iss >> word) { + wordsStr.emplace_back(" "); + wordsStr.back().append(word); + } + + size_t numChars = 0; + for (const auto &w : wordsStr) { + numChars += w.size(); + } + const float duration = (end - start) * constants::kTimePrecision; + const float timePerChar = duration / std::max(1, numChars); + const float startOffset = + (start - timestampBeginToken_) * constants::kTimePrecision; + + std::vector wordObjs; + wordObjs.reserve(wordsStr.size()); + int32_t prevCharCount = 0; + for (auto &w : wordsStr) { + const auto wSize = static_cast(w.size()); + const float wStart = startOffset + prevCharCount * timePerChar; + const float wEnd = wStart + timePerChar * wSize; + prevCharCount += wSize; + + // We store punctations separately to other characters. + std::string puncts = ""; + while (!w.empty() && constants::kPunctations.contains(w.back())) { + puncts += w.back(); + w.pop_back(); + } + std::reverse(puncts.begin(), puncts.end()); + + wordObjs.emplace_back(std::move(w), wStart, wEnd, std::move(puncts)); + } + + return wordObjs; +} + +float ASR::calculateCompressionRatio(const std::string &text) const { + size_t compressedSize = gzip::deflateSize(text); + return static_cast(text.size()) / static_cast(compressedSize); +} + +} // namespace rnexecutorch::models::speech_to_text::whisper diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h new file mode 100644 index 000000000..54d2d699a --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h @@ -0,0 +1,167 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "../common/schema/ASR.h" +#include "../common/types/GenerationResult.h" +#include "../common/types/Token.h" +#include +#include +#include + +namespace rnexecutorch::models::speech_to_text::whisper { + +using executorch::aten::Tensor; + +/** + * Automatic Speech Recognition (ASR) class for Whisper-based models. + * This class handles both encoding and decoding steps for Whisper family + * models, loading a single model with named entry points for "encode" and + * "decode". + */ +class ASR : public models::BaseModel, public schema::ASR { +public: + ASR(const std::string &modelSource, const std::string &tokenizerSource, + std::shared_ptr callInvoker); + + /** + * @brief The main Whisper transcription API point. + * Wrapps the entire transciption process into a single method. + * + * @param waveform Input audio waveform sampled at 16kHz, similarly to + * encode's input. + * @param options Control variables for decoding process. + */ + std::vector virtual transcribe( + std::span waveform, const DecodingOptions &options) const override; + + /** + * Encodes the input audio waveform into mel spectrogram embeddings. + * + * @param waveform Input audio waveform sampled at 16kHz. + * @return Flat vector containing the encoder's output features. + * The output tensor shape: [1, 1500, 384] for Whisper + * models. + */ + std::vector encode(std::span waveform) const override; + + /** + * Decodes a sequence of tokens into logits given the encoded audio features. + * + * @param tokens A span of token IDs from previous iteration + * (see autoregressive nature of Whisper decoding). + * @param encoderOutput A span of floats containing the precomputed encoder + * embeddings. + * @param startPos The starting position in the sequence (used for KV + * caching). + * @return A vector of floats representing the output logits for + * the next token. + */ + std::vector decode(std::span tokens, + std::span encoderOutput, + uint64_t startPos = 0) const override; + + // Standard ExecuTorch model methods for compatibility with the rest of the + // API. + void unload() noexcept override; + std::size_t getMemoryLowerBound() const noexcept override; + +private: + /** + * A helper factory for creating initial token sequences. + * + * The initial sequence consists of special tokens, such as + * language mark token or timestamp token. It's always a part + * of decoder's input. + * + * @param options Determine a specific properties of the initial sequence, + * such as whether + * to add a language mark token or not. + */ + std::vector + createInitialSequence(const DecodingOptions &options) const; + + /** + * Generation wrapper - wrapps encoding & decoding with + * temperature fallback mechanism. + * It could, in theory, run up to 5 inferences for increasing + * temperature values. + * + * @param waveform Input audio waveform sampled at 16kHz, similarly to + * encode's input. + * @param options Control variables for decoding process. + */ + std::vector generate(std::span waveform, + const DecodingOptions &options) const; + + /** + * Generation wrapper - wrapps encoding & decoding for a single, + * specific temperature value. + * Results in a single inference. + * Allows to skip the encoding phase if encoder results are already provided. + * + * @param waveform Input audio waveform sampled at 16kHz, similarly to + * encode's input. + * @param options Control variables for decoding process. + * @param temperature Controls the scale of randomization during the logits + * resolving process. + * @param encoderOutput An optional parameter. If provided, the encoding phase + * is skipped and the provided value is used instead. + */ + GenerationResult + generate(std::span waveform, const DecodingOptions &options, + float temperature, + std::optional> encoderOutput = std::nullopt) const; + + /** + * Calculates word-level timestamps for a sequence of generated tokens. + * + * This method parses the generated tokens, splits them into segments based on + * timestamp tokens, and applies a linear estimation for individual words. + * It also adjusts timestamps based on the actual waveform length. + * + * @param generatedTokens The sequence of tokens produced by the model. + * @param waveform The original audio signal used for scaling. + * @param avgLogProb Average log probability of the generated sequence. + * @param temperature Temperature used during generation. + * @param compressionRatio Text compression ratio for the generated sequence. + * @return A vector of transcribed segments with word-level + * timing. + */ + std::vector + calculateWordLevelTimestamps(std::span generatedTokens, + const std::span waveform, + float avgLogProb, float temperature, + float compressionRatio) const; + + /** + * Estimates word-level timestamps linearly within a token sequence. + * + * Decodes the tokens into words and distributes the time interval [start, + * end] across words based on their character count. + * + * @param tokens The slice of tokens representing a single segment. + * @param start The timestamp token ID marking the beginning of the segment. + * @param end The timestamp token ID marking the end of the segment. + * @return A vector of Word objects with estimated start and end times. + */ + std::vector + estimateWordLevelTimestampsLinear(std::span tokens, + uint64_t start, uint64_t end) const; + + float calculateCompressionRatio(const std::string &text) const; + + // Submodules - a tokenizer module for decoding process. + std::unique_ptr tokenizer_; + + // Tokenization helper definitions + const Token startOfTranscriptionToken_; + const Token endOfTranscriptionToken_; + const Token timestampBeginToken_; +}; + +} // namespace rnexecutorch::models::speech_to_text::whisper diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h new file mode 100644 index 000000000..1d2c40e93 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +namespace rnexecutorch::models::speech_to_text::whisper::constants { + +// Maximum duration of each audio chunk to process (in seconds) +// It is intentionally set to 29 since otherwise only the last chunk would be +// correctly transcribe due to the model's positional encoding limit +constexpr static int32_t kChunkSize = 29; + +// The maximum number of tokens the decoder can generate per chunk +constexpr static int32_t kMaxDecodeLength = 128; + +// Minimum allowed chunk length before processing (in audio samples) +constexpr static int32_t kMinChunkSamples = 1 * 16000; + +// Number of mel frames output by the encoder (derived from input spectrogram) +constexpr static int32_t kNumFrames = 1500; + +// Sampling rate expected by Whisper and the model's audio pipeline (16 kHz) +constexpr static int32_t kSamplingRate = 16000; +constexpr static int32_t kSamplesPerMilisecond = kSamplingRate / 1000; + +// Time precision used by Whisper timestamps: each token spans 0.02 seconds +constexpr static float kTimePrecision = 0.02f; + +// Special characters serving as pause / end of sentence +static const std::unordered_set kPunctations = {',', '.', '?', + '!', ':', ';'}; + +// Special token constants +namespace tokens { +static const std::string kStartOfTranscript = "<|startoftranscript|>"; +static const std::string kEndOfTranscript = "<|endoftext|>"; +static const std::string kBeginTimestamp = "<|0.00|>"; +static const std::string kBlankAudio = "[BLANK_AUDIO]"; +} // namespace tokens + +} // namespace rnexecutorch::models::speech_to_text::whisper::constants \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp new file mode 100644 index 000000000..45b24f64c --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp @@ -0,0 +1,103 @@ +#include "HypothesisBuffer.h" +#include "Params.h" +#include "Utils.h" +#include + +namespace rnexecutorch::models::speech_to_text::whisper::stream { + +void HypothesisBuffer::insert(std::span words, float offset) { + // Step 1 - decide which words should be considered as fresh. + // Using less amount of fresh words saves a little bit of time, but + // could backfire in terms of quality of the final committed transcript. + fresh_.clear(); + for (const Word &word : words) { + // Global start is a beginning timestamp relative only to the beginning of + // the current streaming process. + const float startGlobal = word.start + offset; + const float endGlobal = word.end + offset; + + // To optimize the process, we discard the words which are too old + // according to the calculated timestamp. + if (startGlobal > lastCommittedTime_ - params::kStreamFreshThreshold) { + fresh_.emplace_back(word.content, startGlobal, endGlobal, + word.punctations); + } + } + + // Step 2 - we have already selected the fresh words. Now it's time to + // correct any mistakes and remove the words which overlap with already + // commited segments - to avoid duplicates. + if (!fresh_.empty() && !committed_.empty()) { + // Calculate the largest overlapping fragment size. + // Note that we use size limit (kStreamMaxOverlapSize) for efficiency of the + // algorithm, and timestamp difference limit + // (kStreamMaxOverlapTimestampDiff) to avoid removing correct fragments + // which were just repeated after some time. + size_t overlapSize = utils::findLargestOverlapingFragment( + committed_, fresh_, params::kStreamMaxOverlapSize, + params::kStreamMaxOverlapTimestampDiff); + + // Remove all the overlapping words. + if (overlapSize > 0) { + fresh_.erase(fresh_.begin(), fresh_.begin() + overlapSize); + } + } +} + +std::deque HypothesisBuffer::commit() { + std::deque toCommit = {}; + + // Find a stable prefix: words that haven't changed between last and current + // iteration. + while (!fresh_.empty() && !hypothesis_.empty() && + fresh_.front().content == hypothesis_.front().content) { + // The last word from the fresh_ buffer must also match punctations with the + // hypothesis. This is done in order to ensure correct punctation marks in + // the resulting transcription. + if (fresh_.size() == 1 && + fresh_.front().punctations != hypothesis_.front().punctations) { + break; + } + + // Take timestamps from the hypothesis, but actual content from the fresh + // buffer. + toCommit.emplace_back(std::move(fresh_.front().content), + hypothesis_.front().start, hypothesis_.front().end, + std::move(fresh_.front().punctations)); + fresh_.pop_front(); + hypothesis_.pop_front(); + } + + // Save the last committed word timestamp. + // This will mark the end of the entire committed sequence. + if (!toCommit.empty()) { + lastCommittedTime_ = toCommit.back().end; + } + + // The remaining words from the fresh buffer (uncommitted phrase) + // become a hypothesis for the next iteration. + hypothesis_ = std::move(fresh_); + fresh_.clear(); + + // The last step is to commit the selected words. + committed_.insert(committed_.end(), toCommit.cbegin(), toCommit.cend()); + + return toCommit; +} + +void HypothesisBuffer::releaseCommits(size_t wordsToKeep) { + if (committed_.size() > wordsToKeep) { + size_t nWordsToErase = committed_.size() - wordsToKeep; + committed_.erase(committed_.begin(), committed_.begin() + nWordsToErase); + } +} + +void HypothesisBuffer::reset() { + fresh_.clear(); + hypothesis_.clear(); + committed_.clear(); + + lastCommittedTime_ = 0.f; +} + +} // namespace rnexecutorch::models::speech_to_text::whisper::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h new file mode 100644 index 000000000..226c037b7 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include + +#include "../common/types/Word.h" + +namespace rnexecutorch::models::speech_to_text::whisper::stream { + +/** + * A buffer for managing streaming transcription hypotheses. + * This class handles stabilization of the transcription result by tracking + * "fresh" hypotheses and "committing" them once they are stable across updates. + */ +class HypothesisBuffer { +public: + /** + * Inserts new words into the fresh_ buffer. + * Words are filtered based on the last committed time and checked for + * overlaps with existing committed words to prevent duplicates. + * + * @param newWords A span of recently generated words. + * @param offset Time offset to adjust the word timestamps. + */ + void insert(std::span words, float offset); + + /** + * Attempts to commit words present in the fresh_ buffer. + * A phrase from fresh_ buffer can only be committed if it also appears + * in the hypothesis_ buffer (uncommitted words from previous iteration). + * + * Uncommitted words become a 'hypothesis' and are moved into the hypothesis_ + * buffer. + * + * @return A sequence of words committed in the current iteration. + */ + std::deque commit(); + + /** + * Shrinks the committed_ buffer by erasing all words except N latest ones. + * + * Used primarily to relieve increasing memory usage during very + * long streaming sessions. + * + * @param wordsToKeep - number of trailing words to be kept in. + */ + void releaseCommits(size_t wordsToKeep); + + /** + * Resets all the stored buffers and state variables to the initial state + */ + void reset(); + + // Declare a friendship with OnlineASR to allow it to access the internal + // state of stored buffers. + friend class OnlineASR; + +private: + // Stored buffers + // The lifecycle of a correct result word looks as following: + // fresh buffer -> hypothesis buffer -> commited + std::deque + fresh_; // 'New' words from current iterations, which require some checks + // before they go into hypothesis_ buffer. + std::deque + hypothesis_; // Words potentially to be commited, stored between + // iterations (obtained from fresh_ buffer). + std::deque committed_; // A history of already commited words. + + float lastCommittedTime_ = 0.0f; +}; + +} // namespace rnexecutorch::models::speech_to_text::whisper::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp new file mode 100644 index 000000000..755c958f9 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp @@ -0,0 +1,153 @@ +#include +#include +#include +#include + +#include "Constants.h" +#include "OnlineASR.h" +#include "Params.h" + +namespace rnexecutorch::models::speech_to_text::whisper::stream { + +namespace { +// A helper function to avoid code duplication. +std::vector move_to_vector(std::deque &container) { + return std::vector(std::make_move_iterator(container.begin()), + std::make_move_iterator(container.end())); +}; +} // namespace + +OnlineASR::OnlineASR(const ASR *asr) : asr_(asr) { + // Reserve a minimal expected amount of memory for audio buffer. + audioBuffer_.reserve(static_cast(2 * params::kStreamChunkThreshold * + constants::kSamplingRate)); +} + +void OnlineASR::insertAudioChunk(std::span audio) { + audioBuffer_.insert(audioBuffer_.end(), audio.begin(), audio.end()); + + // Update the epsilon accordingly to the amount of added audio. + epsilon_ += static_cast(audio.size()) / constants::kSamplingRate; +} + +bool OnlineASR::isReady() const { + return audioBuffer_.size() >= constants::kMinChunkSamples; +} + +ProcessResult OnlineASR::process(const DecodingOptions &options) { + // Perform a transcription process to obtain results for + // the current state of the audio buffer. + std::vector transcriptions = asr_->transcribe(audioBuffer_, options); + + if (transcriptions.empty()) { + return {.committed = {}, .nonCommitted = {}}; + } + + // Flatten segments into a single word sequence. + // In this case, Word consists of text and timestamps. + std::vector words; + words.reserve(transcriptions.front().words.size()); + + // Note that we transfer the ownership of moves, so words should not be + // accessed by transcriptions.segment.words afterwards. + for (auto &segment : transcriptions) { + words.insert(words.end(), std::make_move_iterator(segment.words.begin()), + std::make_move_iterator(segment.words.end())); + } + + hypothesisBuffer_.insert(words, bufferTimeOffset_); + + // Apply fix for timestamps. + // After the insert() call on hypothesis buffer, the inner fresh_ buffer + // contains either a completely new words or words which overlap only + // with the inner hypothesis_ buffer. + if (!hypothesisBuffer_.fresh_.empty()) { + float establishedEnd = !hypothesisBuffer_.committed_.empty() + ? hypothesisBuffer_.committed_.back().end + : 0.F; + const float newEnd = hypothesisBuffer_.fresh_.back().end; + float newBegin = hypothesisBuffer_.fresh_.front().start; + + for (size_t i = 0; i < hypothesisBuffer_.fresh_.size(); i++) { + // If the word overlaps with the hypothesis, we can simply copy the + // timestamps from the previous iteration (that is, from the hypothesis + // inner buffer). + if (i < hypothesisBuffer_.hypothesis_.size() && + hypothesisBuffer_.fresh_[i].content == + hypothesisBuffer_.hypothesis_[i].content) { + hypothesisBuffer_.fresh_[i].start = + hypothesisBuffer_.hypothesis_[i].start; + hypothesisBuffer_.fresh_[i].end = hypothesisBuffer_.hypothesis_[i].end; + + establishedEnd = hypothesisBuffer_.hypothesis_[i].end; + newBegin = hypothesisBuffer_.fresh_[i].end; + + continue; + } + + // In case of a new word, we apply timestamp range scaling + // based on timestamps established in previous iterations. + // The idea is, that both ranges of established and completely new words + // should sum up to the final timestamp produced by the model. + // TODO: estimate the epsilon value + const float beforeScaleStart = hypothesisBuffer_.fresh_[i].start; + const float beforeScaleEnd = hypothesisBuffer_.fresh_[i].end; + float scale = + (newEnd - establishedEnd) / (newEnd - newBegin); // missing epsilon + // float scale = (newEnd - establishedEnd - epsilon) / (newEnd - + // newBegin); // correct + hypothesisBuffer_.fresh_[i].start = + (hypothesisBuffer_.fresh_[i].start - newEnd) * scale + newEnd; + hypothesisBuffer_.fresh_[i].end = + (hypothesisBuffer_.fresh_[i].end - newEnd) * scale + newEnd; + } + + // Before committing, save the last timestamp produced by the model. + // That is, the ending timestamp of the last fresh word. + lastNonSilentMoment_ = newEnd; + } + + // Commit matching words. + auto committed = hypothesisBuffer_.commit(); + auto nonCommitted = hypothesisBuffer_.hypothesis_; + + // Cut the audio buffer to not exceed the size threshold. + // Since Whisper does not accept waveforms longer than 30 seconds, we need + // to cut the audio at some safe point. + const float audioDuration = + static_cast(audioBuffer_.size()) / constants::kSamplingRate; + if (audioDuration > params::kStreamChunkThreshold) { + // Leave some portion of audio in, to improve model behavior + // in future iterations. + const float erasedDuration = + audioDuration - params::kStreamAudioBufferReserve; + const size_t nSamplesToErase = + static_cast(erasedDuration * constants::kSamplingRate); + + audioBuffer_.erase(audioBuffer_.begin(), + audioBuffer_.begin() + nSamplesToErase); + bufferTimeOffset_ += erasedDuration; + } + + return {.committed = move_to_vector(committed), + .nonCommitted = move_to_vector(nonCommitted)}; +} + +std::vector OnlineASR::finish() { + // We always push the last remaining hypothesis, even if it's not + // confirmed in second iteration. + auto remaining = hypothesisBuffer_.hypothesis_; + + reset(); + + return move_to_vector(remaining); +} + +void OnlineASR::reset() { + hypothesisBuffer_.reset(); + bufferTimeOffset_ = 0.f; + + audioBuffer_.clear(); +} + +} // namespace rnexecutorch::models::speech_to_text::whisper::stream diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h new file mode 100644 index 000000000..085fcc140 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h @@ -0,0 +1,79 @@ +#pragma once + +#include "../common/schema/OnlineASR.h" +#include "../common/types/ProcessResult.h" +#include "../common/types/Segment.h" +#include "../common/types/Word.h" +#include "ASR.h" +#include "HypothesisBuffer.h" + +namespace rnexecutorch::models::speech_to_text::whisper::stream { + +/** + * Online Automatic Speech Recognition (OnlineASR) for Whisper. + * It manages continuous processing of audio stream by maintaining a local + * audio buffer and using ASR to transcribe it in increments. + */ +class OnlineASR : public schema::OnlineASR { +public: + OnlineASR(const ASR *asr); + + /** + * Appends new audio samples to the internal processing buffer. + * + * @param audio A span of PCM float samples (expected 16kHz). + */ + void insertAudioChunk(std::span audio) override; + + /** + * Determines whether the model is ready to process the next iteration. + * + * @return True if audioBuffer has enough samples, False otherwise + */ + bool isReady() const override; + + /** + * Processes the current audio buffer and returns new transcription results. + * Stability is managed by an internal HypothesisBuffer to ensure that + * only confirmed (stable) text is returned as "committed". + * + * @param options Decoding configuration (language, etc.). + * @return A ProcessResult containing newly committed and uncommitted + * words. + */ + ProcessResult process(const DecodingOptions &options) override; + + /** + * Finalizes the current streaming session. + * Flushes any remaining words from the hypothesis buffer. + * + * @return A vector of remaining transcribed words. + */ + std::vector finish() override; + + /** + * Reset the streaming state by resetting the buffers + */ + void reset() override; + +private: + // ASR module connection for transcribing the audio + const ASR *asr_; + + // Helper buffers - audio buffer + // Stores the increasing amounts of streamed audio. + // Cleared from time to time after reaching a threshold size. + std::vector audioBuffer_ = {}; + float bufferTimeOffset_ = 0.f; // Audio buffer offset + + // Helper buffers - hypothesis buffer + // Manages the whisper streaming hypothesis mechanism. + HypothesisBuffer hypothesisBuffer_; + + // Silence estimation components + // Used primarily in timestamp range scaling algorithm. + float epsilon_ = 0.F; + float lastNonSilentMoment_ = 0.F; +}; + +} // namespace rnexecutorch::models::speech_to_text::whisper::stream \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h new file mode 100644 index 000000000..17607669f --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +/** + * Hyperparameters + * + * Those are adjustable values, which when changed, affect the behavior + * of the underlying model and/or algorithms. + */ +namespace rnexecutorch::models::speech_to_text::whisper::params { + +/** + * Determines the range of buffer left when skipping an audio chunk + * of size lower than maximum allowed chunk size. + * + * If the audio length does not exceed [kChunkSize * kSamplingRate] - [buffer], + * then instead of moving to the last returned timestamp, we jump across the + * entire 30 seconds chunk. This resolves the issue of multiple redundant + * segments being produced by the transcription algorithm. + */ +constexpr static int32_t kChunkBreakBuffer = 2; // [s] + +/** + * Determines the maximum timestamp difference available for a word to be + * considered as fresh in streaming algorithm. + */ +constexpr static float kStreamFreshThreshold = 1.F; // [s], originally 0.5 + +/** + * Determines the maximum expected size of overlapping fragments between + * fresh words buffer and commited words buffer in streaming mode. + * + * It is a limit of maximum amount of erased repeated words from fresh buffer. + * The bigger it gets, the less probable it is to commit the same phrase twice. + */ +constexpr static size_t kStreamMaxOverlapSize = + 12; // Number of overlaping words + +/** + * Similar to kMaxStreamOverlapSize, but this one determines + * the maximum allowed timestamp difference between the overlaping fragments. + */ +constexpr static float kStreamMaxOverlapTimestampDiff = 15.F; // [s] + +/** + * A threshold which exceeded causes the main streaming audio buffer to be + * cleared. + */ +constexpr static float kStreamChunkThreshold = 20.F; // [s] + +/** + * Decides how much of recent audio waveform should be kept in when + * clearing the audio buffer in streaming algorithm. + */ +constexpr static float kStreamAudioBufferReserve = 5.F; // [s] + +} // namespace rnexecutorch::models::speech_to_text::whisper::params \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h new file mode 100644 index 000000000..aec1aa1f5 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h @@ -0,0 +1,71 @@ +#pragma once + +#include "../common/types/Word.h" +#include +#include +#include +#include + +namespace rnexecutorch::models::speech_to_text::whisper::utils { + +/** + * Finds the largest (in number of words) overlaping fragment between word + * vectors A (suffix) and B (prefix). + * + * An overlaping fragment is any fragment C, which can be simultaneously a + * suffix of A and a prefix of B. Example: A = 'Jane likes food and playing + * games', B = 'playing games and sleeping', the overlap fragment C = 'playing + * games'. + * + * @param suffixVec An input vector, where only suffixes can overlap. + * Typically the 'commited' buffer in streaming algorithm. + * @param preffixVec An input vector, where only prefixes can overlap. + * Typically the 'fresh' buffer in streaming algorithm. + * @param maxCheckRange The maximum size of overlapping fragment. Determines the + * range of search. + * @param maxTimestampDiff The maximum allowed timestamp difference between + * overlaping fragments. If exceeded, the fragment are not considered as + * overlaping. + * @return The size of the largest found overlaping fragment. + */ +template +inline size_t findLargestOverlapingFragment(const Container &suffixVec, + const Container &prefixVec, + size_t maxCheckRange = 10, + float maxTimestampDiff = 100.f) { + size_t range = std::min({suffixVec.size(), prefixVec.size(), maxCheckRange}); + + if (range == 0) { + return 0; + } + + // Iterate backwards from the largest possible overlap size down to 1. + // i starts at the index where the suffix of length 'range' begins. + for (size_t i = suffixVec.size() - range; i < suffixVec.size(); ++i) { + // We search for overlaps by searching for the first word of prefixVec + if (suffixVec[i].content == prefixVec[0].content) { + size_t calculatedSize = suffixVec.size() - i; + + // Optimization: Check if the last elements match before full comparison + if (prefixVec[calculatedSize - 1].content != suffixVec.back().content) { + continue; + } + + bool isEqual = std::equal( + suffixVec.begin() + i, suffixVec.end(), prefixVec.begin(), + [maxTimestampDiff](const Word &sWord, const Word &pWord) { + return sWord.content == pWord.content && + std::fabs(sWord.start - pWord.start) <= maxTimestampDiff && + std::fabs(sWord.end - pWord.end) <= maxTimestampDiff; + }); + + if (isEqual) { + return calculatedSize; + } + } + } + + return 0; +} + +} // namespace rnexecutorch::models::speech_to_text::whisper::utils \ No newline at end of file diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 6e76e52b7..305286726 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -1,5 +1,5 @@ import { Platform } from 'react-native'; -import { URL_PREFIX, VERSION_TAG } from './versions'; +import { NEXT_VERSION_TAG, URL_PREFIX, VERSION_TAG } from './versions'; // LLMs @@ -417,102 +417,101 @@ export const STYLE_TRANSFER_UDNIE = { }; // S2T +const WHISPER_TINY_EN_MODEL = `${URL_PREFIX}-whisper-tiny.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_en_xnnpack.pte`; const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/tokenizer.json`; -const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`; -const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`; -const WHISPER_TINY_EN_ENCODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_encoder_xnnpack.pte`; -const WHISPER_TINY_EN_DECODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_decoder_xnnpack.pte`; +// const WHISPER_TINY_EN_ENCODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_encoder_xnnpack.pte`; +// const WHISPER_TINY_EN_DECODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_decoder_xnnpack.pte`; -const WHISPER_BASE_EN_TOKENIZER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/tokenizer.json`; -const WHISPER_BASE_EN_ENCODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_encoder_xnnpack.pte`; -const WHISPER_BASE_EN_DECODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_decoder_xnnpack.pte`; +// const WHISPER_BASE_EN_TOKENIZER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_BASE_EN_ENCODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_encoder_xnnpack.pte`; +// const WHISPER_BASE_EN_DECODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_decoder_xnnpack.pte`; -const WHISPER_SMALL_EN_TOKENIZER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/tokenizer.json`; -const WHISPER_SMALL_EN_ENCODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_encoder_xnnpack.pte`; -const WHISPER_SMALL_EN_DECODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_decoder_xnnpack.pte`; +// const WHISPER_SMALL_EN_TOKENIZER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_SMALL_EN_ENCODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_encoder_xnnpack.pte`; +// const WHISPER_SMALL_EN_DECODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_decoder_xnnpack.pte`; -const WHISPER_TINY_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/tokenizer.json`; -const WHISPER_TINY_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_encoder_xnnpack.pte`; -const WHISPER_TINY_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_decoder_xnnpack.pte`; +// const WHISPER_TINY_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_TINY_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_encoder_xnnpack.pte`; +// const WHISPER_TINY_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_decoder_xnnpack.pte`; -const WHISPER_BASE_TOKENIZER = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/tokenizer.json`; -const WHISPER_BASE_ENCODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_encoder_xnnpack.pte`; -const WHISPER_BASE_DECODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_decoder_xnnpack.pte`; +// const WHISPER_BASE_TOKENIZER = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_BASE_ENCODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_encoder_xnnpack.pte`; +// const WHISPER_BASE_DECODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_decoder_xnnpack.pte`; -const WHISPER_SMALL_TOKENIZER = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/tokenizer.json`; -const WHISPER_SMALL_ENCODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_encoder_xnnpack.pte`; -const WHISPER_SMALL_DECODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_decoder_xnnpack.pte`; +// const WHISPER_SMALL_TOKENIZER = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/tokenizer.json`; +// const WHISPER_SMALL_ENCODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_encoder_xnnpack.pte`; +// const WHISPER_SMALL_DECODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_decoder_xnnpack.pte`; /** * @category Models - Speech To Text */ export const WHISPER_TINY_EN = { + type: 'whisper' as const, isMultilingual: false, - encoderSource: WHISPER_TINY_EN_ENCODER, - decoderSource: WHISPER_TINY_EN_DECODER, + modelSource: WHISPER_TINY_EN_MODEL, tokenizerSource: WHISPER_TINY_EN_TOKENIZER, }; -/** - * @category Models - Speech To Text - */ -export const WHISPER_TINY_EN_QUANTIZED = { - isMultilingual: false, - encoderSource: WHISPER_TINY_EN_ENCODER_QUANTIZED, - decoderSource: WHISPER_TINY_EN_DECODER_QUANTIZED, - tokenizerSource: WHISPER_TINY_EN_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_BASE_EN = { - isMultilingual: false, - encoderSource: WHISPER_BASE_EN_ENCODER, - decoderSource: WHISPER_BASE_EN_DECODER, - tokenizerSource: WHISPER_BASE_EN_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_SMALL_EN = { - isMultilingual: false, - encoderSource: WHISPER_SMALL_EN_ENCODER, - decoderSource: WHISPER_SMALL_EN_DECODER, - tokenizerSource: WHISPER_SMALL_EN_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_TINY = { - isMultilingual: true, - encoderSource: WHISPER_TINY_ENCODER_MODEL, - decoderSource: WHISPER_TINY_DECODER_MODEL, - tokenizerSource: WHISPER_TINY_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_BASE = { - isMultilingual: true, - encoderSource: WHISPER_BASE_ENCODER_MODEL, - decoderSource: WHISPER_BASE_DECODER_MODEL, - tokenizerSource: WHISPER_BASE_TOKENIZER, -}; - -/** - * @category Models - Speech To Text - */ -export const WHISPER_SMALL = { - isMultilingual: true, - encoderSource: WHISPER_SMALL_ENCODER_MODEL, - decoderSource: WHISPER_SMALL_DECODER_MODEL, - tokenizerSource: WHISPER_SMALL_TOKENIZER, -}; +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_TINY_EN_QUANTIZED = { +// isMultilingual: false, +// encoderSource: WHISPER_TINY_EN_ENCODER_QUANTIZED, +// decoderSource: WHISPER_TINY_EN_DECODER_QUANTIZED, +// tokenizerSource: WHISPER_TINY_EN_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_BASE_EN = { +// isMultilingual: false, +// encoderSource: WHISPER_BASE_EN_ENCODER, +// decoderSource: WHISPER_BASE_EN_DECODER, +// tokenizerSource: WHISPER_BASE_EN_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_SMALL_EN = { +// isMultilingual: false, +// encoderSource: WHISPER_SMALL_EN_ENCODER, +// decoderSource: WHISPER_SMALL_EN_DECODER, +// tokenizerSource: WHISPER_SMALL_EN_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_TINY = { +// isMultilingual: true, +// encoderSource: WHISPER_TINY_ENCODER_MODEL, +// decoderSource: WHISPER_TINY_DECODER_MODEL, +// tokenizerSource: WHISPER_TINY_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_BASE = { +// isMultilingual: true, +// encoderSource: WHISPER_BASE_ENCODER_MODEL, +// decoderSource: WHISPER_BASE_DECODER_MODEL, +// tokenizerSource: WHISPER_BASE_TOKENIZER, +// }; + +// /** +// * @category Models - Speech To Text +// */ +// export const WHISPER_SMALL = { +// isMultilingual: true, +// encoderSource: WHISPER_SMALL_ENCODER_MODEL, +// decoderSource: WHISPER_SMALL_DECODER_MODEL, +// tokenizerSource: WHISPER_SMALL_TOKENIZER, +// }; // Image segmentation const DEEPLAB_V3_RESNET50_MODEL = `${URL_PREFIX}-deeplab-v3/${VERSION_TAG}/xnnpack/deeplabV3_xnnpack_fp32.pte`; diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts index fa3d8c685..e26b6e3c7 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts @@ -38,9 +38,9 @@ export const useSpeechToText = ({ setIsReady(false); await moduleInstance.load( { + type: model.type, isMultilingual: model.isMultilingual, - encoderSource: model.encoderSource, - decoderSource: model.decoderSource, + modelSource: model.modelSource, tokenizerSource: model.tokenizerSource, }, (progress) => { @@ -59,9 +59,9 @@ export const useSpeechToText = ({ }; }, [ moduleInstance, + model.type, model.isMultilingual, - model.encoderSource, - model.decoderSource, + model.modelSource, model.tokenizerSource, preventLoad, ]); diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index a42881f45..ef89146c5 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -23,9 +23,9 @@ declare global { schedulerStepsOffset: number ) => any; var loadSpeechToText: ( - encoderSource: string, - decoderSource: string, - modelName: string + modelName: string, + modelSource: string, + tokenizerSource: string ) => any; var loadTextToSpeechKokoro: ( lang: string, diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts index 64f4e953f..22f4eb77d 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts @@ -33,26 +33,23 @@ export class SpeechToTextModule { undefined, model.tokenizerSource ); - const encoderDecoderPromise = ResourceFetcher.fetch( + const modelPromise = ResourceFetcher.fetch( onDownloadProgressCallback, - model.encoderSource, - model.decoderSource + model.modelSource ); - const [tokenizerSources, encoderDecoderResults] = await Promise.all([ + const [tokenizerSources, modelSources] = await Promise.all([ tokenizerLoadPromise, - encoderDecoderPromise, + modelPromise, ]); - const encoderSource = encoderDecoderResults?.[0]; - const decoderSource = encoderDecoderResults?.[1]; - if (!encoderSource || !decoderSource || !tokenizerSources) { + if (!modelSources || !tokenizerSources) { throw new RnExecutorchError( RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' ); } this.nativeModule = await global.loadSpeechToText( - encoderSource, - decoderSource, + model.type, + modelSources[0]!, tokenizerSources[0]! ); } diff --git a/packages/react-native-executorch/src/types/stt.ts b/packages/react-native-executorch/src/types/stt.ts index bf7fc6436..df0ab063f 100644 --- a/packages/react-native-executorch/src/types/stt.ts +++ b/packages/react-native-executorch/src/types/stt.ts @@ -10,7 +10,7 @@ export interface SpeechToTextProps { /** * Configuration object containing model sources. */ - model: SpeechToTextModelConfig; + model: SpeechToTextModelConfig; // | ... /** * Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. */ @@ -261,20 +261,19 @@ export interface TranscriptionResult { * @category Types */ export interface SpeechToTextModelConfig { + type: 'whisper'; // | ... (add more in the future) + /** * A boolean flag indicating whether the model supports multiple languages. */ isMultilingual: boolean; /** - * A string that specifies the location of a `.pte` file for the encoder. - */ - encoderSource: ResourceSource; - - /** - * A string that specifies the location of a `.pte` file for the decoder. + * A string that specifies the location of a `.pte` file for the model. + * + * We expect the model to have 2 bundled methods: 'decode' and 'encode'. */ - decoderSource: ResourceSource; + modelSource: ResourceSource; /** * A string that specifies the location to the tokenizer for the model.