diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 081e1785..42c896a2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,8 @@ name: CI +# TODO: This workflow should be calling make targets instead of duplicating commands. +# See `make verify` for the local equivalent of these checks. + on: push: branches: [ main ] diff --git a/CLAUDE.md b/CLAUDE.md index 6db9e23f..53694e57 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -222,6 +222,13 @@ Run the pipeline to update provider types: This process is fully automated and generates only essential types to minimize code size. +## Development workflow + +1. Write a failing test or repro +2. Make the change +3. Run tests (`cargo test`) +4. Before committing, run `make verify` (runs tests and lint) + ## Development setup **Git hooks installation**: diff --git a/Makefile b/Makefile index aef94b35..2d869322 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all lingua-wasm typescript python test clean help generate-types generate-all-providers install-hooks install-wasm-tools setup +.PHONY: all lingua-wasm typescript python test clean help generate-types generate-all-providers install-hooks install-wasm-tools setup verify clippy fmt-check all: typescript python ## Build all bindings @@ -78,6 +78,15 @@ fmt: ## Format all code @echo "Formatting TypeScript code..." cd bindings/typescript && pnpm run lint +fmt-check: ## Check formatting without modifying + cargo fmt --all -- --check + +clippy: ## Run clippy with warnings as errors (matches CI) + cargo clippy --all-targets --all-features -- -D warnings + +verify: fmt-check clippy ## Run all CI checks locally (run before committing) + RUSTFLAGS="-D warnings" $(MAKE) test-rust + install-hooks: ## Install git pre-commit hooks @echo "Installing git hooks..." ./scripts/install-hooks.sh diff --git a/bindings/typescript/wasm-web/lingua.d.ts b/bindings/typescript/wasm-web/lingua.d.ts new file mode 100644 index 00000000..16837610 --- /dev/null +++ b/bindings/typescript/wasm-web/lingua.d.ts @@ -0,0 +1,130 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Convert array of Chat Completions messages to Lingua Messages + */ +export function chat_completions_messages_to_lingua(value: any): any; +/** + * Convert array of Lingua Messages to Chat Completions messages + */ +export function lingua_to_chat_completions_messages(value: any): any; +/** + * Convert array of Responses API messages to Lingua Messages + */ +export function responses_messages_to_lingua(value: any): any; +/** + * Convert array of Lingua Messages to Responses API messages + */ +export function lingua_to_responses_messages(value: any): any; +/** + * Convert array of Anthropic messages to Lingua Messages + */ +export function anthropic_messages_to_lingua(value: any): any; +/** + * Convert array of Lingua Messages to Anthropic messages + */ +export function lingua_to_anthropic_messages(value: any): any; +/** + * Deduplicate messages based on role and content + */ +export function deduplicate_messages(value: any): any; +/** + * Import messages from spans + */ +export function import_messages_from_spans(value: any): any; +/** + * Import and deduplicate messages from spans in a single operation + */ +export function import_and_deduplicate_messages(value: any): any; +/** + * Validate a JSON string as a Chat Completions request + */ +export function validate_chat_completions_request(json: string): any; +/** + * Validate a JSON string as a Chat Completions response + */ +export function validate_chat_completions_response(json: string): any; +/** + * Validate a JSON string as a Responses API request + */ +export function validate_responses_request(json: string): any; +/** + * Validate a JSON string as a Responses API response + */ +export function validate_responses_response(json: string): any; +/** + * Validate a JSON string as an OpenAI request + * @deprecated Use validate_chat_completions_request instead + */ +export function validate_openai_request(json: string): any; +/** + * Validate a JSON string as an OpenAI response + * @deprecated Use validate_chat_completions_response instead + */ +export function validate_openai_response(json: string): any; +/** + * Validate a JSON string as an Anthropic request + */ +export function validate_anthropic_request(json: string): any; +/** + * Validate a JSON string as an Anthropic response + */ +export function validate_anthropic_response(json: string): any; +/** + * Validate a JSON string as a Google request (not supported - protobuf types) + */ +export function validate_google_request(json: string): any; +/** + * Validate a JSON string as a Google response (not supported - protobuf types) + */ +export function validate_google_response(json: string): any; + +export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembly.Module; + +export interface InitOutput { + readonly memory: WebAssembly.Memory; + readonly chat_completions_messages_to_lingua: (a: number, b: number) => void; + readonly lingua_to_chat_completions_messages: (a: number, b: number) => void; + readonly responses_messages_to_lingua: (a: number, b: number) => void; + readonly lingua_to_responses_messages: (a: number, b: number) => void; + readonly anthropic_messages_to_lingua: (a: number, b: number) => void; + readonly lingua_to_anthropic_messages: (a: number, b: number) => void; + readonly deduplicate_messages: (a: number, b: number) => void; + readonly import_messages_from_spans: (a: number, b: number) => void; + readonly import_and_deduplicate_messages: (a: number, b: number) => void; + readonly validate_chat_completions_request: (a: number, b: number, c: number) => void; + readonly validate_chat_completions_response: (a: number, b: number, c: number) => void; + readonly validate_responses_request: (a: number, b: number, c: number) => void; + readonly validate_responses_response: (a: number, b: number, c: number) => void; + readonly validate_anthropic_request: (a: number, b: number, c: number) => void; + readonly validate_anthropic_response: (a: number, b: number, c: number) => void; + readonly validate_google_request: (a: number, b: number, c: number) => void; + readonly validate_openai_request: (a: number, b: number, c: number) => void; + readonly validate_openai_response: (a: number, b: number, c: number) => void; + readonly validate_google_response: (a: number, b: number, c: number) => void; + readonly __wbindgen_export_0: (a: number, b: number) => number; + readonly __wbindgen_export_1: (a: number, b: number, c: number, d: number) => number; + readonly __wbindgen_export_2: (a: number) => void; + readonly __wbindgen_add_to_stack_pointer: (a: number) => number; +} + +export type SyncInitInput = BufferSource | WebAssembly.Module; +/** +* Instantiates the given `module`, which can either be bytes or +* a precompiled `WebAssembly.Module`. +* +* @param {{ module: SyncInitInput }} module - Passing `SyncInitInput` directly is deprecated. +* +* @returns {InitOutput} +*/ +export function initSync(module: { module: SyncInitInput } | SyncInitInput): InitOutput; + +/** +* If `module_or_path` is {RequestInfo} or {URL}, makes a request and +* for everything else, calls `WebAssembly.instantiate` directly. +* +* @param {{ module_or_path: InitInput | Promise }} module_or_path - Passing `InitInput` directly is deprecated. +* +* @returns {Promise} +*/ +export default function __wbg_init (module_or_path?: { module_or_path: InitInput | Promise } | InitInput | Promise): Promise; diff --git a/bindings/typescript/wasm-web/lingua.js b/bindings/typescript/wasm-web/lingua.js new file mode 100644 index 00000000..e476f43c --- /dev/null +++ b/bindings/typescript/wasm-web/lingua.js @@ -0,0 +1,948 @@ +let wasm; + +const heap = new Array(128).fill(undefined); + +heap.push(undefined, null, true, false); + +function getObject(idx) { return heap[idx]; } + +let WASM_VECTOR_LEN = 0; + +let cachedUint8ArrayMemory0 = null; + +function getUint8ArrayMemory0() { + if (cachedUint8ArrayMemory0 === null || cachedUint8ArrayMemory0.byteLength === 0) { + cachedUint8ArrayMemory0 = new Uint8Array(wasm.memory.buffer); + } + return cachedUint8ArrayMemory0; +} + +const cachedTextEncoder = (typeof TextEncoder !== 'undefined' ? new TextEncoder('utf-8') : { encode: () => { throw Error('TextEncoder not available') } } ); + +const encodeString = (typeof cachedTextEncoder.encodeInto === 'function' + ? function (arg, view) { + return cachedTextEncoder.encodeInto(arg, view); +} + : function (arg, view) { + const buf = cachedTextEncoder.encode(arg); + view.set(buf); + return { + read: arg.length, + written: buf.length + }; +}); + +function passStringToWasm0(arg, malloc, realloc) { + + if (realloc === undefined) { + const buf = cachedTextEncoder.encode(arg); + const ptr = malloc(buf.length, 1) >>> 0; + getUint8ArrayMemory0().subarray(ptr, ptr + buf.length).set(buf); + WASM_VECTOR_LEN = buf.length; + return ptr; + } + + let len = arg.length; + let ptr = malloc(len, 1) >>> 0; + + const mem = getUint8ArrayMemory0(); + + let offset = 0; + + for (; offset < len; offset++) { + const code = arg.charCodeAt(offset); + if (code > 0x7F) break; + mem[ptr + offset] = code; + } + + if (offset !== len) { + if (offset !== 0) { + arg = arg.slice(offset); + } + ptr = realloc(ptr, len, len = offset + arg.length * 3, 1) >>> 0; + const view = getUint8ArrayMemory0().subarray(ptr + offset, ptr + len); + const ret = encodeString(arg, view); + + offset += ret.written; + ptr = realloc(ptr, len, offset, 1) >>> 0; + } + + WASM_VECTOR_LEN = offset; + return ptr; +} + +let cachedDataViewMemory0 = null; + +function getDataViewMemory0() { + if (cachedDataViewMemory0 === null || cachedDataViewMemory0.buffer.detached === true || (cachedDataViewMemory0.buffer.detached === undefined && cachedDataViewMemory0.buffer !== wasm.memory.buffer)) { + cachedDataViewMemory0 = new DataView(wasm.memory.buffer); + } + return cachedDataViewMemory0; +} + +let heap_next = heap.length; + +function addHeapObject(obj) { + if (heap_next === heap.length) heap.push(heap.length + 1); + const idx = heap_next; + heap_next = heap[idx]; + + heap[idx] = obj; + return idx; +} + +function handleError(f, args) { + try { + return f.apply(this, args); + } catch (e) { + wasm.__wbindgen_export_2(addHeapObject(e)); + } +} + +function dropObject(idx) { + if (idx < 132) return; + heap[idx] = heap_next; + heap_next = idx; +} + +function takeObject(idx) { + const ret = getObject(idx); + dropObject(idx); + return ret; +} + +function isLikeNone(x) { + return x === undefined || x === null; +} + +function debugString(val) { + // primitive types + const type = typeof val; + if (type == 'number' || type == 'boolean' || val == null) { + return `${val}`; + } + if (type == 'string') { + return `"${val}"`; + } + if (type == 'symbol') { + const description = val.description; + if (description == null) { + return 'Symbol'; + } else { + return `Symbol(${description})`; + } + } + if (type == 'function') { + const name = val.name; + if (typeof name == 'string' && name.length > 0) { + return `Function(${name})`; + } else { + return 'Function'; + } + } + // objects + if (Array.isArray(val)) { + const length = val.length; + let debug = '['; + if (length > 0) { + debug += debugString(val[0]); + } + for(let i = 1; i < length; i++) { + debug += ', ' + debugString(val[i]); + } + debug += ']'; + return debug; + } + // Test for built-in + const builtInMatches = /\[object ([^\]]+)\]/.exec(toString.call(val)); + let className; + if (builtInMatches && builtInMatches.length > 1) { + className = builtInMatches[1]; + } else { + // Failed to match the standard '[object ClassName]' + return toString.call(val); + } + if (className == 'Object') { + // we're a user defined class or Object + // JSON.stringify avoids problems with cycles, and is generally much + // easier than looping through ownProperties of `val`. + try { + return 'Object(' + JSON.stringify(val) + ')'; + } catch (_) { + return 'Object'; + } + } + // errors + if (val instanceof Error) { + return `${val.name}: ${val.message}\n${val.stack}`; + } + // TODO we could test for more things here, like `Set`s and `Map`s. + return className; +} + +const cachedTextDecoder = (typeof TextDecoder !== 'undefined' ? new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }) : { decode: () => { throw Error('TextDecoder not available') } } ); + +if (typeof TextDecoder !== 'undefined') { cachedTextDecoder.decode(); }; + +function getStringFromWasm0(ptr, len) { + ptr = ptr >>> 0; + return cachedTextDecoder.decode(getUint8ArrayMemory0().subarray(ptr, ptr + len)); +} +/** + * Convert array of Chat Completions messages to Lingua Messages + * @param {any} value + * @returns {any} + */ +export function chat_completions_messages_to_lingua(value) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.chat_completions_messages_to_lingua(retptr, addHeapObject(value)); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Convert array of Lingua Messages to Chat Completions messages + * @param {any} value + * @returns {any} + */ +export function lingua_to_chat_completions_messages(value) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.lingua_to_chat_completions_messages(retptr, addHeapObject(value)); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Convert array of Responses API messages to Lingua Messages + * @param {any} value + * @returns {any} + */ +export function responses_messages_to_lingua(value) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.responses_messages_to_lingua(retptr, addHeapObject(value)); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Convert array of Lingua Messages to Responses API messages + * @param {any} value + * @returns {any} + */ +export function lingua_to_responses_messages(value) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.lingua_to_responses_messages(retptr, addHeapObject(value)); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Convert array of Anthropic messages to Lingua Messages + * @param {any} value + * @returns {any} + */ +export function anthropic_messages_to_lingua(value) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.anthropic_messages_to_lingua(retptr, addHeapObject(value)); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Convert array of Lingua Messages to Anthropic messages + * @param {any} value + * @returns {any} + */ +export function lingua_to_anthropic_messages(value) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.lingua_to_anthropic_messages(retptr, addHeapObject(value)); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Deduplicate messages based on role and content + * @param {any} value + * @returns {any} + */ +export function deduplicate_messages(value) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.deduplicate_messages(retptr, addHeapObject(value)); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Import messages from spans + * @param {any} value + * @returns {any} + */ +export function import_messages_from_spans(value) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.import_messages_from_spans(retptr, addHeapObject(value)); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Import and deduplicate messages from spans in a single operation + * @param {any} value + * @returns {any} + */ +export function import_and_deduplicate_messages(value) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + wasm.import_and_deduplicate_messages(retptr, addHeapObject(value)); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as a Chat Completions request + * @param {string} json + * @returns {any} + */ +export function validate_chat_completions_request(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_chat_completions_request(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as a Chat Completions response + * @param {string} json + * @returns {any} + */ +export function validate_chat_completions_response(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_chat_completions_response(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as a Responses API request + * @param {string} json + * @returns {any} + */ +export function validate_responses_request(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_responses_request(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as a Responses API response + * @param {string} json + * @returns {any} + */ +export function validate_responses_response(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_responses_response(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as an OpenAI request + * @deprecated Use validate_chat_completions_request instead + * @param {string} json + * @returns {any} + */ +export function validate_openai_request(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_chat_completions_request(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as an OpenAI response + * @deprecated Use validate_chat_completions_response instead + * @param {string} json + * @returns {any} + */ +export function validate_openai_response(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_chat_completions_response(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as an Anthropic request + * @param {string} json + * @returns {any} + */ +export function validate_anthropic_request(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_anthropic_request(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as an Anthropic response + * @param {string} json + * @returns {any} + */ +export function validate_anthropic_response(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_anthropic_response(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as a Google request (not supported - protobuf types) + * @param {string} json + * @returns {any} + */ +export function validate_google_request(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_google_request(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +/** + * Validate a JSON string as a Google response (not supported - protobuf types) + * @param {string} json + * @returns {any} + */ +export function validate_google_response(json) { + try { + const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); + const ptr0 = passStringToWasm0(json, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len0 = WASM_VECTOR_LEN; + wasm.validate_google_request(retptr, ptr0, len0); + var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); + var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); + var r2 = getDataViewMemory0().getInt32(retptr + 4 * 2, true); + if (r2) { + throw takeObject(r1); + } + return takeObject(r0); + } finally { + wasm.__wbindgen_add_to_stack_pointer(16); + } +} + +async function __wbg_load(module, imports) { + if (typeof Response === 'function' && module instanceof Response) { + if (typeof WebAssembly.instantiateStreaming === 'function') { + try { + return await WebAssembly.instantiateStreaming(module, imports); + + } catch (e) { + if (module.headers.get('Content-Type') != 'application/wasm') { + console.warn("`WebAssembly.instantiateStreaming` failed because your server does not serve Wasm with `application/wasm` MIME type. Falling back to `WebAssembly.instantiate` which is slower. Original error:\n", e); + + } else { + throw e; + } + } + } + + const bytes = await module.arrayBuffer(); + return await WebAssembly.instantiate(bytes, imports); + + } else { + const instance = await WebAssembly.instantiate(module, imports); + + if (instance instanceof WebAssembly.Instance) { + return { instance, module }; + + } else { + return instance; + } + } +} + +function __wbg_get_imports() { + const imports = {}; + imports.wbg = {}; + imports.wbg.__wbg_String_8f0eb39a4a4c2f66 = function(arg0, arg1) { + const ret = String(getObject(arg1)); + const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len1 = WASM_VECTOR_LEN; + getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); + getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); + }; + imports.wbg.__wbg_buffer_609cc3eee51ed158 = function(arg0) { + const ret = getObject(arg0).buffer; + return addHeapObject(ret); + }; + imports.wbg.__wbg_call_672a4d21634d4a24 = function() { return handleError(function (arg0, arg1) { + const ret = getObject(arg0).call(getObject(arg1)); + return addHeapObject(ret); + }, arguments) }; + imports.wbg.__wbg_done_769e5ede4b31c67b = function(arg0) { + const ret = getObject(arg0).done; + return ret; + }; + imports.wbg.__wbg_entries_3265d4158b33e5dc = function(arg0) { + const ret = Object.entries(getObject(arg0)); + return addHeapObject(ret); + }; + imports.wbg.__wbg_get_67b2ba62fc30de12 = function() { return handleError(function (arg0, arg1) { + const ret = Reflect.get(getObject(arg0), getObject(arg1)); + return addHeapObject(ret); + }, arguments) }; + imports.wbg.__wbg_get_b9b93047fe3cf45b = function(arg0, arg1) { + const ret = getObject(arg0)[arg1 >>> 0]; + return addHeapObject(ret); + }; + imports.wbg.__wbg_getwithrefkey_1dc361bd10053bfe = function(arg0, arg1) { + const ret = getObject(arg0)[getObject(arg1)]; + return addHeapObject(ret); + }; + imports.wbg.__wbg_instanceof_ArrayBuffer_e14585432e3737fc = function(arg0) { + let result; + try { + result = getObject(arg0) instanceof ArrayBuffer; + } catch (_) { + result = false; + } + const ret = result; + return ret; + }; + imports.wbg.__wbg_instanceof_Map_f3469ce2244d2430 = function(arg0) { + let result; + try { + result = getObject(arg0) instanceof Map; + } catch (_) { + result = false; + } + const ret = result; + return ret; + }; + imports.wbg.__wbg_instanceof_Uint8Array_17156bcf118086a9 = function(arg0) { + let result; + try { + result = getObject(arg0) instanceof Uint8Array; + } catch (_) { + result = false; + } + const ret = result; + return ret; + }; + imports.wbg.__wbg_isArray_a1eab7e0d067391b = function(arg0) { + const ret = Array.isArray(getObject(arg0)); + return ret; + }; + imports.wbg.__wbg_isSafeInteger_343e2beeeece1bb0 = function(arg0) { + const ret = Number.isSafeInteger(getObject(arg0)); + return ret; + }; + imports.wbg.__wbg_iterator_9a24c88df860dc65 = function() { + const ret = Symbol.iterator; + return addHeapObject(ret); + }; + imports.wbg.__wbg_length_a446193dc22c12f8 = function(arg0) { + const ret = getObject(arg0).length; + return ret; + }; + imports.wbg.__wbg_length_e2d2a49132c1b256 = function(arg0) { + const ret = getObject(arg0).length; + return ret; + }; + imports.wbg.__wbg_new_405e22f390576ce2 = function() { + const ret = new Object(); + return addHeapObject(ret); + }; + imports.wbg.__wbg_new_5e0be73521bc8c17 = function() { + const ret = new Map(); + return addHeapObject(ret); + }; + imports.wbg.__wbg_new_78feb108b6472713 = function() { + const ret = new Array(); + return addHeapObject(ret); + }; + imports.wbg.__wbg_new_a12002a7f91c75be = function(arg0) { + const ret = new Uint8Array(getObject(arg0)); + return addHeapObject(ret); + }; + imports.wbg.__wbg_next_25feadfc0913fea9 = function(arg0) { + const ret = getObject(arg0).next; + return addHeapObject(ret); + }; + imports.wbg.__wbg_next_6574e1a8a62d1055 = function() { return handleError(function (arg0) { + const ret = getObject(arg0).next(); + return addHeapObject(ret); + }, arguments) }; + imports.wbg.__wbg_set_37837023f3d740e8 = function(arg0, arg1, arg2) { + getObject(arg0)[arg1 >>> 0] = takeObject(arg2); + }; + imports.wbg.__wbg_set_3f1d0b984ed272ed = function(arg0, arg1, arg2) { + getObject(arg0)[takeObject(arg1)] = takeObject(arg2); + }; + imports.wbg.__wbg_set_65595bdd868b3009 = function(arg0, arg1, arg2) { + getObject(arg0).set(getObject(arg1), arg2 >>> 0); + }; + imports.wbg.__wbg_set_8fc6bf8a5b1071d1 = function(arg0, arg1, arg2) { + const ret = getObject(arg0).set(getObject(arg1), getObject(arg2)); + return addHeapObject(ret); + }; + imports.wbg.__wbg_value_cd1ffa7b1ab794f1 = function(arg0) { + const ret = getObject(arg0).value; + return addHeapObject(ret); + }; + imports.wbg.__wbindgen_as_number = function(arg0) { + const ret = +getObject(arg0); + return ret; + }; + imports.wbg.__wbindgen_bigint_from_i64 = function(arg0) { + const ret = arg0; + return addHeapObject(ret); + }; + imports.wbg.__wbindgen_bigint_from_u64 = function(arg0) { + const ret = BigInt.asUintN(64, arg0); + return addHeapObject(ret); + }; + imports.wbg.__wbindgen_bigint_get_as_i64 = function(arg0, arg1) { + const v = getObject(arg1); + const ret = typeof(v) === 'bigint' ? v : undefined; + getDataViewMemory0().setBigInt64(arg0 + 8 * 1, isLikeNone(ret) ? BigInt(0) : ret, true); + getDataViewMemory0().setInt32(arg0 + 4 * 0, !isLikeNone(ret), true); + }; + imports.wbg.__wbindgen_boolean_get = function(arg0) { + const v = getObject(arg0); + const ret = typeof(v) === 'boolean' ? (v ? 1 : 0) : 2; + return ret; + }; + imports.wbg.__wbindgen_debug_string = function(arg0, arg1) { + const ret = debugString(getObject(arg1)); + const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + const len1 = WASM_VECTOR_LEN; + getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); + getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); + }; + imports.wbg.__wbindgen_error_new = function(arg0, arg1) { + const ret = new Error(getStringFromWasm0(arg0, arg1)); + return addHeapObject(ret); + }; + imports.wbg.__wbindgen_in = function(arg0, arg1) { + const ret = getObject(arg0) in getObject(arg1); + return ret; + }; + imports.wbg.__wbindgen_is_bigint = function(arg0) { + const ret = typeof(getObject(arg0)) === 'bigint'; + return ret; + }; + imports.wbg.__wbindgen_is_function = function(arg0) { + const ret = typeof(getObject(arg0)) === 'function'; + return ret; + }; + imports.wbg.__wbindgen_is_object = function(arg0) { + const val = getObject(arg0); + const ret = typeof(val) === 'object' && val !== null; + return ret; + }; + imports.wbg.__wbindgen_is_string = function(arg0) { + const ret = typeof(getObject(arg0)) === 'string'; + return ret; + }; + imports.wbg.__wbindgen_is_undefined = function(arg0) { + const ret = getObject(arg0) === undefined; + return ret; + }; + imports.wbg.__wbindgen_jsval_eq = function(arg0, arg1) { + const ret = getObject(arg0) === getObject(arg1); + return ret; + }; + imports.wbg.__wbindgen_jsval_loose_eq = function(arg0, arg1) { + const ret = getObject(arg0) == getObject(arg1); + return ret; + }; + imports.wbg.__wbindgen_memory = function() { + const ret = wasm.memory; + return addHeapObject(ret); + }; + imports.wbg.__wbindgen_number_get = function(arg0, arg1) { + const obj = getObject(arg1); + const ret = typeof(obj) === 'number' ? obj : undefined; + getDataViewMemory0().setFloat64(arg0 + 8 * 1, isLikeNone(ret) ? 0 : ret, true); + getDataViewMemory0().setInt32(arg0 + 4 * 0, !isLikeNone(ret), true); + }; + imports.wbg.__wbindgen_number_new = function(arg0) { + const ret = arg0; + return addHeapObject(ret); + }; + imports.wbg.__wbindgen_object_clone_ref = function(arg0) { + const ret = getObject(arg0); + return addHeapObject(ret); + }; + imports.wbg.__wbindgen_object_drop_ref = function(arg0) { + takeObject(arg0); + }; + imports.wbg.__wbindgen_string_get = function(arg0, arg1) { + const obj = getObject(arg1); + const ret = typeof(obj) === 'string' ? obj : undefined; + var ptr1 = isLikeNone(ret) ? 0 : passStringToWasm0(ret, wasm.__wbindgen_export_0, wasm.__wbindgen_export_1); + var len1 = WASM_VECTOR_LEN; + getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); + getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); + }; + imports.wbg.__wbindgen_string_new = function(arg0, arg1) { + const ret = getStringFromWasm0(arg0, arg1); + return addHeapObject(ret); + }; + imports.wbg.__wbindgen_throw = function(arg0, arg1) { + throw new Error(getStringFromWasm0(arg0, arg1)); + }; + + return imports; +} + +function __wbg_init_memory(imports, memory) { + +} + +function __wbg_finalize_init(instance, module) { + wasm = instance.exports; + __wbg_init.__wbindgen_wasm_module = module; + cachedDataViewMemory0 = null; + cachedUint8ArrayMemory0 = null; + + + + return wasm; +} + +function initSync(module) { + if (wasm !== undefined) return wasm; + + + if (typeof module !== 'undefined') { + if (Object.getPrototypeOf(module) === Object.prototype) { + ({module} = module) + } else { + console.warn('using deprecated parameters for `initSync()`; pass a single object instead') + } + } + + const imports = __wbg_get_imports(); + + __wbg_init_memory(imports); + + if (!(module instanceof WebAssembly.Module)) { + module = new WebAssembly.Module(module); + } + + const instance = new WebAssembly.Instance(module, imports); + + return __wbg_finalize_init(instance, module); +} + +async function __wbg_init(module_or_path) { + if (wasm !== undefined) return wasm; + + + if (typeof module_or_path !== 'undefined') { + if (Object.getPrototypeOf(module_or_path) === Object.prototype) { + ({module_or_path} = module_or_path) + } else { + console.warn('using deprecated parameters for the initialization function; pass a single object instead') + } + } + + if (typeof module_or_path === 'undefined') { + module_or_path = new URL('lingua_bg.wasm', import.meta.url); + } + const imports = __wbg_get_imports(); + + if (typeof module_or_path === 'string' || (typeof Request === 'function' && module_or_path instanceof Request) || (typeof URL === 'function' && module_or_path instanceof URL)) { + module_or_path = fetch(module_or_path); + } + + __wbg_init_memory(imports); + + const { instance, module } = await __wbg_load(await module_or_path, imports); + + return __wbg_finalize_init(instance, module); +} + +export { initSync }; +export default __wbg_init; diff --git a/bindings/typescript/wasm-web/lingua_bg.wasm b/bindings/typescript/wasm-web/lingua_bg.wasm new file mode 100644 index 00000000..63386d19 Binary files /dev/null and b/bindings/typescript/wasm-web/lingua_bg.wasm differ diff --git a/bindings/typescript/wasm-web/lingua_bg.wasm.d.ts b/bindings/typescript/wasm-web/lingua_bg.wasm.d.ts new file mode 100644 index 00000000..658330e2 --- /dev/null +++ b/bindings/typescript/wasm-web/lingua_bg.wasm.d.ts @@ -0,0 +1,26 @@ +/* tslint:disable */ +/* eslint-disable */ +export const memory: WebAssembly.Memory; +export const chat_completions_messages_to_lingua: (a: number, b: number) => void; +export const lingua_to_chat_completions_messages: (a: number, b: number) => void; +export const responses_messages_to_lingua: (a: number, b: number) => void; +export const lingua_to_responses_messages: (a: number, b: number) => void; +export const anthropic_messages_to_lingua: (a: number, b: number) => void; +export const lingua_to_anthropic_messages: (a: number, b: number) => void; +export const deduplicate_messages: (a: number, b: number) => void; +export const import_messages_from_spans: (a: number, b: number) => void; +export const import_and_deduplicate_messages: (a: number, b: number) => void; +export const validate_chat_completions_request: (a: number, b: number, c: number) => void; +export const validate_chat_completions_response: (a: number, b: number, c: number) => void; +export const validate_responses_request: (a: number, b: number, c: number) => void; +export const validate_responses_response: (a: number, b: number, c: number) => void; +export const validate_anthropic_request: (a: number, b: number, c: number) => void; +export const validate_anthropic_response: (a: number, b: number, c: number) => void; +export const validate_google_request: (a: number, b: number, c: number) => void; +export const validate_openai_request: (a: number, b: number, c: number) => void; +export const validate_openai_response: (a: number, b: number, c: number) => void; +export const validate_google_response: (a: number, b: number, c: number) => void; +export const __wbindgen_export_0: (a: number, b: number) => number; +export const __wbindgen_export_1: (a: number, b: number, c: number, d: number) => number; +export const __wbindgen_export_2: (a: number) => void; +export const __wbindgen_add_to_stack_pointer: (a: number) => number; diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index 862b9386..3b1ee784 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -21,11 +21,16 @@ use crate::providers::ClientHeaders; use crate::streaming::{bedrock_event_stream, single_bytes_stream, RawResponseStream}; use lingua::ProviderFormat; +/// Default anthropic_version for Bedrock's Anthropic Messages API. +/// See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +pub const DEFAULT_ANTHROPIC_VERSION: &str = "bedrock-2023-05-31"; + #[derive(Debug, Clone)] pub struct BedrockConfig { pub endpoint: Url, pub service: String, pub timeout: Option, + pub anthropic_version: String, } impl Default for BedrockConfig { @@ -35,10 +40,20 @@ impl Default for BedrockConfig { .expect("valid Bedrock endpoint"), service: "bedrock".to_string(), timeout: None, + anthropic_version: DEFAULT_ANTHROPIC_VERSION.to_string(), } } } +/// Bedrock API mode - determines which endpoint to use. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BedrockMode { + /// AWS Converse API - unified format for all Bedrock models + Converse, + /// Anthropic Messages API - native format for Claude models on Bedrock + AnthropicMessages, +} + #[derive(Debug, Clone)] pub struct BedrockProvider { client: Client, @@ -64,6 +79,7 @@ impl BedrockProvider { /// Extracts Bedrock-specific options from metadata: /// - `region`: AWS region (used to construct endpoint if not provided) /// - `service`: AWS service name (defaults to "bedrock") + /// - `anthropic_version`: API version for Anthropic models (defaults to "bedrock-2023-05-31") pub fn from_config( endpoint: Option<&Url>, timeout: Option, @@ -83,6 +99,9 @@ impl BedrockProvider { if let Some(service) = metadata.get("service").and_then(Value::as_str) { config.service = service.to_string(); } + if let Some(version) = metadata.get("anthropic_version").and_then(Value::as_str) { + config.anthropic_version = version.to_string(); + } if let Some(t) = timeout { config.timeout = Some(t); } @@ -90,11 +109,32 @@ impl BedrockProvider { Self::new(config) } - fn invoke_url(&self, model: &str, stream: bool) -> Result { - let path = if stream { - format!("model/{model}/converse-stream") + /// Determine which Bedrock API mode to use based on model name. + /// + /// Handles both direct model IDs (`anthropic.claude-*`) and inference profiles + /// (`us.anthropic.claude-*`, `global.anthropic.claude-*`, etc.). + pub fn determine_mode(&self, model: &str) -> BedrockMode { + if model.starts_with("anthropic.") || model.contains(".anthropic.") { + BedrockMode::AnthropicMessages } else { - format!("model/{model}/converse") + BedrockMode::Converse + } + } + + /// Build the invoke URL for a specific mode. + pub fn invoke_url_for_mode( + &self, + model: &str, + mode: &BedrockMode, + stream: bool, + ) -> Result { + let path = match (mode, stream) { + (BedrockMode::Converse, false) => format!("model/{model}/converse"), + (BedrockMode::Converse, true) => format!("model/{model}/converse-stream"), + (BedrockMode::AnthropicMessages, false) => format!("model/{model}/invoke"), + (BedrockMode::AnthropicMessages, true) => { + format!("model/{model}/invoke-with-response-stream") + } }; self.config .endpoint @@ -205,7 +245,13 @@ impl crate::providers::Provider for BedrockProvider { spec: &ModelSpec, _client_headers: &ClientHeaders, ) -> Result { - let url = self.invoke_url(&spec.model, false)?; + let mode = self.determine_mode(&spec.model); + let url = self.invoke_url_for_mode(&spec.model, &mode, false)?; + let payload = if mode == BedrockMode::AnthropicMessages { + prepare_anthropic_payload(payload, &self.config.anthropic_version)? + } else { + payload + }; #[cfg(feature = "tracing")] tracing::debug!( @@ -268,7 +314,13 @@ impl crate::providers::Provider for BedrockProvider { } // Router should have already added stream options to payload - let url = self.invoke_url(&spec.model, true)?; + let mode = self.determine_mode(&spec.model); + let url = self.invoke_url_for_mode(&spec.model, &mode, true)?; + let payload = if mode == BedrockMode::AnthropicMessages { + prepare_anthropic_payload(payload, &self.config.anthropic_version)? + } else { + payload + }; #[cfg(feature = "tracing")] tracing::debug!( @@ -358,3 +410,134 @@ fn extract_retry_after(status: StatusCode, _body: &str) -> Option { None } } + +/// Prepare an Anthropic-format payload for Bedrock by adding anthropic_version. +fn prepare_anthropic_payload(payload: Bytes, anthropic_version: &str) -> Result { + let mut body: lingua::serde_json::Value = lingua::serde_json::from_slice(&payload) + .map_err(|e| Error::InvalidRequest(format!("failed to parse payload: {e}")))?; + if let Some(obj) = body.as_object_mut() { + obj.insert( + "anthropic_version".into(), + lingua::serde_json::Value::String(anthropic_version.into()), + ); + } + let bytes = lingua::serde_json::to_vec(&body) + .map_err(|e| Error::InvalidRequest(format!("failed to serialize payload: {e}")))?; + Ok(Bytes::from(bytes)) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn provider() -> BedrockProvider { + let config = BedrockConfig { + endpoint: Url::parse("https://bedrock-runtime.us-east-1.amazonaws.com/").unwrap(), + service: "bedrock".into(), + timeout: None, + anthropic_version: DEFAULT_ANTHROPIC_VERSION.to_string(), + }; + BedrockProvider::new(config).unwrap() + } + + #[test] + fn selects_anthropic_mode_for_claude_models() { + let provider = provider(); + assert!(matches!( + provider.determine_mode("anthropic.claude-3-sonnet-20240229-v1:0"), + BedrockMode::AnthropicMessages + )); + assert!(matches!( + provider.determine_mode("anthropic.claude-3-haiku-20240307-v1:0"), + BedrockMode::AnthropicMessages + )); + } + + #[test] + fn selects_anthropic_mode_for_inference_profiles() { + let provider = provider(); + // US inference profile + assert!(matches!( + provider.determine_mode("us.anthropic.claude-haiku-4-5-20251001-v1:0"), + BedrockMode::AnthropicMessages + )); + // Global inference profile + assert!(matches!( + provider.determine_mode("global.anthropic.claude-sonnet-4-5-20250929-v1:0"), + BedrockMode::AnthropicMessages + )); + // EU inference profile + assert!(matches!( + provider.determine_mode("eu.anthropic.claude-3-sonnet"), + BedrockMode::AnthropicMessages + )); + } + + #[test] + fn selects_converse_mode_for_other_models() { + let provider = provider(); + assert!(matches!( + provider.determine_mode("amazon.titan-text-express-v1"), + BedrockMode::Converse + )); + assert!(matches!( + provider.determine_mode("meta.llama3-70b-instruct-v1:0"), + BedrockMode::Converse + )); + } + + #[test] + fn builds_invoke_endpoint_for_anthropic() { + let provider = provider(); + let url = provider + .invoke_url_for_mode( + "anthropic.claude-3-sonnet", + &BedrockMode::AnthropicMessages, + false, + ) + .unwrap(); + assert_eq!( + url.as_str(), + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet/invoke" + ); + } + + #[test] + fn builds_invoke_stream_endpoint_for_anthropic() { + let provider = provider(); + let url = provider + .invoke_url_for_mode( + "anthropic.claude-3-sonnet", + &BedrockMode::AnthropicMessages, + true, + ) + .unwrap(); + assert_eq!( + url.as_str(), + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet/invoke-with-response-stream" + ); + } + + #[test] + fn builds_converse_endpoint_for_others() { + let provider = provider(); + let url = provider + .invoke_url_for_mode("amazon.titan-text-express", &BedrockMode::Converse, false) + .unwrap(); + assert_eq!( + url.as_str(), + "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express/converse" + ); + } + + #[test] + fn prepares_anthropic_payload_with_version() { + let payload = Bytes::from(r#"{"model":"claude","messages":[]}"#); + let result = prepare_anthropic_payload(payload, DEFAULT_ANTHROPIC_VERSION).unwrap(); + let body: lingua::serde_json::Value = lingua::serde_json::from_slice(&result).unwrap(); + assert_eq!( + body.get("anthropic_version").unwrap(), + DEFAULT_ANTHROPIC_VERSION + ); + } +} diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index f3a4b9e5..234b2e7c 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -149,8 +149,9 @@ impl Router { client_headers: &ClientHeaders, ) -> Result { let (provider, auth, spec, strategy) = self.resolve_provider(model)?; - let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model)) - { + // Use spec.format for transformation - this allows composite providers + // (like Bedrock) to handle multiple formats based on the model's catalog entry + let payload = match lingua::transform_request(body.clone(), spec.format, Some(model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), @@ -206,8 +207,9 @@ impl Router { client_headers: &ClientHeaders, ) -> Result { let (provider, auth, spec, _) = self.resolve_provider(model)?; - let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model)) - { + // Use spec.format for transformation - this allows composite providers + // (like Bedrock) to handle multiple formats based on the model's catalog entry + let payload = match lingua::transform_request(body.clone(), spec.format, Some(model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), @@ -369,6 +371,20 @@ impl RouterBuilder { self } + /// Register an additional format for an existing provider alias. + /// + /// This allows a single provider to handle multiple formats, which is useful + /// for composite providers like Bedrock that can handle both Converse and + /// Anthropic formats. + pub fn add_provider_for_format( + mut self, + alias: impl Into, + format: ProviderFormat, + ) -> Self { + self.formats.insert(format, alias.into()); + self + } + pub fn add_auth(mut self, alias: impl Into, auth: AuthConfig) -> Self { self.auth_configs.insert(alias.into(), auth); self diff --git a/crates/braintrust-llm-router/tests/router.rs b/crates/braintrust-llm-router/tests/router.rs index 605ec3df..8f3610ac 100644 --- a/crates/braintrust-llm-router/tests/router.rs +++ b/crates/braintrust-llm-router/tests/router.rs @@ -383,3 +383,111 @@ async fn router_retries_and_propagates_terminal_error() { assert!(matches!(err, Error::Timeout)); assert_eq!(attempts.load(Ordering::SeqCst), 3); } + +/// Test that a provider can be registered for multiple formats via add_provider_for_format(). +/// This enables composite providers like Bedrock that handle both Converse and Anthropic formats. +#[tokio::test] +async fn router_supports_multi_format_provider() { + let mut catalog = ModelCatalog::empty(); + // Model A uses OpenAI format (primary format of the provider) + catalog.insert( + "model-a".into(), + ModelSpec { + model: "model-a".into(), + format: ProviderFormat::OpenAI, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + }, + ); + // Model B uses Anthropic format (secondary format via add_provider_for_format) + catalog.insert( + "model-b".into(), + ModelSpec { + model: "model-b".into(), + format: ProviderFormat::Anthropic, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + }, + ); + let catalog = Arc::new(catalog); + + // StubProvider returns OpenAI format from format() + // We register it for both OpenAI (via add_provider) and Anthropic (via add_provider_for_format) + let router = RouterBuilder::new() + .with_catalog(Arc::clone(&catalog)) + .add_provider("multi", StubProvider) + .add_provider_for_format("multi", ProviderFormat::Anthropic) + .add_auth( + "multi", + AuthConfig::ApiKey { + key: "test".into(), + header: None, + prefix: None, + }, + ) + .build() + .expect("router builds"); + + // Model A (OpenAI format) should route to "multi" provider + let body = to_body(json!({ + "model": "model-a", + "messages": [{"role": "user", "content": "Ping"}] + })); + let bytes = router + .complete( + body, + "model-a", + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await + .expect("model-a should route to multi provider"); + let response: Value = + braintrust_llm_router::serde_json::from_slice(&bytes).expect("valid json"); + assert_eq!( + response.get("model").and_then(Value::as_str), + Some("model-a") + ); + + // Model B (Anthropic format) should ALSO route to "multi" provider + // This verifies add_provider_for_format() registered the provider for Anthropic format + let body = to_body(json!({ + "model": "model-b", + "messages": [{"role": "user", "content": "Ping"}] + })); + let bytes = router + .complete( + body, + "model-b", + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await + .expect("model-b should also route to multi provider via add_provider_for_format"); + let response: Value = + braintrust_llm_router::serde_json::from_slice(&bytes).expect("valid json"); + assert_eq!( + response.get("model").and_then(Value::as_str), + Some("model-b") + ); +}