diff --git a/crates/braintrust-llm-router/src/catalog/model_list.json b/crates/braintrust-llm-router/src/catalog/model_list.json index e78b90fa..9507767c 100644 --- a/crates/braintrust-llm-router/src/catalog/model_list.json +++ b/crates/braintrust-llm-router/src/catalog/model_list.json @@ -2991,7 +2991,7 @@ "max_output_tokens": 4000 }, "anthropic.claude-sonnet-4-5-20250929-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3005,7 +3005,7 @@ "max_output_tokens": 64000 }, "us.anthropic.claude-sonnet-4-5-20250929-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3020,7 +3020,7 @@ "max_output_tokens": 64000 }, "anthropic.claude-sonnet-4-20250514-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3034,7 +3034,7 @@ "max_output_tokens": 64000 }, "us.anthropic.claude-sonnet-4-20250514-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3049,7 +3049,7 @@ "max_output_tokens": 64000 }, "anthropic.claude-3-7-sonnet-20250219-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3063,7 +3063,7 @@ "max_output_tokens": 8192 }, "us.anthropic.claude-3-7-sonnet-20250219-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3078,7 +3078,7 @@ "max_output_tokens": 8192 }, "anthropic.claude-haiku-4-5-20251001-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 0.8, @@ -3090,7 +3090,7 @@ "max_output_tokens": 64000 }, "us.anthropic.claude-haiku-4-5-20251001-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 0.8, @@ -3103,7 +3103,7 @@ "max_output_tokens": 64000 }, "anthropic.claude-3-5-haiku-20241022-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 0.8, @@ -3115,7 +3115,7 @@ "max_output_tokens": 8192 }, "us.anthropic.claude-3-5-haiku-20241022-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 0.8, @@ -3128,7 +3128,7 @@ "max_output_tokens": 8192 }, "anthropic.claude-3-5-sonnet-20241022-v2:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3140,7 +3140,7 @@ "max_output_tokens": 8192 }, "us.anthropic.claude-3-5-sonnet-20241022-v2:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3153,7 +3153,7 @@ "max_output_tokens": 8192 }, "apac.anthropic.claude-3-5-sonnet-20241022-v2:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3162,7 +3162,7 @@ "parent": "anthropic.claude-3-5-sonnet-20241022-v2:0" }, "anthropic.claude-3-5-sonnet-20240620-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3172,7 +3172,7 @@ "max_output_tokens": 4096 }, "us.anthropic.claude-3-5-sonnet-20240620-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3183,7 +3183,7 @@ "max_output_tokens": 4096 }, "apac.anthropic.claude-3-5-sonnet-20240620-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3192,7 +3192,7 @@ "parent": "anthropic.claude-3-5-sonnet-20240620-v1:0" }, "eu.anthropic.claude-3-5-sonnet-20240620-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3203,7 +3203,7 @@ "max_output_tokens": 4096 }, "anthropic.claude-opus-4-1-20250805-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 15, @@ -3217,7 +3217,7 @@ "max_output_tokens": 32000 }, "us.anthropic.claude-opus-4-1-20250805-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 15, @@ -3232,7 +3232,7 @@ "max_output_tokens": 32000 }, "anthropic.claude-opus-4-20250514-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 15, @@ -3246,7 +3246,7 @@ "max_output_tokens": 32000 }, "us.anthropic.claude-opus-4-20250514-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 15, @@ -3261,7 +3261,7 @@ "max_output_tokens": 32000 }, "anthropic.claude-3-opus-20240229-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 15, @@ -3271,7 +3271,7 @@ "max_output_tokens": 4096 }, "us.anthropic.claude-3-opus-20240229-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 15, @@ -3282,7 +3282,7 @@ "max_output_tokens": 4096 }, "anthropic.claude-3-sonnet-20240229-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3292,7 +3292,7 @@ "max_output_tokens": 4096 }, "us.anthropic.claude-3-sonnet-20240229-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3303,7 +3303,7 @@ "max_output_tokens": 4096 }, "apac.anthropic.claude-3-sonnet-20240229-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3312,7 +3312,7 @@ "parent": "anthropic.claude-3-sonnet-20240229-v1:0" }, "eu.anthropic.claude-3-sonnet-20240229-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 3, @@ -3323,7 +3323,7 @@ "max_output_tokens": 4096 }, "anthropic.claude-3-haiku-20240307-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 0.25, @@ -3333,7 +3333,7 @@ "max_output_tokens": 4096 }, "us.anthropic.claude-3-haiku-20240307-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 0.25, @@ -3344,7 +3344,7 @@ "max_output_tokens": 4096 }, "apac.anthropic.claude-3-haiku-20240307-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 0.25, @@ -3353,7 +3353,7 @@ "parent": "anthropic.claude-3-haiku-20240307-v1:0" }, "eu.anthropic.claude-3-haiku-20240307-v1:0": { - "format": "converse", + "format": "anthropic", "flavor": "chat", "multimodal": true, "input_cost_per_mil_tokens": 0.25, diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index 862b9386..00afeda0 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -18,14 +18,25 @@ use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; use crate::providers::ClientHeaders; -use crate::streaming::{bedrock_event_stream, single_bytes_stream, RawResponseStream}; +use crate::streaming::{ + bedrock_event_stream, bedrock_messages_event_stream, single_bytes_stream, RawResponseStream, +}; use lingua::ProviderFormat; +const BEDROCK_ANTHROPIC_VERSION: &str = "bedrock-2023-05-31"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BedrockMode { + Converse, + AnthropicMessages, +} + #[derive(Debug, Clone)] pub struct BedrockConfig { pub endpoint: Url, pub service: String, pub timeout: Option, + pub anthropic_version: String, } impl Default for BedrockConfig { @@ -35,6 +46,7 @@ impl Default for BedrockConfig { .expect("valid Bedrock endpoint"), service: "bedrock".to_string(), timeout: None, + anthropic_version: BEDROCK_ANTHROPIC_VERSION.to_string(), } } } @@ -90,11 +102,30 @@ impl BedrockProvider { Self::new(config) } - fn invoke_url(&self, model: &str, stream: bool) -> Result { - let path = if stream { - format!("model/{model}/converse-stream") + fn determine_mode(&self, model: &str) -> BedrockMode { + if is_anthropic_model(model) { + BedrockMode::AnthropicMessages } else { - format!("model/{model}/converse") + BedrockMode::Converse + } + } + + fn invoke_url(&self, model: &str, stream: bool, mode: BedrockMode) -> Result { + let path = match mode { + BedrockMode::Converse => { + if stream { + format!("model/{model}/converse-stream") + } else { + format!("model/{model}/converse") + } + } + BedrockMode::AnthropicMessages => { + if stream { + format!("model/{model}/invoke-with-response-stream") + } else { + format!("model/{model}/invoke") + } + } }; self.config .endpoint @@ -102,6 +133,25 @@ impl BedrockProvider { .map_err(|e| Error::InvalidRequest(format!("failed to build invoke url: {e}"))) } + fn prepare_anthropic_payload(&self, payload: Bytes) -> Result { + let mut body: Value = lingua::serde_json::from_slice(&payload) + .map_err(|e| Error::InvalidRequest(format!("invalid JSON payload: {e}")))?; + if let Some(obj) = body.as_object_mut() { + obj.insert( + "anthropic_version".to_string(), + Value::String(self.config.anthropic_version.clone()), + ); + // Bedrock Messages API does not accept `model` or `stream` in the body; + // the model is specified in the URL path and streaming is controlled by + // the endpoint choice (/invoke vs /invoke-with-response-stream). + obj.remove("model"); + obj.remove("stream"); + } + let bytes = lingua::serde_json::to_vec(&body) + .map_err(|e| Error::InvalidRequest(format!("failed to serialize payload: {e}")))?; + Ok(Bytes::from(bytes)) + } + fn sign_request(&self, url: &Url, body: &[u8], auth: &AuthConfig) -> Result { let (access_key, secret_key, session_token, region, service) = auth .aws_credentials() @@ -205,23 +255,29 @@ impl crate::providers::Provider for BedrockProvider { spec: &ModelSpec, _client_headers: &ClientHeaders, ) -> Result { - let url = self.invoke_url(&spec.model, false)?; + let mode = self.determine_mode(&spec.model); + let final_payload = match mode { + BedrockMode::AnthropicMessages => self.prepare_anthropic_payload(payload)?, + BedrockMode::Converse => payload, + }; + let url = self.invoke_url(&spec.model, false, mode)?; #[cfg(feature = "tracing")] tracing::debug!( target: "bt.router.provider.http", llm_provider = "bedrock", http_url = %url, + bedrock_mode = ?mode, "sending request to Bedrock" ); - let headers = self.build_headers(&url, payload.as_ref(), auth)?; + let headers = self.build_headers(&url, final_payload.as_ref(), auth)?; let response = self .client .post(url) .headers(headers) - .body(payload) + .body(final_payload) .send() .await?; @@ -267,8 +323,12 @@ impl crate::providers::Provider for BedrockProvider { return Ok(single_bytes_stream(response)); } - // Router should have already added stream options to payload - let url = self.invoke_url(&spec.model, true)?; + let mode = self.determine_mode(&spec.model); + let final_payload = match mode { + BedrockMode::AnthropicMessages => self.prepare_anthropic_payload(payload)?, + BedrockMode::Converse => payload, + }; + let url = self.invoke_url(&spec.model, true, mode)?; #[cfg(feature = "tracing")] tracing::debug!( @@ -276,16 +336,17 @@ impl crate::providers::Provider for BedrockProvider { llm_provider = "bedrock", http_url = %url, llm_streaming = true, + bedrock_mode = ?mode, "sending streaming request to Bedrock" ); - let headers = self.build_headers(&url, payload.as_ref(), auth)?; + let headers = self.build_headers(&url, final_payload.as_ref(), auth)?; let response = self .client .post(url) .headers(headers) - .body(payload) + .body(final_payload) .send() .await?; @@ -317,7 +378,10 @@ impl crate::providers::Provider for BedrockProvider { }); } - Ok(bedrock_event_stream(response)) + match mode { + BedrockMode::AnthropicMessages => Ok(bedrock_messages_event_stream(response)), + BedrockMode::Converse => Ok(bedrock_event_stream(response)), + } } async fn health_check(&self, auth: &AuthConfig) -> Result<()> { @@ -351,6 +415,10 @@ impl crate::providers::Provider for BedrockProvider { } } +fn is_anthropic_model(model: &str) -> bool { + model.starts_with("anthropic.") || model.contains(".anthropic.") +} + fn extract_retry_after(status: StatusCode, _body: &str) -> Option { if status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error() { Some(Duration::from_secs(2)) @@ -358,3 +426,198 @@ fn extract_retry_after(status: StatusCode, _body: &str) -> Option { None } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn determine_mode_routes_models_correctly() { + let provider = BedrockProvider::new(BedrockConfig::default()).unwrap(); + + let cases: &[(&str, BedrockMode)] = &[ + // Anthropic models -> AnthropicMessages + ("anthropic.claude-3-5-sonnet-20241022-v2:0", BedrockMode::AnthropicMessages), + ("anthropic.claude-haiku-4-5-20251001-v1:0", BedrockMode::AnthropicMessages), + // Region-prefixed inference profiles + ("us.anthropic.claude-3-5-sonnet-20241022-v2:0", BedrockMode::AnthropicMessages), + ("eu.anthropic.claude-3-5-sonnet-20241022-v2:0", BedrockMode::AnthropicMessages), + ("apac.anthropic.claude-3-5-sonnet-20241022-v2:0", BedrockMode::AnthropicMessages), + ("apac.anthropic.claude-3-haiku-20240307-v1:0", BedrockMode::AnthropicMessages), + // Non-Anthropic models -> Converse + ("amazon.nova-micro-v1:0", BedrockMode::Converse), + ("meta.llama3-8b-instruct-v1:0", BedrockMode::Converse), + ("cohere.command-r-v1:0", BedrockMode::Converse), + // Substrings that should NOT match + ("notanthropic.claude-3-5-sonnet", BedrockMode::Converse), + ("myanthropic.model", BedrockMode::Converse), + ("", BedrockMode::Converse), + ]; + + for (model, expected) in cases { + assert_eq!( + provider.determine_mode(model), + *expected, + "wrong mode for model: {model}" + ); + } + } + + #[test] + fn invoke_url_converse_mode() { + let provider = BedrockProvider::new(BedrockConfig::default()).unwrap(); + + let url = provider + .invoke_url("amazon.nova-micro-v1:0", false, BedrockMode::Converse) + .unwrap(); + assert!(url + .as_str() + .contains("/model/amazon.nova-micro-v1:0/converse")); + + let url = provider + .invoke_url("amazon.nova-micro-v1:0", true, BedrockMode::Converse) + .unwrap(); + assert!(url + .as_str() + .contains("/model/amazon.nova-micro-v1:0/converse-stream")); + } + + #[test] + fn invoke_url_anthropic_messages_mode() { + let provider = BedrockProvider::new(BedrockConfig::default()).unwrap(); + + let url = provider + .invoke_url( + "anthropic.claude-3-5-sonnet-20241022-v2:0", + false, + BedrockMode::AnthropicMessages, + ) + .unwrap(); + assert!(url + .as_str() + .contains("/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke")); + + let url = provider + .invoke_url( + "anthropic.claude-3-5-sonnet-20241022-v2:0", + true, + BedrockMode::AnthropicMessages, + ) + .unwrap(); + assert!(url.as_str().contains( + "/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke-with-response-stream" + )); + } + + #[test] + fn prepare_anthropic_payload_injects_version_and_strips_model() { + let provider = BedrockProvider::new(BedrockConfig::default()).unwrap(); + + let payload = Bytes::from( + r#"{"model":"anthropic.claude-3-5-sonnet","messages":[{"role":"user","content":"hi"}],"max_tokens":100,"stream":true}"#, + ); + let result = provider.prepare_anthropic_payload(payload).unwrap(); + let body: Value = lingua::serde_json::from_slice(&result).unwrap(); + + assert_eq!( + body.get("anthropic_version").and_then(Value::as_str), + Some(BEDROCK_ANTHROPIC_VERSION) + ); + assert!( + body.get("model").is_none(), + "model field should be stripped for Bedrock Messages API" + ); + assert!( + body.get("stream").is_none(), + "stream field should be stripped for Bedrock Messages API" + ); + assert_eq!(body.get("max_tokens").and_then(Value::as_u64), Some(100)); + } + + #[test] + fn prepare_anthropic_payload_without_model_or_stream() { + let provider = BedrockProvider::new(BedrockConfig::default()).unwrap(); + + let payload = Bytes::from( + r#"{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}"#, + ); + let result = provider.prepare_anthropic_payload(payload).unwrap(); + let body: Value = lingua::serde_json::from_slice(&result).unwrap(); + + assert_eq!( + body.get("anthropic_version").and_then(Value::as_str), + Some(BEDROCK_ANTHROPIC_VERSION) + ); + assert!(body.get("model").is_none()); + assert!(body.get("stream").is_none()); + assert_eq!(body.get("max_tokens").and_then(Value::as_u64), Some(100)); + } + + #[test] + fn prepare_anthropic_payload_preserves_other_fields() { + let provider = BedrockProvider::new(BedrockConfig::default()).unwrap(); + + let payload = Bytes::from( + r#"{"model":"anthropic.claude-3-5-sonnet","messages":[{"role":"user","content":"hi"}],"max_tokens":4096,"temperature":0.7,"top_p":0.9,"stop_sequences":["END"],"stream":false}"#, + ); + let result = provider.prepare_anthropic_payload(payload).unwrap(); + let body: Value = lingua::serde_json::from_slice(&result).unwrap(); + + assert_eq!( + body.get("anthropic_version").and_then(Value::as_str), + Some(BEDROCK_ANTHROPIC_VERSION) + ); + assert!(body.get("model").is_none()); + assert!(body.get("stream").is_none()); + assert_eq!(body.get("max_tokens").and_then(Value::as_u64), Some(4096)); + assert_eq!( + body.get("temperature").and_then(Value::as_f64), + Some(0.7) + ); + assert_eq!(body.get("top_p").and_then(Value::as_f64), Some(0.9)); + assert!(body.get("stop_sequences").is_some()); + assert!(body.get("messages").is_some()); + } + + #[test] + fn prepare_anthropic_payload_custom_version() { + let mut config = BedrockConfig::default(); + config.anthropic_version = "custom-2024-01-01".to_string(); + let provider = BedrockProvider::new(config).unwrap(); + + let payload = Bytes::from( + r#"{"model":"anthropic.claude-3-5-sonnet","messages":[],"max_tokens":100}"#, + ); + let result = provider.prepare_anthropic_payload(payload).unwrap(); + let body: Value = lingua::serde_json::from_slice(&result).unwrap(); + + assert_eq!( + body.get("anthropic_version").and_then(Value::as_str), + Some("custom-2024-01-01") + ); + } + + #[test] + fn invoke_url_does_not_include_converse_stream_suffix_for_non_streaming_anthropic() { + let provider = BedrockProvider::new(BedrockConfig::default()).unwrap(); + + let url = provider + .invoke_url( + "anthropic.claude-3-5-sonnet-20241022-v2:0", + false, + BedrockMode::AnthropicMessages, + ) + .unwrap(); + + assert!( + !url.as_str().contains("converse"), + "Anthropic mode should not use converse endpoints" + ); + assert!( + !url.as_str().contains("response-stream"), + "non-streaming should not include response-stream" + ); + assert!(url.as_str().ends_with("/invoke")); + } + +} diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 5fca83ea..40722099 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -18,7 +18,7 @@ use crate::retry::{RetryPolicy, RetryStrategy}; use crate::streaming::{transform_stream, ResponseStream}; use lingua::serde_json::Value; use lingua::ProviderFormat; -use lingua::{TransformError, TransformResult}; +use lingua::TransformResult; // Re-export for convenience in dependent crates pub use lingua::{extract_request_hints, RequestHints}; @@ -149,11 +149,10 @@ impl Router { client_headers: &ClientHeaders, ) -> Result { let (provider, auth, spec, strategy) = self.resolve_provider(model)?; - let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model)) - { + let payload = match lingua::transform_request(body.clone(), spec.format, Some(model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, - Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), + Err(e) if e.is_unsupported_target_format() => body.clone(), Err(e) => return Err(e.into()), }; @@ -205,11 +204,10 @@ impl Router { client_headers: &ClientHeaders, ) -> Result { let (provider, auth, spec, _) = self.resolve_provider(model)?; - let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model)) - { + let payload = match lingua::transform_request(body.clone(), spec.format, Some(model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, - Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), + Err(e) if e.is_unsupported_target_format() => body.clone(), Err(e) => return Err(e.into()), }; @@ -368,6 +366,19 @@ impl RouterBuilder { self } + /// Register an existing provider alias for an additional format. + /// + /// This allows a single provider to serve models with different catalog formats. + /// The provider must already have been added via `add_provider` or `add_provider_arc`. + pub fn add_provider_for_format( + mut self, + alias: impl Into, + format: ProviderFormat, + ) -> Self { + self.formats.insert(format, alias.into()); + self + } + pub fn add_auth(mut self, alias: impl Into, auth: AuthConfig) -> Self { self.auth_configs.insert(alias.into(), auth); self diff --git a/crates/braintrust-llm-router/src/streaming.rs b/crates/braintrust-llm-router/src/streaming.rs index 848654f0..f642cdf1 100644 --- a/crates/braintrust-llm-router/src/streaming.rs +++ b/crates/braintrust-llm-router/src/streaming.rs @@ -6,6 +6,8 @@ use futures::Stream; use reqwest::Response; use crate::error::{Error, Result}; +#[cfg(feature = "provider-bedrock")] +use lingua::serde_json::Value; use lingua::ProviderFormat; use lingua::TransformResult; @@ -305,6 +307,142 @@ pub fn bedrock_event_stream(response: Response) -> RawResponseStream { Box::pin(RawBedrockEventStream::new(response.bytes_stream())) } +/// Bedrock Messages API event stream that yields raw Anthropic JSON payloads. +/// +/// Uses the same AWS binary event stream decoder as the Converse stream but +/// emits the payload bytes directly without wrapping in `{"eventType": payload}`. +/// The payloads are already valid Anthropic streaming JSON events. +#[cfg(feature = "provider-bedrock")] +struct RawBedrockMessagesEventStream +where + S: Stream> + Unpin + Send + 'static, +{ + inner: S, + buffer: BytesMut, + decoder: aws_smithy_eventstream::frame::MessageFrameDecoder, + finished: bool, +} + +#[cfg(feature = "provider-bedrock")] +impl RawBedrockMessagesEventStream +where + S: Stream> + Unpin + Send + 'static, +{ + fn new(inner: S) -> Self { + Self { + inner, + buffer: BytesMut::new(), + decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(), + finished: false, + } + } +} + +#[cfg(feature = "provider-bedrock")] +impl Stream for RawBedrockMessagesEventStream +where + S: Stream> + Unpin + Send + 'static, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use aws_smithy_eventstream::frame::DecodedFrame; + + let this = self.get_mut(); + + if this.finished { + return Poll::Ready(None); + } + + loop { + match this.decoder.decode_frame(&mut this.buffer) { + Ok(DecodedFrame::Complete(message)) => { + let payload = message.payload(); + if payload.is_empty() { + continue; + } + + // The invoke-with-response-stream payload is + // {"bytes":""} + // We need to extract and decode the bytes field. + let json_bytes = match extract_bedrock_invoke_payload(payload) { + Ok(Some(decoded)) => decoded, + Ok(None) => continue, + Err(e) => return Poll::Ready(Some(Err(e))), + }; + + return Poll::Ready(Some(Ok(json_bytes))); + } + Ok(DecodedFrame::Incomplete) => { + // Need more data, fall through to poll inner stream + } + Err(e) => { + return Poll::Ready(Some(Err(Error::Provider { + provider: "bedrock".to_string(), + source: anyhow::anyhow!("Event stream decode error: {}", e), + retry_after: None, + http: None, + }))); + } + } + + match Pin::new(&mut this.inner).poll_next(cx) { + Poll::Ready(Some(Ok(bytes))) => { + this.buffer.extend_from_slice(&bytes); + } + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))), + Poll::Ready(None) => { + this.finished = true; + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + } + } + } +} + +/// Create a Bedrock Messages API event stream that yields raw Anthropic JSON payloads. +/// +/// Uses the AWS binary event stream decoder but emits payloads directly +/// (no `{"eventType": payload}` wrapping). For use with Anthropic models on Bedrock. +#[cfg(feature = "provider-bedrock")] +pub fn bedrock_messages_event_stream(response: Response) -> RawResponseStream { + Box::pin(RawBedrockMessagesEventStream::new(response.bytes_stream())) +} + +/// Extract the Anthropic JSON payload from a Bedrock invoke-with-response-stream event. +/// +/// The event payload has the shape `{"bytes":""}`. +/// Returns `Ok(Some(decoded_bytes))` on success, `Ok(None)` if the payload +/// should be skipped, or `Err` on decode failure. +#[cfg(feature = "provider-bedrock")] +fn extract_bedrock_invoke_payload(raw: &[u8]) -> Result> { + use base64::Engine; + + let wrapper: Value = lingua::serde_json::from_slice(raw).map_err(|e| Error::Provider { + provider: "bedrock".to_string(), + source: anyhow::anyhow!("failed to parse invoke stream event: {}", e), + retry_after: None, + http: None, + })?; + + let b64 = match wrapper.get("bytes").and_then(Value::as_str) { + Some(s) => s, + None => return Ok(None), + }; + + let decoded = base64::engine::general_purpose::STANDARD + .decode(b64) + .map_err(|e| Error::Provider { + provider: "bedrock".to_string(), + source: anyhow::anyhow!("failed to base64-decode invoke stream event: {}", e), + retry_after: None, + http: None, + })?; + + Ok(Some(Bytes::from(decoded))) +} + fn split_event(buffer: &BytesMut) -> Option<(Bytes, BytesMut)> { // Check for \r\n\r\n first (4-byte CRLF delimiter) if let Some(index) = buffer.windows(4).position(|w| w == b"\r\n\r\n") { @@ -350,4 +488,77 @@ mod tests { buffer = rest; assert!(!buffer.is_empty()); } + + #[cfg(feature = "provider-bedrock")] + mod bedrock_messages_stream { + use super::*; + + #[test] + fn extract_bedrock_invoke_payload_decodes_base64() { + use base64::Engine; + + let inner_json = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#; + let encoded = base64::engine::general_purpose::STANDARD.encode(inner_json); + let wrapper = format!(r#"{{"bytes":"{}"}}"#, encoded); + + let result = extract_bedrock_invoke_payload(wrapper.as_bytes()).unwrap(); + assert!(result.is_some()); + let decoded = result.unwrap(); + assert_eq!(decoded.as_ref(), inner_json.as_bytes()); + } + + #[test] + fn extract_bedrock_invoke_payload_returns_none_without_bytes_field() { + let payload = br#"{"other_field": "value"}"#; + let result = extract_bedrock_invoke_payload(payload).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn extract_bedrock_invoke_payload_errors_on_invalid_json() { + let payload = b"not json"; + let result = extract_bedrock_invoke_payload(payload); + assert!(result.is_err()); + } + + #[test] + fn extract_bedrock_invoke_payload_errors_on_invalid_base64() { + let payload = br#"{"bytes": "!!!not-valid-base64!!!"}"#; + let result = extract_bedrock_invoke_payload(payload); + assert!(result.is_err()); + } + + #[test] + fn extract_bedrock_invoke_payload_handles_message_start_event() { + use base64::Engine; + + let inner_json = r#"{"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":1}}}"#; + let encoded = base64::engine::general_purpose::STANDARD.encode(inner_json); + let wrapper = format!(r#"{{"bytes":"{}"}}"#, encoded); + + let result = extract_bedrock_invoke_payload(wrapper.as_bytes()).unwrap(); + assert!(result.is_some()); + let decoded = result.unwrap(); + let decoded_str = std::str::from_utf8(&decoded).unwrap(); + assert!(decoded_str.contains("message_start")); + assert!(decoded_str.contains("msg_123")); + } + + #[test] + fn extract_bedrock_invoke_payload_handles_message_stop_event() { + use base64::Engine; + + let inner_json = + r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":15}}"#; + let encoded = base64::engine::general_purpose::STANDARD.encode(inner_json); + let wrapper = format!(r#"{{"bytes":"{}"}}"#, encoded); + + let result = extract_bedrock_invoke_payload(wrapper.as_bytes()).unwrap(); + assert!(result.is_some()); + let decoded = result.unwrap(); + let decoded_str = std::str::from_utf8(&decoded).unwrap(); + assert!(decoded_str.contains("message_delta")); + assert!(decoded_str.contains("end_turn")); + } + } } diff --git a/crates/braintrust-llm-router/tests/router.rs b/crates/braintrust-llm-router/tests/router.rs index 605ec3df..63bf2d7e 100644 --- a/crates/braintrust-llm-router/tests/router.rs +++ b/crates/braintrust-llm-router/tests/router.rs @@ -383,3 +383,256 @@ async fn router_retries_and_propagates_terminal_error() { assert!(matches!(err, Error::Timeout)); assert_eq!(attempts.load(Ordering::SeqCst), 3); } + +#[derive(Clone)] +struct PayloadCapturingProvider { + received: Arc>>, +} + +#[async_trait] +impl Provider for PayloadCapturingProvider { + fn id(&self) -> &'static str { + "capturing" + } + + fn format(&self) -> ProviderFormat { + ProviderFormat::OpenAI + } + + async fn complete( + &self, + payload: Bytes, + _auth: &AuthConfig, + _spec: &ModelSpec, + _client_headers: &ClientHeaders, + ) -> braintrust_llm_router::Result { + *self.received.lock().unwrap() = Some(payload); + + let response = json!({ + "id": "test", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "test", + "stop_reason": "end_turn", + "usage": {"input_tokens": 1, "output_tokens": 1} + }); + Ok(Bytes::from( + braintrust_llm_router::serde_json::to_vec(&response).unwrap(), + )) + } + + async fn complete_stream( + &self, + _payload: Bytes, + _auth: &AuthConfig, + _spec: &ModelSpec, + _client_headers: &ClientHeaders, + ) -> braintrust_llm_router::Result { + Ok(Box::pin(tokio_stream::empty())) + } + + async fn health_check(&self, _auth: &AuthConfig) -> braintrust_llm_router::Result<()> { + Ok(()) + } +} + +#[tokio::test] +async fn router_supports_multi_format_provider() { + let mut catalog = ModelCatalog::empty(); + catalog.insert( + "openai-model".into(), + ModelSpec { + model: "openai-model".into(), + format: ProviderFormat::OpenAI, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + }, + ); + catalog.insert( + "anthropic-model".into(), + ModelSpec { + model: "anthropic-model".into(), + format: ProviderFormat::Anthropic, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + }, + ); + catalog.insert( + "converse-model".into(), + ModelSpec { + model: "converse-model".into(), + format: ProviderFormat::Converse, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + }, + ); + let catalog = Arc::new(catalog); + + let received = Arc::new(std::sync::Mutex::new(None)); + let provider = PayloadCapturingProvider { + received: Arc::clone(&received), + }; + + let router = RouterBuilder::new() + .with_catalog(catalog) + .add_provider("multi", provider) + .add_provider_for_format("multi", ProviderFormat::Anthropic) + .add_provider_for_format("multi", ProviderFormat::Converse) + .add_auth( + "multi", + AuthConfig::ApiKey { + key: "test".into(), + header: None, + prefix: None, + }, + ) + .build() + .expect("router builds"); + + // Send OpenAI-format request targeting an Anthropic-format model. + // The router should transform the request to Anthropic format using spec.format. + let body = to_body(json!({ + "model": "anthropic-model", + "messages": [{"role": "user", "content": "hello"}] + })); + + let _ = router + .complete( + body, + "anthropic-model", + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await; + + let payload_bytes = received + .lock() + .unwrap() + .take() + .expect("provider received payload"); + let payload: Value = braintrust_llm_router::serde_json::from_slice(&payload_bytes).unwrap(); + + // The payload should have been transformed to Anthropic format + assert!( + payload.get("max_tokens").is_some(), + "payload should be in Anthropic format (has max_tokens)" + ); + + // Now test OpenAI-format model through the same provider + *received.lock().unwrap() = None; + let body = to_body(json!({ + "model": "openai-model", + "messages": [{"role": "user", "content": "hello"}] + })); + + let _ = router + .complete( + body, + "openai-model", + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await; + + let payload_bytes = received + .lock() + .unwrap() + .take() + .expect("provider received payload"); + let payload: Value = braintrust_llm_router::serde_json::from_slice(&payload_bytes).unwrap(); + + // The payload should pass through as OpenAI format (no max_tokens added by Anthropic transform) + assert!( + payload.get("model").is_some(), + "payload should still be in OpenAI format" + ); + + // Now test Converse-format model through the same provider + *received.lock().unwrap() = None; + let body = to_body(json!({ + "model": "converse-model", + "messages": [{"role": "user", "content": "hello"}] + })); + + let _ = router + .complete( + body, + "converse-model", + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await; + + let payload_bytes = received + .lock() + .unwrap() + .take() + .expect("provider received payload"); + let payload: Value = braintrust_llm_router::serde_json::from_slice(&payload_bytes).unwrap(); + + // The payload should have been transformed to Converse format + assert!( + payload.get("modelId").is_some(), + "payload should be in Converse format (has modelId)" + ); + + // Now test Converse → Converse passthrough + *received.lock().unwrap() = None; + let body = to_body(json!({ + "modelId": "converse-model", + "messages": [{"role": "user", "content": [{"text": "hello"}]}] + })); + + let _ = router + .complete( + body, + "converse-model", + ProviderFormat::Converse, + &ClientHeaders::default(), + ) + .await; + + let payload_bytes = received + .lock() + .unwrap() + .take() + .expect("provider received payload"); + let payload: Value = braintrust_llm_router::serde_json::from_slice(&payload_bytes).unwrap(); + + // The payload should pass through as Converse format + assert!( + payload.get("modelId").is_some(), + "Converse passthrough should preserve modelId" + ); +} diff --git a/crates/lingua/src/processing/transform.rs b/crates/lingua/src/processing/transform.rs index 4fb3b489..68acd934 100644 --- a/crates/lingua/src/processing/transform.rs +++ b/crates/lingua/src/processing/transform.rs @@ -79,6 +79,10 @@ impl TransformError { | TransformError::FromUniversalFailed(_) ) } + + pub fn is_unsupported_target_format(&self) -> bool { + matches!(self, TransformError::UnsupportedTargetFormat(_)) + } } impl From for TransformError { diff --git a/payloads/cases/advanced.ts b/payloads/cases/advanced.ts index 2ea18699..d3a4d243 100644 --- a/payloads/cases/advanced.ts +++ b/payloads/cases/advanced.ts @@ -4,7 +4,7 @@ import { OPENAI_CHAT_COMPLETIONS_MODEL, OPENAI_RESPONSES_MODEL, ANTHROPIC_MODEL, - BEDROCK_MODEL, + BEDROCK_ANTH_MODEL, } from "./models"; const IMAGE_BASE64 = @@ -100,7 +100,7 @@ export const advancedCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -178,7 +178,7 @@ export const advancedCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -236,7 +236,7 @@ export const advancedCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -364,7 +364,7 @@ export const advancedCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", diff --git a/payloads/cases/models.ts b/payloads/cases/models.ts index e42a3eb5..ecca1d52 100644 --- a/payloads/cases/models.ts +++ b/payloads/cases/models.ts @@ -7,4 +7,5 @@ export const ANTHROPIC_MODEL = "claude-sonnet-4-20250514"; // For Anthropic structured outputs (requires Sonnet 4.5+ for JSON schema output_format) export const ANTHROPIC_STRUCTURED_OUTPUT_MODEL = "claude-sonnet-4-5-20250929"; export const GOOGLE_MODEL = "gemini-2.5-flash"; -export const BEDROCK_MODEL = "us.anthropic.claude-haiku-4-5-20251001-v1:0"; +export const BEDROCK_ANTH_MODEL = "us.anthropic.claude-haiku-4-5-20251001-v1:0"; +export const BEDROCK_CONVERSE_MODEL = "amazon.nova-micro-v1:0"; diff --git a/payloads/cases/simple.ts b/payloads/cases/simple.ts index 8fbdf0f6..bdab119b 100644 --- a/payloads/cases/simple.ts +++ b/payloads/cases/simple.ts @@ -4,7 +4,7 @@ import { OPENAI_CHAT_COMPLETIONS_MODEL, OPENAI_RESPONSES_MODEL, ANTHROPIC_MODEL, - BEDROCK_MODEL, + BEDROCK_ANTH_MODEL, } from "./models"; // Simple test cases - basic functionality testing @@ -54,7 +54,7 @@ export const simpleCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -113,7 +113,7 @@ export const simpleCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -181,7 +181,7 @@ export const simpleCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -313,7 +313,7 @@ export const simpleCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", diff --git a/payloads/scripts/validation/index.ts b/payloads/scripts/validation/index.ts index 91b569ff..5f1788a6 100644 --- a/payloads/scripts/validation/index.ts +++ b/payloads/scripts/validation/index.ts @@ -14,7 +14,8 @@ import { OPENAI_CHAT_COMPLETIONS_MODEL, ANTHROPIC_STRUCTURED_OUTPUT_MODEL, GOOGLE_MODEL, - BEDROCK_MODEL, + BEDROCK_ANTH_MODEL, + BEDROCK_CONVERSE_MODEL, } from "../../cases/models"; import { proxyCases, @@ -226,7 +227,8 @@ const PROVIDER_REGISTRY: Record = { openai: OPENAI_CHAT_COMPLETIONS_MODEL, anthropic: ANTHROPIC_STRUCTURED_OUTPUT_MODEL, google: GOOGLE_MODEL, - bedrock: BEDROCK_MODEL, + bedrock: BEDROCK_ANTH_MODEL, + "bedrock-converse": BEDROCK_CONVERSE_MODEL, }; /** @@ -451,22 +453,16 @@ export async function runValidation( return result; } - // Override model only for cross-provider testing - // OpenAI formats (chat-completions, responses) with non-OpenAI providers + // Override model for cross-provider testing (any format with non-default provider) if ( providerAlias !== "default" && providerAlias !== "openai" && // Don't override for OpenAI - tests have correct models PROVIDER_REGISTRY[providerAlias] ) { - const isOpenAIFormat = - format === "chat-completions" || format === "responses"; - if (isOpenAIFormat) { - // Override for cross-provider translation testing - request = { - ...request, - model: PROVIDER_REGISTRY[providerAlias], - }; - } + request = { + ...request, + model: PROVIDER_REGISTRY[providerAlias], + }; } // Execute through proxy