diff --git a/public/models/ggml-model-whisper-base.en-q5_1.bin b/public/models/ggml-model-whisper-base.en-q5_1.bin deleted file mode 100644 index 1f4c0515..00000000 Binary files a/public/models/ggml-model-whisper-base.en-q5_1.bin and /dev/null differ diff --git a/public/models/ggml-model-whisper-tiny-q5_1.bin b/public/models/ggml-model-whisper-tiny-q5_1.bin new file mode 100644 index 00000000..a05a475a --- /dev/null +++ b/public/models/ggml-model-whisper-tiny-q5_1.bin @@ -0,0 +1 @@ +Entry not found \ No newline at end of file diff --git a/src/components/api/returnAPI.tsx b/src/components/api/returnAPI.tsx index 3524d5c6..da7a9973 100644 --- a/src/components/api/returnAPI.tsx +++ b/src/components/api/returnAPI.tsx @@ -1,15 +1,13 @@ -// import * as sdk from 'microsoft-cognitiveservices-speech-sdk' +// src/components/api/returnAPI.tsx + +// import * * as sdk from 'microsoft-cognitiveservices-speech-sdk' import installCOIServiceWorker from './coi-serviceworker' import { API, PlaybackStatus } from '../../react-redux&middleware/redux/typesImports'; import { - ApiStatus, - AzureStatus, - ControlStatus, - SRecognition, - StreamTextStatus, - ScribearServerStatus + ApiStatus, AzureStatus, ControlStatus, SRecognition, + StreamTextStatus, ScribearServerStatus } from '../../react-redux&middleware/redux/typesImports'; -import { useEffect, useState } from 'react'; +import { useEffect, useRef, useState } from 'react'; import { batch, useDispatch, useSelector } from 'react-redux'; import { AzureRecognizer } from './azure/azureRecognizer'; @@ -24,182 +22,210 @@ import { PlaybackRecognizer } from './playback/playbackRecognizer'; import { ScribearRecognizer } from './scribearServer/scribearRecognizer'; import type { SelectedOption } from '../../react-redux&middleware/redux/types/modelSelection'; -// import { PlaybackReducer } from '../../react-redux&middleware/redux/reducers/apiReducers'; -// controls what api to send and what to do when error handling. - -// NOTES: this needs to do everything I think. Handler should be returned which allows -// event call like stop and the event should be returned... (maybe the recognition? idk.) - -/* -* === * === * DO NOT DELETE IN ANY CIRCUMSTANCE * === * === * -* === * TRIBUTE TO THE ORIGINAL AUTHOR OF THIS CODE: Will * === * -DO NOT DELETE IN ANY CIRCUMSTANCE -export const returnRecogAPI = (api : ApiStatus, control : ControlStatus, azure : AzureStatus) => { - // const apiStatus = useSelector((state: RootState) => { - // return state.APIStatusReducer as ApiStatus; - // }) - // const control = useSelector((state: RootState) => { - // return state.ControlReducer as ControlStatus; - // }); - // const azureStatus = useSelector((state: RootState) => { - // return state.AzureReducer as AzureStatus; - // }) - const recognition : Promise = getRecognition(api.currentApi, control, azure); - const useRecognition : Object = makeRecognition(api.currentApi); - // const recogHandler : Function = handler(api.currentApi); - - - return ({ useRecognition, recognition }); -} -* === * === * DO NOT DELETE IN ANY CIRCUMSTANCE * === * === * -* === * TRIBUTE TO THE ORIGINAL AUTHOR OF THIS CODE: Will * === * -*/ - - - -function toWhisperCode(bcp47: string): string { - // Accept "en", "en-US", etc. Return "en" for "en-US". - if (!bcp47) return "en"; - const base = bcp47.split('-')[0].toLowerCase(); - const supported = new Set([ - 'en','es','fr','de','it','pt','nl','sv','da','nb','fi','pl','cs','sk','sl','hr','sr','bg','ro', - 'hu','el','tr','ru','uk','ar','he','fa','ur','hi','bn','ta','te','ml','mr','gu','kn','pa', - 'id','ms','vi','th','zh','ja','ko' - ]); - return supported.has(base) ? base : 'en'; - } - -const createRecognizer = (currentApi: number, control: ControlStatus, azure: AzureStatus, streamTextConfig: StreamTextStatus, scribearServerStatus: ScribearServerStatus, selectedModelOption: SelectedOption, playbackStatus: PlaybackStatus): Recognizer => { - if (currentApi === API.SCRIBEAR_SERVER) { - return new ScribearRecognizer(scribearServerStatus, selectedModelOption, control.speechLanguage.CountryCode); - } else if (currentApi === API.PLAYBACK) { - return new PlaybackRecognizer(playbackStatus); - } - if (currentApi === API.WEBSPEECH) { - return new WebSpeechRecognizer(null, control.speechLanguage.CountryCode); - } else if (currentApi === API.AZURE_TRANSLATION) { - return new AzureRecognizer(null, control.speechLanguage.CountryCode, azure); - } - else if (currentApi === API.AZURE_CONVERSATION) { - throw new Error("Not implemented"); - } - else if (currentApi === API.STREAM_TEXT) { - // Placeholder - this is just WebSpeech for now - return new StreamTextRecognizer(streamTextConfig.streamTextEvent, 'en', streamTextConfig.startPosition); - } else if (currentApi === API.WHISPER) { - return new WhisperRecognizer( - null, - toWhisperCode(control.speechLanguage.CountryCode), - 4 - ); - } else { - throw new Error(`Unexpcted API_CODE: ${currentApi}`); - } +import { selectSelectedModel } from '../../react-redux&middleware/redux/reducers/modelSelectionReducers'; + +/** Only two tiny models */ +function getWhisperModelKey(selected: any): 'tiny-en-q5_1' | 'tiny-q5_1' { + const allowed = new Set(['tiny-en-q5_1', 'tiny-q5_1']); + if (typeof selected === 'string') { + if (selected === 'tiny') return 'tiny-en-q5_1'; + return (allowed.has(selected) ? selected : 'tiny-en-q5_1') as any; + } + if (selected && typeof selected === 'object') { + const k = (selected as any).model_key || (selected as any).key; + if (k === 'tiny') return 'tiny-en-q5_1'; + return (allowed.has(k) ? k : 'tiny-en-q5_1') as any; + } + return 'tiny-en-q5_1'; } -/** - * Make a callback function that updates the Redux transcript using new final blocks and new - * in-progress block - * - * We have to do things in this roundabout way to have access to dispatch in a callback function, - * see https://stackoverflow.com/questions/59456816/how-to-call-usedispatch-in-a-callback - * @param dispatch A Redux dispatch function - */ -const updateTranscript = (dispatch: Dispatch) => (newFinalBlocks: Array, newInProgressBlock: TranscriptBlock): void => { - // console.log(`Updating transcript using these blocks: `, newFinalBlocks, newInProgressBlock) - // batch makes these dispatches only cause one re-rendering - batch(() => { - for (const block of newFinalBlocks) { - dispatch({ type: "transcript/new_final_block", payload: block }); - } - dispatch({ type: 'transcript/update_in_progress_block', payload: newInProgressBlock }); - }) +/** Accept BCP-47 *or* plain-language labels; default wisely. */ +function toWhisperCodeLoose(input: string | undefined, selectedWhisperModel: any): string { + const model = getWhisperModelKey(selectedWhisperModel); + const fallback = model === 'tiny-q5_1' ? 'auto' : 'en'; + + if (!input || typeof input !== 'string') return fallback; + const s = input.toLowerCase().trim(); + + // If it already looks like a code (e.g., "zh-cn"), reduce to base. + const base = s.split(/[^a-z]/)[0] || s; + const codes = new Set([ + 'en','es','fr','de','it','pt','nl','sv','da','nb','fi','pl','cs','sk','sl','hr','sr','bg','ro', + 'hu','el','tr','ru','uk','ar','he','fa','ur','hi','bn','ta','te','ml','mr','gu','kn','pa', + 'id','ms','vi','th','zh','ja','ko','auto' + ]); + if (codes.has(base)) return base; + + // Heuristic for common menu labels + const map: Array<[string,string]> = [ + ['english', 'en'], + ['chinese', 'zh'], + ['mandarin', 'zh'], + ['cantonese', 'zh'], + ['japanese', 'ja'], + ['korean', 'ko'], + ['spanish', 'es'], + ['french', 'fr'], + ['german', 'de'], + ['portuguese', 'pt'], + ['russian', 'ru'], + ['arabic', 'ar'], + ['hindi', 'hi'], + ['thai', 'th'], + ['vietnamese', 'vi'], + ]; + for (const [kw, code] of map) if (s.includes(kw)) return code; + + return fallback; } -/** - * Syncs up the recognizer with the API selection and listening status - * - Creates new recognizer and stop old ones when API is changed - * - Start / stop recognizer as listening changes - * - Feed any phrase list updates to azure recognizer - * - * @param recog - * @param api - * @param control - * @param azure - * - * @return transcripts, resetTranscript, recogHandler - */ -export const useRecognition = (sRecog: SRecognition, api: ApiStatus, control: ControlStatus, - azure: AzureStatus, streamTextConfig: StreamTextStatus, scribearServerStatus, selectedModelOption: SelectedOption, playbackStatus: PlaybackStatus) => { - - const [recognizer, setRecognizer] = useState(); - // TODO: Add a reset button to utitlize resetTranscript - // const [resetTranscript, setResetTranscript] = useState<() => string>(() => () => dispatch('RESET_TRANSCRIPT')); - const dispatch = useDispatch(); - - // Register service worker for whisper on launch - useEffect(() => { - installCOIServiceWorker(); - }, []) - - // Change recognizer, if api changed - useEffect(() => { - console.log("UseRecognition, switching to new recognizer: ", api.currentApi); - - let newRecognizer: Recognizer | null; - try { - // Create new recognizer, and subscribe to its events - newRecognizer = createRecognizer(api.currentApi, control, azure, streamTextConfig, scribearServerStatus, selectedModelOption, playbackStatus); - newRecognizer.onTranscribed(updateTranscript(dispatch)); - setRecognizer(newRecognizer) - - // Start new recognizer if necessary - if (control.listening) { - console.log("UseRecognition, attempting to start recognizer after switching") - newRecognizer.start() - } - } catch (e) { - console.log("UseRecognition, failed to switch to new recognizer: ", e); - } - - return () => { - // Stop current recognizer when switching to another one, if possible - newRecognizer?.stop(); - } - }, [api.currentApi, azure, control, streamTextConfig, playbackStatus, scribearServerStatus, selectedModelOption]); +const createRecognizer = ( + currentApi: number, + control: ControlStatus, + azure: AzureStatus, + streamTextConfig: StreamTextStatus, + scribearServerStatus: ScribearServerStatus, + selectedModelOption: SelectedOption, + selectedWhisperModel: any, + playbackStatus: PlaybackStatus +): Recognizer => { + if (currentApi === API.SCRIBEAR_SERVER) { + return new ScribearRecognizer( + scribearServerStatus, + selectedModelOption, + control.speechLanguage.CountryCode + ); + } else if (currentApi === API.PLAYBACK) { + return new PlaybackRecognizer(playbackStatus); + } + + if (currentApi === API.WEBSPEECH) { + return new WebSpeechRecognizer(null, control.speechLanguage.CountryCode); + } else if (currentApi === API.AZURE_TRANSLATION) { + return new AzureRecognizer(null, control.speechLanguage.CountryCode, azure); + } else if (currentApi === API.AZURE_CONVERSATION) { + throw new Error("Not implemented"); + } else if (currentApi === API.STREAM_TEXT) { + return new StreamTextRecognizer( + streamTextConfig.streamTextEvent, + 'en', + streamTextConfig.startPosition + ); + } else if (currentApi === API.WHISPER) { + const lang = toWhisperCodeLoose(control.speechLanguage.CountryCode, selectedWhisperModel); + return new WhisperRecognizer(null, lang, 4, getWhisperModelKey(selectedWhisperModel)); + } else { + throw new Error(`Unexpcted API_CODE: ${currentApi}`); + } +} - // Start / stop recognizer, if listening toggled - useEffect(() => { - if (!recognizer) { // whipser won't have recogHandler - return; +const updateTranscript = (dispatch: Dispatch) => + (newFinalBlocks: Array, newInProgressBlock: TranscriptBlock): void => { + batch(() => { + for (const block of newFinalBlocks) { + dispatch({ type: "transcript/new_final_block", payload: block }); } - if (control.listening) { - console.log("UseRecognition, sending start signal to recognizer") - recognizer.start(); - } else if (!control.listening) { - console.log("UseRecognition, sending stop signal to recognizer") - recognizer.stop(); + dispatch({ type: 'transcript/update_in_progress_block', payload: newInProgressBlock }); + }); + }; + +export const useRecognition = ( + sRecog: SRecognition, + api: ApiStatus, + control: ControlStatus, + azure: AzureStatus, + streamTextConfig: StreamTextStatus, + scribearServerStatus: ScribearServerStatus, + selectedModelOption: SelectedOption, + playbackStatus: PlaybackStatus +) => { + const [recognizer, setRecognizer] = useState(); + const dispatch = useDispatch(); + const selectedWhisperModel = useSelector(selectSelectedModel); + const startingRef = useRef(false); + + useEffect(() => { installCOIServiceWorker(); }, []); + + useEffect(() => { + let newRecognizer: Recognizer | null = null; + try { + if (api.currentApi === API.WHISPER) { + recognizer?.stop(); + setRecognizer(undefined); + return () => {}; } - }, [control.listening]); - - // Update domain phrases for azure recognizer - useEffect(() => { - console.log("UseRecognition, changing azure phrases", azure.phrases); - if (api.currentApi === API.AZURE_TRANSLATION && recognizer) { - (recognizer as AzureRecognizer).addPhrases(azure.phrases); + newRecognizer = createRecognizer( + api.currentApi, control, azure, streamTextConfig, scribearServerStatus, + selectedModelOption, selectedWhisperModel, playbackStatus + ); + newRecognizer.onTranscribed(updateTranscript(dispatch)); + setRecognizer(newRecognizer); + if (control.listening) newRecognizer.start(); + } catch (e) { + console.log("UseRecognition, failed to switch to new recognizer: ", e); + } + return () => { newRecognizer?.stop(); }; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [ + api.currentApi, + control.speechLanguage.CountryCode, + selectedWhisperModel + ]); + + useEffect(() => { + if (api.currentApi !== API.WHISPER) { + if (!recognizer) return; + if (control.listening) recognizer.start(); else recognizer.stop(); + return; + } + + if (control.listening) { + if (!recognizer && !startingRef.current) { + startingRef.current = true; + try { + const r = createRecognizer( + api.currentApi, control, azure, streamTextConfig, scribearServerStatus, + selectedModelOption, selectedWhisperModel, playbackStatus + ); + r.onTranscribed(updateTranscript(dispatch)); + setRecognizer(r); + setTimeout(() => { r.start(); startingRef.current = false; }, 0); + } catch (e) { + console.log('Whisper lazy-create failed: ', e); + startingRef.current = false; + } + } else if (recognizer) { + recognizer.start(); } - }, [azure.phrases]); - - // TODO: whisper's transcript is not in redux store but only in sessionStorage at the moment. - let transcript: string = useSelector((state: RootState) => { - return state.TranscriptReducer.transcripts[0].toString() - }); - // if (api.currentApi === API.WHISPER) { - // // TODO: inefficient to get it from sessionStorage everytime - // // TODO: add whisper_transcript to redux store after integrating "whisper" folder (containing stream.js) into ScribeAR - // transcript = sessionStorage.getItem('whisper_transcript') || ''; - // return transcript; - // } - - return transcript; + } else { + recognizer?.stop(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [ + api.currentApi, + control.listening, + selectedWhisperModel, + control.speechLanguage.CountryCode + ]); + + // Live language push (handles non-BCP-47 labels too) + useEffect(() => { + if (api.currentApi === API.WHISPER && recognizer) { + (recognizer as any)?.setLanguage?.( + toWhisperCodeLoose(control.speechLanguage.CountryCode, selectedWhisperModel) + ); + } + }, [api.currentApi, control.speechLanguage.CountryCode, recognizer, selectedWhisperModel]); + + useEffect(() => { + if (api.currentApi === API.AZURE_TRANSLATION && recognizer) { + (recognizer as AzureRecognizer).addPhrases(azure.phrases); + } + }, [azure.phrases, api.currentApi, recognizer]); + + const transcript: string = useSelector((state: RootState) => + state.TranscriptReducer.transcripts[0].toString() + ); + + return transcript; } diff --git a/src/components/api/whisper/indexedDB.js b/src/components/api/whisper/indexedDB.js index e906a4f0..e7d05a91 100644 --- a/src/components/api/whisper/indexedDB.js +++ b/src/components/api/whisper/indexedDB.js @@ -1,169 +1,168 @@ -// HACK: moving global variables from index.html to here for loadRemote - -let dbVersion = 1 -let dbName = 'whisper.ggerganov.com'; -let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB - -// fetch a remote file from remote URL using the Fetch API -async function fetchRemote(url, cbProgress, cbPrint) { - cbPrint('fetchRemote: downloading with fetch()...'); - - const response = await fetch( - url, - { - method: 'GET', - } - ); +// Robust, silent caching + URL fallbacks for Whisper model binaries + +var dbVersion = 1; +var dbName = 'whisper.ggerganov.com'; +var _indexedDB = + (window.indexedDB || + window.mozIndexedDB || + window.webkitIndexedDB || + window.msIndexedDB); + +// --- tiny helpers ----------------------------------------------------------- + +function log(cbPrint, msg){ try { cbPrint && cbPrint(msg); } catch(_){} } + +function isValidModelBytes(buf, contentType) { + if (!buf) return false; + // Reject obvious HTML/JSON or tiny payloads (typical CRA index.html ~2–5 KB) + if (buf.byteLength < 4096) return false; + if (!contentType) return true; + const ct = String(contentType).toLowerCase(); + if (ct.includes('text/html')) return false; + if (ct.includes('json')) return false; + return true; +} - if (!response.ok) { - cbPrint('fetchRemote: failed to fetch ' + url); - return; +// Fetch bytes from a URL (no prompt, no cache), return Uint8Array or null +async function fetchBytes(url, cbProgress, cbPrint) { + log(cbPrint, 'fetchRemote: GET ' + url); + const res = await fetch(url, { cache: 'no-cache', method: 'GET' }); + if (!res.ok) { + log(cbPrint, 'fetchRemote: HTTP ' + res.status + ' @ ' + url); + return null; + } + const totalHdr = res.headers.get('content-length'); + const total = totalHdr ? parseInt(totalHdr, 10) : undefined; + + // If stream not available, do arrayBuffer directly + if (!res.body || !res.body.getReader) { + const buf = await res.arrayBuffer(); + if (!isValidModelBytes(buf, res.headers.get('content-type'))) return null; + cbProgress && cbProgress(1); + return new Uint8Array(buf); + } + + const reader = res.body.getReader(); + let received = 0, chunks = [], lastReport = -1; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + chunks.push(value); + received += value.length; + if (total) { + const frac = received / total; + cbProgress && cbProgress(frac); + const bucket = Math.round(frac * 10); + if (bucket !== lastReport) { + log(cbPrint, 'fetchRemote: fetching ' + (bucket * 10) + '% ...'); + lastReport = bucket; + } } + } + + // Concat + const out = new Uint8Array(received); + let pos = 0; + for (let i = 0; i < chunks.length; i++) { + out.set(chunks[i], pos); + pos += chunks[i].length; + } + + // Final sanity + const ct = res.headers.get('content-type'); + if (!isValidModelBytes(out, ct)) return null; + if (!total) { cbProgress && cbProgress(1); } + + return out; +} - const contentLength = response.headers.get('content-length'); - const total = parseInt(contentLength, 10); - const reader = response.body.getReader(); - - var chunks = []; - var receivedLength = 0; - var progressLast = -1; +// IDB helpers +function idbGet(db, key) { + return new Promise(resolve => { + try { + const tx = db.transaction(['models'], 'readonly'); + const os = tx.objectStore('models'); + const rq = os.get(key); + rq.onsuccess = () => resolve(rq.result || null); + rq.onerror = () => resolve(null); + } catch (_) { resolve(null); } + }); +} - while (true) { - const { done, value } = await reader.read(); +function idbPut(db, key, bytes, cbPrint) { + return new Promise(resolve => { + try { + const tx = db.transaction(['models'], 'readwrite'); + const os = tx.objectStore('models'); + const rq = os.put(bytes, key); + rq.onsuccess = () => { log(cbPrint, 'loadRemote: stored in IDB: ' + key); resolve(true); }; + rq.onerror = () => { log(cbPrint, 'loadRemote: IDB put failed (non-fatal)'); resolve(false); }; + } catch (e) { + log(cbPrint, 'loadRemote: IDB exception: ' + e); + resolve(false); + } + }); +} - if (done) { - break; - } +function openDB(cbPrint) { + return new Promise(resolve => { + const rq = _indexedDB.open(dbName, dbVersion); + rq.onupgradeneeded = (ev) => { + const db = ev.target.result; + if (!db.objectStoreNames.contains('models')) { + db.createObjectStore('models', { autoIncrement: false }); + log(cbPrint, 'loadRemote: created IDB store'); + } + }; + rq.onsuccess = () => resolve(rq.result); + rq.onerror = () => resolve(null); + rq.onblocked = () => resolve(null); + rq.onabort = () => resolve(null); + }); +} - chunks.push(value); - receivedLength += value.length; +// --- PUBLIC: try a list of URLs, use cache per-URL, stop on first good one --- +export async function loadRemoteWithFallbacks(urls, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) { + try { + if (navigator.storage?.estimate) { + const est = await navigator.storage.estimate(); + log(cbPrint, 'loadRemote: storage quota: ' + est.quota + ' bytes'); + log(cbPrint, 'loadRemote: storage usage: ' + est.usage + ' bytes'); + } + } catch (_) {} - if (contentLength) { - cbProgress(receivedLength/total); + // De-dup & filter falsy + const list = Array.from(new Set((urls || []).filter(Boolean))); + if (!list.length) { cbCancel && cbCancel(); return; } - var progressCur = Math.round((receivedLength / total) * 10); - if (progressCur != progressLast) { - cbPrint('fetchRemote: fetching ' + 10*progressCur + '% ...'); - progressLast = progressCur; - } - } - } + const db = await openDB(cbPrint); - var position = 0; - var chunksAll = new Uint8Array(receivedLength); + for (let i = 0; i < list.length; i++) { + const url = list[i]; - for (var chunk of chunks) { - chunksAll.set(chunk, position); - position += chunk.length; + // 1) Cache hit? + if (db) { + const cached = await idbGet(db, url); + if (cached && cached.byteLength > 4096) { + log(cbPrint, `loadRemote: cache hit for ${url}`); + cbReady && cbReady(dst, cached instanceof Uint8Array ? cached : new Uint8Array(cached)); + return; + } } - return chunksAll; -} - -// load remote data -// - check if the data is already in the IndexedDB -// - if not, fetch it from the remote URL and store it in the IndexedDB -export function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) { - if (!navigator.storage || !navigator.storage.estimate) { - cbPrint('loadRemote: navigator.storage.estimate() is not supported'); + log(cbPrint, `loadRemote: cache miss; downloading ~${size_mb} MB`); + const bytes = await fetchBytes(url, cbProgress, cbPrint); + if (bytes && bytes.byteLength > 4096) { + if (db) { await idbPut(db, url, bytes, cbPrint); } + cbReady && cbReady(dst, bytes); + return; } else { - // query the storage quota and print it - navigator.storage.estimate().then(function (estimate) { - cbPrint('loadRemote: storage quota: ' + estimate.quota + ' bytes'); - cbPrint('loadRemote: storage usage: ' + estimate.usage + ' bytes'); - }); + log(cbPrint, `fetchWithFallbacks: "${url}" did not look like a model, trying next...`); } + } - // check if the data is already in the IndexedDB - var rq = indexedDB.open(dbName, dbVersion); - - rq.onupgradeneeded = function (event) { - var db = event.target.result; - if (db.version == 1) { - var os = db.createObjectStore('models', { autoIncrement: false }); - cbPrint('loadRemote: created IndexedDB ' + db.name + ' version ' + db.version); - } else { - // clear the database - var os = event.currentTarget.transaction.objectStore('models'); - os.clear(); - cbPrint('loadRemote: cleared IndexedDB ' + db.name + ' version ' + db.version); - } - }; - - rq.onsuccess = function (event) { - var db = event.target.result; - var tx = db.transaction(['models'], 'readonly'); - var os = tx.objectStore('models'); - var rq = os.get(url); - - rq.onsuccess = function (event) { - if (rq.result) { - cbPrint('loadRemote: "' + url + '" is already in the IndexedDB'); - cbReady(dst, rq.result); - } else { - // data is not in the IndexedDB - cbPrint('loadRemote: "' + url + '" is not in the IndexedDB'); - - // alert and ask the user to confirm - if (!confirm( - 'You are about to download ' + size_mb + ' MB of data.\n' + - 'The model data will be cached in the browser for future use.\n\n' + - 'Press OK to continue.')) { - cbCancel(); - return; - } - - fetchRemote(url, cbProgress, cbPrint).then(function (data) { - if (data) { - // store the data in the IndexedDB - var rq = indexedDB.open(dbName, dbVersion); - rq.onsuccess = function (event) { - var db = event.target.result; - var tx = db.transaction(['models'], 'readwrite'); - var os = tx.objectStore('models'); - - var rq = null; - try { - var rq = os.put(data, url); - } catch (e) { - cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB: \n' + e); - cbCancel(); - return; - } - - rq.onsuccess = function (event) { - cbPrint('loadRemote: "' + url + '" stored in the IndexedDB'); - cbReady(dst, data); - }; - - rq.onerror = function (event) { - cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB'); - cbCancel(); - }; - }; - } - }); - } - }; - - rq.onerror = function (event) { - cbPrint('loadRemote: failed to get data from the IndexedDB'); - cbCancel(); - }; - }; - - rq.onerror = function (event) { - cbPrint('loadRemote: failed to open IndexedDB'); - cbCancel(); - }; - - rq.onblocked = function (event) { - cbPrint('loadRemote: failed to open IndexedDB: blocked'); - cbCancel(); - }; - - rq.onabort = function (event) { - cbPrint('loadRemote: failed to open IndexedDB: abort'); - cbCancel(); - }; + // All failed + log(cbPrint, 'loadRemote: all fetch attempts failed'); + cbCancel && cbCancel(); } diff --git a/src/components/api/whisper/whisperRecognizer.tsx b/src/components/api/whisper/whisperRecognizer.tsx index ad61e21f..7711545a 100644 --- a/src/components/api/whisper/whisperRecognizer.tsx +++ b/src/components/api/whisper/whisperRecognizer.tsx @@ -1,281 +1,260 @@ import { Recognizer } from '../recognizer'; import { TranscriptBlock } from '../../../react-redux&middleware/redux/types/TranscriptTypes'; import makeWhisper from './libstream'; -import { loadRemote } from './indexedDB' +import { loadRemoteWithFallbacks } from './indexedDB'; import { SIPOAudioBuffer } from './sipo-audio-buffer'; -import RecordRTC, {StereoAudioRecorder} from 'recordrtc'; +import RecordRTC, { StereoAudioRecorder } from 'recordrtc'; import WavDecoder from 'wav-decoder'; -// https://stackoverflow.com/questions/4554252/typed-arrays-in-gecko-2-float32array-concatenation-and-expansion -function Float32Concat(first, second) { - const firstLength = first.length; - const result = new Float32Array(firstLength + second.length); +type ModelKey = 'tiny-en-q5_1' | 'tiny-q5_1'; - result.set(first); - result.set(second, firstLength); +// concat helper +function Float32Concat(a: Float32Array, b: Float32Array) { + const out = new Float32Array(a.length + b.length); + out.set(a, 0); + out.set(b, a.length); + return out; +} - return result; +// Build all plausible URLs where the model might live +function resolveModelUrls(filename: string): string[] { + const pub = (process.env.PUBLIC_URL || '').replace(/\/$/, ''); + const origin = window.location.origin; + const curPath = window.location.pathname.replace(/\/$/, ''); + return [ + `${pub}/models/${filename}`, + `${origin}${pub}/models/${filename}`, + `/models/${filename}`, + `${origin}/models/${filename}`, + `${curPath}/models/${filename}`, + ]; } -/** - * Wrapper for Web Assembly implementation of whisper.cpp - */ export class WhisperRecognizer implements Recognizer { - - // - // Audio params - // - private kSampleRate = 16000; - // Length of the suffix of the captured audio that whisper processes each time (in seconds) - private kWindowLength = 5; - // Number of samples in each audio chunk the audio worklet receives - private kChunkLength = 128; - - /** - * Audio context and buffer used to capture speech - */ - private context: AudioContext; - private audio_buffer: SIPOAudioBuffer; - private recorder?: RecordRTC; - - /** - * Instance of the whisper wasm module, and its variables - */ - private whisper: any = null; - private model_name: string = ""; - private model_index: Number = -1; - private language: string = "en"; - private num_threads: number; - - private transcribed_callback: ((newFinalBlocks: Array, newInProgressBlock: TranscriptBlock) => void) | null = null; - - /** - * Creates an Whisper recognizer instance that listens to the default microphone - * and expects speech in the given language - * @param audioSource Not implemented yet - * @param language Not implemented yet - * @parem num_threads Number of worker threads that whisper uses - */ - constructor(audioSource: any, language: string, num_threads: number = 4, model: string = "tiny-en-q5_1") { - this.num_threads = num_threads; - this.language = language; - this.model_name = model; - - const num_chunks = this.kWindowLength * this.kSampleRate / this.kChunkLength; - this.audio_buffer = new SIPOAudioBuffer(num_chunks, this.kChunkLength); - - this.context = new AudioContext({ - sampleRate: this.kSampleRate, - }); - } - - private print = (text: string) => { - if (this.transcribed_callback != null) { - let block = new TranscriptBlock(); - block.text = text; - this.transcribed_callback([block], new TranscriptBlock()); - } - } - - private printDebug = (text: string) => { - console.log(text); - } - - private storeFS(fname, buf) { - // write to WASM file using FS_createDataFile - // if the file exists, delete it - try { - this.whisper.FS_unlink(fname); - } catch (e) { - // ignore - } - this.whisper.FS_createDataFile("/", fname, buf, true, true); - this.printDebug('storeFS: stored model: ' + fname + ' size: ' + buf.length); - } - - - /** - * Async load the WASM module, ggml model, and Audio Worklet needed for whisper to work - */ - public async loadWhisper() { - // Load wasm and ggml - this.whisper = await makeWhisper({ - print: this.print, - printErr: this.printDebug, - setStatus: function(text) { - this.printErr('js: ' + text); - }, - monitorRunDependencies: function(left) { - } - }) - await this.load_model(this.model_name); - this.model_index = this.whisper.init('whisper.bin'); - - console.log("Whisper: Done instantiating whisper", this.whisper, this.model_index); - - // Set up audio source - let mic_stream = await navigator.mediaDevices.getUserMedia({audio: true, video: false}); - // let source = this.context.createMediaStreamSource(mic_stream); - - let last_suffix = new Float32Array(0); - this.recorder = new RecordRTC(mic_stream, { - type: 'audio', - mimeType: 'audio/wav', - desiredSampRate: this.kSampleRate, - timeSlice: 250, - ondataavailable: async (blob: Blob) => { - // Convert wav chunk to PCM - const array_buffer = await blob.arrayBuffer(); - const {channelData} = await WavDecoder.decode(array_buffer); - // Should be 16k, float32, stereo pcm data - // Just get 1 channel - let pcm_data = channelData[0]; - - // Prepend previous suffix and update with current suffix - pcm_data = Float32Concat(last_suffix, pcm_data); - last_suffix = pcm_data.slice(-(pcm_data.length % 128)) - - // Feed process_recorder_message audio in 128 sample chunks - for (let i = 0; i < pcm_data.length - 127; i+= 128) { - const audio_chunk = pcm_data.subarray(i, i + 128) - - this.process_recorder_message(audio_chunk); - } - - }, - recorderType: StereoAudioRecorder, - numberOfAudioChannels: 1, - }); - - this.recorder.startRecording(); - console.log("Whisper: Done setting up audio context"); - } - - private async load_model(model: string) { - let urls = { - 'tiny.en': 'ggml-model-whisper-tiny.en.bin', - 'tiny': 'ggml-model-whisper-tiny.bin', - 'base.en': 'ggml-model-whisper-base.en.bin', - 'base': 'ggml-model-whisper-base.bin', - 'small.en': 'ggml-model-whisper-small.en.bin', - 'small': 'ggml-model-whisper-small.bin', - - 'tiny-en-q5_1': 'ggml-model-whisper-tiny.en-q5_1.bin', - 'tiny-q5_1': 'ggml-model-whisper-tiny-q5_1.bin', - 'base-en-q5_1': 'ggml-model-whisper-base.en-q5_1.bin', - 'base-q5_1': 'ggml-model-whisper-base-q5_1.bin', - 'small-en-q5_1': 'ggml-model-whisper-small.en-q5_1.bin', - 'small-q5_1': 'ggml-model-whisper-small-q5_1.bin', - 'medium-en-q5_0':'ggml-model-whisper-medium.en-q5_0.bin', - 'medium-q5_0': 'ggml-model-whisper-medium-q5_0.bin', - 'large-q5_0': 'ggml-model-whisper-large-q5_0.bin', - }; - let sizes = { - 'tiny.en': 75, - 'tiny': 75, - 'base.en': 142, - 'base': 142, - 'small.en': 466, - 'small': 466, - - 'tiny-en-q5_1': 31, - 'tiny-q5_1': 31, - 'base-en-q5_1': 57, - 'base-q5_1': 57, - 'small-en-q5_1': 182, - 'small-q5_1': 182, - 'medium-en-q5_0': 515, - 'medium-q5_0': 515, - 'large-q5_0': 1030, - }; - - let url = process.env.PUBLIC_URL + "/models/" + urls[model]; - let dst = 'whisper.bin'; - let size_mb = sizes[model]; - - // HACK: turn loadRemote into a promise so that we can chain .then - let that = this; - return new Promise((resolve, reject) => { - loadRemote(url, - dst, - size_mb, - (text) => {}, - (fname, buf) => { - that.storeFS(fname, buf); - resolve(); - }, - () => { - reject(); - }, - that.printDebug, - ); - }) + private kSampleRate = 16000; + private kWindowLength = 5; // sec + private kChunkLength = 128; // samples + + private context: AudioContext; + private audio_buffer: SIPOAudioBuffer; + private recorder?: RecordRTC; + + private whisper: any = null; + private model_name: ModelKey = 'tiny-en-q5_1'; + private model_index = -1; + + /** + * language semantics: + * - 'zh' -> assume Chinese speech; translate -> English + * - 'en' -> assume English; transcribe -> English (no translation flag) + * - 'auto'-> enable detect_language in wrapper; transcribe in detected language + * - others (e.g., 'ja','ko',...) -> transcribe in that language + */ + private language = 'en'; + private num_threads: number; + + private transcribed_callback: + ((finalBlocks: Array, inProg: TranscriptBlock) => void) | null = null; + + constructor( + _audioSource: any, + language: string, + num_threads: number = 4, + model: ModelKey = 'tiny-en-q5_1' + ) { + this.num_threads = num_threads; + this.language = language || 'en'; + this.model_name = model; + + const num_chunks = (this.kWindowLength * this.kSampleRate) / this.kChunkLength; + this.audio_buffer = new SIPOAudioBuffer(num_chunks, this.kChunkLength); + + this.context = new AudioContext({ sampleRate: this.kSampleRate }); + + // Warn if English-only model is used for non-English scenarios + if (this.model_name === 'tiny-en-q5_1' && this.language !== 'en') { + console.warn( + '[Whisper] Non-English scenario requested with tiny-en model. ' + + 'Use the multilingual model (tiny-q5_1) for best results.' + ); } - - /** - * Helper method that stores audio chunks from the raw recorder in buffer - * @param audio_chunk Float32Array containing an audio chunk - */ - private process_recorder_message(audio_chunk: Float32Array) { - this.audio_buffer.push(audio_chunk); - if (this.audio_buffer.isFull()) { - this.whisper.set_audio(this.model_index, this.audio_buffer.getAll()); - this.audio_buffer.clear(); - } + } + + /** Centralized application of language/translate/task flags to the WASM wrapper */ + private applyLanguageSettings = () => { + if (!this.whisper || this.model_index < 0) return; + + // Decide behavior + const isAuto = this.language === 'auto'; + const shouldTranslateToEnglish = (this.language === 'zh'); // <β€” core requirement + + try { + // Auto-detect if 'auto', otherwise pin the language + this.whisper.set_detect_language?.(this.model_index, isAuto ? 1 : 0); + if (!isAuto) { + this.whisper.set_language?.(this.model_index, this.language); + } + + // Translate only for Chinese -> English; otherwise transcribe in source language + this.whisper.set_translate?.(this.model_index, shouldTranslateToEnglish ? 1 : 0); + + // If wrapper exposes task, 0=transcribe, 1=translate (common whisper.cpp bindings) + this.whisper.set_task?.(this.model_index, shouldTranslateToEnglish ? 1 : 0); + + // Optional: threads if supported by wrapper + this.whisper.set_threads?.(this.model_index, this.num_threads); + } catch (e) { + // optional methods may not exist in some builds + // tslint:disable-next-line:no-console + console.debug('[Whisper] applyLanguageSettings (some setters may be missing):', e); } - - private process_analyzer_result(features: any) { - - } - - /** - * Makes the Whisper recognizer start transcribing speech, if not already started - * Throws exception if recognizer fails to start - */ - start() { - console.log("trying to start whisper"); - let that = this; - if (this.whisper == null || this.model_index == -1) { - this.loadWhisper().then(() => { - that.whisper.setStatus(""); - that.context.resume(); - }) - } else { - this.whisper.setStatus(""); - this.context.resume(); - } + }; + + /** Allow runtime language change (e.g., UI dropdown) */ + public setLanguage = (lang: string) => { + if (!lang) return; + // Normalize common variants to 2-letter codes whisper.cpp expects + const base = (lang.toLowerCase().match(/[a-z]+/) || ['en'])[0]; + // keep 'zh' and 'en' explicit; allow 'auto' + const code = base === 'zh' ? 'zh' : base === 'en' ? 'en' : (base === 'auto' ? 'auto' : base); + + this.language = code; + console.log('[Whisper] setLanguage ->', this.language); + + this.applyLanguageSettings(); + }; + + private print = (text: string) => { + if (this.transcribed_callback) { + const block = new TranscriptBlock(); + block.text = text; + this.transcribed_callback([block], new TranscriptBlock()); } - - /** - * Makes the Whisper recognizer stop transcribing speech asynchronously - * Throws exception if recognizer fails to stop - */ - stop() { - if (this.whisper == null || this.model_index === -1) { + }; + private printDebug = (t: string) => console.log(t); + + private storeFS(fname: string, buf: Uint8Array) { + try { this.whisper.FS_unlink(fname); } catch {} + this.whisper.FS_createDataFile('/', fname, buf, true, true); + this.printDebug('storeFS: stored model: ' + fname + ' size: ' + buf.length); + } + + private async load_model(model: ModelKey) { + // Be resilient to common filename variants + const fileCandidates: Record = { + 'tiny-en-q5_1': [ + 'ggml-model-whisper-tiny.en-q5_1.bin', + 'ggml-tiny.en-q5_1.bin' + ], + 'tiny-q5_1': [ + 'ggml-model-whisper-tiny-q5_1.bin', + 'ggml-tiny-q5_1.bin' + ], + }; + const sizes: Record = { 'tiny-en-q5_1': 31, 'tiny-q5_1': 31 }; + + const names = fileCandidates[model] || []; + const candidates = names.flatMap(n => resolveModelUrls(n)); + const dst = 'whisper.bin'; + const size_mb = sizes[model]; + + return new Promise((resolve, reject) => { + loadRemoteWithFallbacks( + candidates, + dst, + size_mb, + (_p) => {}, + (fname, buf) => { + if (!(buf && buf.length > 4096)) { + this.printDebug('load_model: fetched buffer too small (' + (buf ? buf.length : 0) + ')'); + reject('model-bytes-too-small'); return; + } + this.storeFS(fname, buf); + resolve(); + }, + () => reject('load-cancel'), + this.printDebug + ); + }); + } + + public async loadWhisper() { + this.whisper = await makeWhisper({ + print: this.print, + printErr: this.printDebug, + setStatus: (t: string) => this.printDebug('js: ' + t), + monitorRunDependencies: function (_left: number) {} + }); + + await this.load_model(this.model_name); + + // Initialize the model in WASM FS + this.model_index = this.whisper.init('whisper.bin'); + + // Apply language/translate/task config + this.applyLanguageSettings(); + + console.log('[Whisper] initialized with language =', this.language, 'model =', this.model_name); + + // Microphone + const mic = await navigator.mediaDevices.getUserMedia({ audio: true, video: false }); + + let last_suffix = new Float32Array(0); + this.recorder = new RecordRTC(mic, { + type: 'audio', + mimeType: 'audio/wav', + desiredSampRate: this.kSampleRate, + timeSlice: 250, + ondataavailable: async (blob: Blob) => { + const ab = await blob.arrayBuffer(); + const decoded = await WavDecoder.decode(ab); + let pcm = decoded.channelData[0]; // mono + + pcm = Float32Concat(last_suffix, pcm); + last_suffix = pcm.slice(-(pcm.length % this.kChunkLength)); + + for (let i = 0; i <= pcm.length - this.kChunkLength; i += this.kChunkLength) { + this.process_recorder_message(pcm.subarray(i, i + this.kChunkLength)); } - this.whisper.set_status("paused"); - this.context.suspend(); - this.recorder?.stopRecording(); + }, + recorderType: StereoAudioRecorder, + numberOfAudioChannels: 1, + }); + + this.recorder.startRecording(); + this.printDebug('Whisper: audio recorder started'); + } + + private process_recorder_message(audio_chunk: Float32Array) { + this.audio_buffer.push(audio_chunk); + if (this.audio_buffer.isFull()) { + this.whisper.set_audio(this.model_index, this.audio_buffer.getAll()); + this.audio_buffer.clear(); } - - /** - * Subscribe a callback function to the transcript update event, which is usually triggered - * when the recognizer has processed more speech or finalized some in-progress part - * @param callback A callback function called with updates to the transcript - */ - onTranscribed(callback: (newFinalBlocks: Array, newInProgressBlock: TranscriptBlock) => void) { - this.transcribed_callback = callback; + } + + start() { + this.printDebug('trying to start whisper'); + if (!this.whisper || this.model_index === -1) { + this.loadWhisper() + .then(() => { this.whisper.setStatus?.(''); this.context.resume(); }) + .catch(e => this.printDebug('loadWhisper failed: ' + e)); + } else { + this.whisper.setStatus?.(''); + this.context.resume(); } - - /** - * Subscribe a callback function to the error event, which is triggered - * when the recognizer has encountered an error - * @param callback A callback function called with the error object when the event is triggered - */ - onError(callback: (e: Error) => void) { - - } - + } + + stop() { + if (!this.whisper || this.model_index === -1) return; + try { this.whisper.set_status?.('paused'); } catch {} + this.context.suspend(); + this.recorder?.stopRecording(); + } + + onTranscribed(cb: (finalBlocks: Array, inProg: TranscriptBlock) => void) { + this.transcribed_callback = cb; + } + onError(_cb: (e: Error) => void) {} } - diff --git a/src/components/navbars/topbar/api/WhisperDropdown.tsx b/src/components/navbars/topbar/api/WhisperDropdown.tsx index c18229b6..6850da1f 100644 --- a/src/components/navbars/topbar/api/WhisperDropdown.tsx +++ b/src/components/navbars/topbar/api/WhisperDropdown.tsx @@ -1,13 +1,13 @@ +// src/components/navbars/topbar/api/WhisperDropdown.tsx import React, { useState } from 'react'; import { useDispatch, useSelector } from 'react-redux'; -// Pull UI pieces from the same aggregator used elsewhere import { IconButton, Tooltip, Menu, MenuItem, - SettingsIcon, // this should be re-exported by your muiImports like other icons + SettingsIcon, } from '../../../../muiImports'; import { selectSelectedModel } from @@ -15,64 +15,70 @@ import { selectSelectedModel } from type Props = { onPicked?: () => void }; -type ModelKey = 'tiny' | 'base'; - -// Helper to normalize current selection into 'tiny' | 'base' | undefined -function normalizeSelected(selected: unknown): ModelKey | undefined { - if (typeof selected === 'string') { - return selected === 'tiny' || selected === 'base' ? selected : undefined; - } - if (selected && typeof selected === 'object' && 'key' in (selected as any)) { - const k = (selected as any).key; - return k === 'tiny' || k === 'base' ? k : undefined; - } - return undefined; -} +// Only keep tiny models (keep legacy 'tiny' for backward-compat) +type ModelKey = 'tiny-en-q5_1' | 'tiny-q5_1' | 'tiny'; export default function WhisperDropdown({ onPicked }: Props) { const dispatch = useDispatch(); - const selected = useSelector(selectSelectedModel); - const selectedKey = normalizeSelected(selected); const [anchorEl, setAnchorEl] = useState(null); - const open = Boolean(anchorEl); + const isShown = Boolean(anchorEl); + + const selected = useSelector(selectSelectedModel); + const selectedKey: ModelKey = React.useMemo(() => { + const allowed = new Set(['tiny-en-q5_1', 'tiny-q5_1']); + if (typeof selected === 'string') { + if (selected === 'tiny') return 'tiny-en-q5_1'; + return (allowed.has(selected) ? selected : 'tiny-en-q5_1') as ModelKey; + } + if (selected && typeof selected === 'object') { + const k = (selected as any).model_key || (selected as any).key; + if (k === 'tiny') return 'tiny-en-q5_1'; + return (allowed.has(k) ? k : 'tiny-en-q5_1') as ModelKey; + } + return 'tiny-en-q5_1'; + }, [selected]); - const handleOpen = (e: React.MouseEvent) => setAnchorEl(e.currentTarget); + const showPopup = (e: React.MouseEvent) => setAnchorEl(e.currentTarget); const handleClose = () => setAnchorEl(null); const pick = (which: ModelKey) => { - // Keep the existing flags your loader is watching - sessionStorage.setItem('isDownloadTiny', String(which === 'tiny')); - sessionStorage.setItem('isDownloadBase', String(which === 'base')); + // Maintain legacy session flags for any old code paths + sessionStorage.setItem('isDownloadTiny', 'true'); + sessionStorage.setItem('isDownloadBase', 'false'); - // Update Redux – if you have a real action creator, use it here dispatch({ type: 'SET_SELECTED_MODEL', payload: which as any }); - - handleClose(); onPicked?.(); + handleClose(); }; return ( <> - {/* Right-aligned gear like other providers */} - - + + - pick('tiny')}> - TINY (75 MB) + pick('tiny-en-q5_1')} + > + TINY (EN, q5_1) ~31 MB - pick('base')}> - BASE (145 MB) + pick('tiny-q5_1')} + > + TINY (Multi, q5_1) ~31 MB diff --git a/src/components/navbars/topbar/api/WhisperSettings.tsx b/src/components/navbars/topbar/api/WhisperSettings.tsx index 19e5f518..5ef129b9 100644 --- a/src/components/navbars/topbar/api/WhisperSettings.tsx +++ b/src/components/navbars/topbar/api/WhisperSettings.tsx @@ -15,7 +15,7 @@ import { import { selectSelectedModel } from '../../../../react-redux&middleware/redux/reducers/modelSelectionReducers'; -type ModelKey = 'tiny' | 'base'; +type ModelKey = 'tiny-en-q5_1' | 'tiny-q5_1' | 'tiny'; export default function WhisperSettings() { const [anchorEl, setAnchorEl] = React.useState(null); @@ -24,36 +24,33 @@ export default function WhisperSettings() { const dispatch = useDispatch(); const selected = useSelector(selectSelectedModel); - // Normalize current selection to 'tiny' | 'base' | undefined - const selectedKey: ModelKey | undefined = React.useMemo(() => { + const selectedKey: ModelKey = React.useMemo(() => { + const allowed = new Set(['tiny-en-q5_1', 'tiny-q5_1']); if (typeof selected === 'string') { - return selected === 'tiny' || selected === 'base' ? selected : undefined; + if (selected === 'tiny') return 'tiny-en-q5_1'; + return allowed.has(selected) ? (selected as ModelKey) : 'tiny-en-q5_1'; } - if (selected && typeof selected === 'object' && 'key' in (selected as any)) { - const k = (selected as any).key; - return k === 'tiny' || k === 'base' ? k : undefined; + if (selected && typeof selected === 'object') { + const k = (selected as any).model_key || (selected as any).key; + if (k === 'tiny') return 'tiny-en-q5_1'; + return allowed.has(k) ? (k as ModelKey) : 'tiny-en-q5_1'; } - return undefined; + return 'tiny-en-q5_1'; }, [selected]); const showPopup = (e: React.MouseEvent) => setAnchorEl(e.currentTarget); const closePopup = () => setAnchorEl(null); const pick = (which: ModelKey) => { - // Keep flags if your loader watches them - sessionStorage.setItem('isDownloadTiny', (which === 'tiny').toString()); - sessionStorage.setItem('isDownloadBase', (which === 'base').toString()); - - // Update redux + sessionStorage.setItem('isDownloadTiny', 'true'); + sessionStorage.setItem('isDownloadBase', 'false'); dispatch({ type: 'SET_SELECTED_MODEL', payload: which as any }); - closePopup(); }; return ( <> - {/* Right-side gear (same pattern as Azure/ScribeAR Server) */} - + @@ -61,43 +58,33 @@ export default function WhisperSettings() { anchorEl={anchorEl} open={isShown} onClose={closePopup} - PaperProps={{ - elevation: 0, - sx: { - position: 'unset', - ml: '25vw', - width: '50vw', - mt: '25vh', - height: '50vh', - filter: 'drop-shadow(0px 2px 8px rgba(0,0,0,0.32))', - }, - }} - transformOrigin={{ horizontal: 'center', vertical: 'top' }} - anchorOrigin={{ horizontal: 'center', vertical: 'bottom' }} + anchorOrigin={{ horizontal: 'right', vertical: 'bottom' }} + transformOrigin={{ horizontal: 'right', vertical: 'top' }} > - - - - - - - - - + + + + + + + + + + ); diff --git a/src/components/navbars/topbar/api/pickApi.tsx b/src/components/navbars/topbar/api/pickApi.tsx index bba360f8..9577f816 100644 --- a/src/components/navbars/topbar/api/pickApi.tsx +++ b/src/components/navbars/topbar/api/pickApi.tsx @@ -28,7 +28,7 @@ import { ListItemText, ThemeProvider, createTheme, - Chip, // from muiImports barrel + Chip, } from '../../../../muiImports'; import { ListItem } from '@mui/material'; @@ -83,15 +83,36 @@ const IconStatus = (currentApi: any) => { } }; -function getSelectedModelKey(selected: unknown): 'tiny' | 'base' | undefined { +/** Normalize various saved shapes into one of our TWO tiny model keys; default to tiny-en-q5_1 */ +function getSelectedModelKey( + selected: unknown +): 'tiny-en-q5_1' | 'tiny-q5_1' { + const allowed = new Set(['tiny-en-q5_1', 'tiny-q5_1']); + if (typeof selected === 'string') { - return selected === 'tiny' || selected === 'base' ? selected : undefined; + if (selected === 'tiny') return 'tiny-en-q5_1'; + // if someone still has 'base' persisted, fall back to tiny-en + if (selected === 'base') return 'tiny-en-q5_1'; + return allowed.has(selected) ? (selected as any) : 'tiny-en-q5_1'; + } + + if (selected && typeof selected === 'object') { + const k = (selected as any).model_key ?? (selected as any).key; + if (k === 'tiny') return 'tiny-en-q5_1'; + if (k === 'base') return 'tiny-en-q5_1'; + return allowed.has(k) ? (k as any) : 'tiny-en-q5_1'; } - if (selected && typeof selected === 'object' && 'key' in (selected as any)) { - const k = (selected as any).key; - return k === 'tiny' || k === 'base' ? k : undefined; + + return 'tiny-en-q5_1'; +} + +/** Short label for topbar chip */ +function getSelectedModelLabel(k?: string) { + switch (k) { + case 'tiny-en-q5_1': return 'TINY-EN'; + case 'tiny-q5_1': return 'TINY-MULTI'; + default: return 'TINY-EN'; } - return undefined; } export default function PickApi() { @@ -261,7 +282,7 @@ export default function PickApi() { - + @@ -306,7 +327,7 @@ export default function PickApi() { )} diff --git a/src/ml/inference.js b/src/ml/inference.js index d3a824e7..5cb42117 100644 --- a/src/ml/inference.js +++ b/src/ml/inference.js @@ -3,57 +3,109 @@ /*global BigInt64Array */ import { loadTokenizer } from './bert_tokenizer.ts'; -// import * as wasmFeatureDetect from 'wasm-feature-detect'; - -//Setup onnxruntime +// Setup onnxruntime const ort = require('onnxruntime-web'); -//requires Cross-Origin-*-policy headers https://web.dev/coop-coep/ -/** -const simdResolver = wasmFeatureDetect.simd().then(simdSupported => { - console.log("simd is supported? "+ simdSupported); - if (simdSupported) { - ort.env.wasm.numThreads = 3; - ort.env.wasm.simd = true; - } else { - ort.env.wasm.numThreads = 1; - ort.env.wasm.simd = false; - } -}); -*/ +// --- CONFIG ----------------------------------------------------------------- + +// We resolve URLs against PUBLIC_URL (prod) or root (dev) +function resolvePublicUrl(path) { + const base = + (process.env.NODE_ENV === 'development' + ? '' + : (process.env.PUBLIC_URL || '') + ).replace(/\/$/, ''); + return `${base}${path.startsWith('/') ? '' : '/'}${path}`; +} -const options = { - executionProviders: ['wasm'], - graphOptimizationLevel: 'all' +// Place these files in public/models/onnx/ +const LM_MODEL_URL = resolvePublicUrl('/models/onnx/xtremedistill-go-emotion-int8.onnx'); +// Note: parentheses are valid in URLs; we encode to be safe. +const GRU_MODEL_URL = resolvePublicUrl('/models/onnx/' + encodeURI('gru_embedder(1,40,80).onnx')); + +const ORT_OPTIONS = { + executionProviders: ['wasm'], + graphOptimizationLevel: 'all', }; -var downLoadingModel = true; -const model = "./xtremedistill-go-emotion-int8.onnx"; -// const gruModel = "./gru_embedder(1,21,80).onnx"; -const gruModel = "./gru_embedder(1,40,80).onnx"; -const gruSession = ort.InferenceSession.create(gruModel, options); +// --- INTERNAL STATE --------------------------------------------------------- + +let downloadingModel = false; + +let lmSessionPromise = null; // Promise +let gruSessionPromise = null; // Promise + +let lmDisabled = false; +let gruDisabled = false; + +// tokenizer promise (unchanged) +const tokenizer = loadTokenizer(); + +// --- HELPERS ---------------------------------------------------------------- + +async function fetchModelBytes(url) { + // Fetch to ArrayBuffer first so we can reject HTML/JSON and tiny files + const res = await fetch(url, { cache: 'no-cache' }); + if (!res.ok) throw new Error(`HTTP ${res.status} @ ${url}`); + const ct = (res.headers.get('content-type') || '').toLowerCase(); + const buf = await res.arrayBuffer(); + + // Reject clearly-wrong payloads (HTML/JSON or super tiny files) + if (buf.byteLength < 2048 || ct.includes('text/html') || ct.includes('json')) { + throw new Error(`Invalid model payload (type=${ct}, size=${buf.byteLength}) @ ${url}`); + } + return buf; +} + +function sortResult(a, b) { + if (a[1] === b[1]) return 0; + return (a[1] < b[1]) ? 1 : -1; +} + +function sigmoid(t) { + return 1 / (1 + Math.exp(-t)); +} + +// Build BERT inputs (BigInt) +function create_model_input(encoded) { + let input_ids = new Array(encoded.length + 2); + let attention_mask = new Array(encoded.length + 2); + let token_type_ids = new Array(encoded.length + 2); + input_ids[0] = BigInt(101); + attention_mask[0] = BigInt(1); + token_type_ids[0] = BigInt(0); -const session = ort.InferenceSession.create(model, options); -session.then(t => { - downLoadingModel = false; - //warmup the VM - for(var i = 0; i < 10; i++) { - console.log("Inference warmup " + i); - lm_inference("this is a warmup inference"); + let i = 0; + for (; i < encoded.length; i++) { + input_ids[i + 1] = BigInt(encoded[i]); + attention_mask[i + 1] = BigInt(1); + token_type_ids[i + 1] = BigInt(0); } -}); -const tokenizer = loadTokenizer() + input_ids[i + 1] = BigInt(102); + attention_mask[i + 1] = BigInt(1); + token_type_ids[i + 1] = BigInt(0); + + const seq = input_ids.length; + + input_ids = new ort.Tensor('int64', BigInt64Array.from(input_ids), [1, seq]); + attention_mask = new ort.Tensor('int64', BigInt64Array.from(attention_mask), [1, seq]); + token_type_ids = new ort.Tensor('int64', BigInt64Array.from(token_type_ids), [1, seq]); + + return { input_ids, attention_mask, token_type_ids }; +} + +// --- CONSTANTS -------------------------------------------------------------- const EMOJI_DEFAULT_DISPLAY = [ - ["Emotion", "Score"], - ['admiration πŸ‘',0], - ['amusement πŸ˜‚', 0], - ['neutral 😐',0], - ['approval πŸ‘',0], - ['joy πŸ˜ƒ',0], - ['gratitude πŸ™',0], + ['Emotion', 'Score'], + ['admiration πŸ‘', 0], + ['amusement πŸ˜‚', 0], + ['neutral 😐', 0], + ['approval πŸ‘', 0], + ['joy πŸ˜ƒ', 0], + ['gratitude πŸ™', 0], ]; const EMOJIS = [ @@ -80,112 +132,128 @@ const EMOJIS = [ 'optimism 🀞', 'pride 😌', 'realization πŸ’‘', - 'reliefπŸ˜…', - 'remorse 😞', + 'relief πŸ˜…', + 'remorse 😞', 'sadness 😞', 'surprise 😲', - 'neutral 😐' + 'neutral 😐', ]; -function isDownloading() { - return downLoadingModel; -} +// --- LAZY SESSION INITIALIZERS --------------------------------------------- -function sortResult(a, b) { - if (a[1] === b[1]) { - return 0; - } - else { - return (a[1] < b[1]) ? 1 : -1; - } +async function ensureLmSession() { + if (lmDisabled) return null; + if (lmSessionPromise) return lmSessionPromise; + + downloadingModel = true; + lmSessionPromise = (async () => { + try { + const bytes = await fetchModelBytes(LM_MODEL_URL); + const session = await ort.InferenceSession.create(bytes, ORT_OPTIONS); + return session; + } catch (err) { + console.warn('[inference] LM disabled:', err); + lmDisabled = true; + return null; + } finally { + downloadingModel = false; + } + })(); + + return lmSessionPromise; } -function sigmoid(t) { - return 1/(1+Math.pow(Math.E, -t)); +async function ensureGruSession() { + if (gruDisabled) return null; + if (gruSessionPromise) return gruSessionPromise; + + downloadingModel = true; + gruSessionPromise = (async () => { + try { + const bytes = await fetchModelBytes(GRU_MODEL_URL); + const session = await ort.InferenceSession.create(bytes, ORT_OPTIONS); + return session; + } catch (err) { + console.warn('[inference] GRU disabled:', err); + gruDisabled = true; + return null; + } finally { + downloadingModel = false; + } + })(); + + return gruSessionPromise; } -function create_model_input(encoded) { - var input_ids = new Array(encoded.length+2); - var attention_mask = new Array(encoded.length+2); - var token_type_ids = new Array(encoded.length+2); - input_ids[0]Β = BigInt(101); - attention_mask[0]Β = BigInt(1); - token_type_ids[0]Β = BigInt(0); - var i = 0; - for(; i < encoded.length; i++) { - input_ids[i+1] = BigInt(encoded[i]); - attention_mask[i+1] = BigInt(1); - token_type_ids[i+1] = BigInt(0); - } - input_ids[i+1]Β = BigInt(102); - attention_mask[i+1]Β = BigInt(1); - token_type_ids[i+1]Β = BigInt(0); - const sequence_length = input_ids.length; - input_ids = new ort.Tensor('int64', BigInt64Array.from(input_ids), [1,sequence_length]); - attention_mask = new ort.Tensor('int64', BigInt64Array.from(attention_mask), [1,sequence_length]); - token_type_ids = new ort.Tensor('int64', BigInt64Array.from(token_type_ids), [1,sequence_length]); - return { - input_ids: input_ids, - attention_mask: attention_mask, - token_type_ids:token_type_ids - } +// --- PUBLIC API ------------------------------------------------------------- + +function isDownloading() { + return downloadingModel; } async function lm_inference(text) { - try { - const encoded_ids = await tokenizer.then(t => { - return t.tokenize(text); - }); - if(encoded_ids.length === 0) { + try { + const session = await ensureLmSession(); + if (!session) return [0.0, EMOJI_DEFAULT_DISPLAY]; + + const encoded_ids = await tokenizer.then(t => t.tokenize(text)); + if (!encoded_ids || encoded_ids.length === 0) { return [0.0, EMOJI_DEFAULT_DISPLAY]; } + const start = performance.now(); - const model_input = create_model_input(encoded_ids); - const output = await session.then(s => { return s.run(model_input,['output_0'])}); + const feeds = create_model_input(encoded_ids); + const output = await session.run(feeds, ['output_0']); + const duration = (performance.now() - start).toFixed(1); - const probs = output['output_0'].data.map(sigmoid).map( t => Math.floor(t*100)); - // console.log(147, 'probs: ', probs); + const probs = output['output_0'].data + .map(sigmoid) + .map(t => Math.floor(t * 100)); + const result = []; - for(var i = 0; i < EMOJIS.length;i++) { - const t = [EMOJIS[i], probs[i]]; - result[i] = t; + for (let i = 0; i < EMOJIS.length; i++) { + result[i] = [EMOJIS[i], probs[i]]; } - result.sort(sortResult); - + result.sort(sortResult); + const result_list = []; - result_list[0] = ["Emotion", "Score"]; - for(i = 0; i < 6; i++) { - result_list[i+1] = result[i]; - } - return [duration,result_list]; - } catch (e) { - return [0.0,EMOJI_DEFAULT_DISPLAY]; + result_list[0] = ['Emotion', 'Score']; + for (let i = 0; i < 6; i++) result_list[i + 1] = result[i]; + + return [duration, result_list]; + } catch (_e) { + // Swallow errors to keep UI smooth + return [0.0, EMOJI_DEFAULT_DISPLAY]; } -} +} async function gru_inference(melLogSpectrogram) { try { + const session = await ensureGruSession(); + if (!session) return null; + + // Expecting shape [1, 40, 80] + const input = new ort.Tensor( + 'float32', + melLogSpectrogram.flat(), + [1, 40, 80] + ); - const inputs = ort.InferenceSession.FeedsType = { - // input: new ort.Tensor('float32', melLogSpectrogram.flat(), [1, 21, 80]) - input: new ort.Tensor('float32', melLogSpectrogram.flat(), [1, 40, 80]) - }; - // console.log(171, 'input: ', inputs.input.data); - // console.log('input sum: ', inputs.inpumot.data.reduce((a, b) => a + b, 0)); - // console.log(inputs.input); - const output = await gruSession.then(s => { return s.run(inputs)}); - // console.log(175, 'output: ', output.output.data); - // sum all data - // const sum = output.output.data.reduce((a, b) => a + b, 0); - // console.log('output sum: ', sum); - return output; + const outputs = await session.run({ input }); + return outputs; } catch (e) { - console.log(e); + console.warn('[inference] GRU inference error:', e); + return null; } } +// Named exports kept for compatibility with your codebase +export let intent_inference = lm_inference; +export let columnNames = EMOJI_DEFAULT_DISPLAY; +export let modelDownloadInProgress = isDownloading; +export let gruInference = gru_inference; -export let intent_inference = lm_inference -export let columnNames = EMOJI_DEFAULT_DISPLAY -export let modelDownloadInProgress = isDownloading -export let gruInference = gru_inference +// Optional: explicit initializer if you ever want to prefetch after a user action +export async function initOnnx() { + await Promise.all([ensureLmSession(), ensureGruSession()]); +}