From 7440725e9827b18e4d674137b75684944d4e851d Mon Sep 17 00:00:00 2001 From: Ken Jiang Date: Wed, 4 Feb 2026 21:29:05 -0500 Subject: [PATCH 1/2] refactor capabilities --- crates/lingua/src/processing/transform.rs | 8 +- crates/lingua/src/providers/openai/adapter.rs | 336 ++++-------------- .../src/providers/openai/capabilities.rs | 174 +++------ crates/lingua/src/providers/openai/mod.rs | 4 +- .../src/providers/openai/responses_adapter.rs | 18 +- 5 files changed, 118 insertions(+), 422 deletions(-) diff --git a/crates/lingua/src/processing/transform.rs b/crates/lingua/src/processing/transform.rs index 99f16d0..4fb3b48 100644 --- a/crates/lingua/src/processing/transform.rs +++ b/crates/lingua/src/processing/transform.rs @@ -17,7 +17,7 @@ use crate::capabilities::ProviderFormat; use crate::error::ConvertError; use crate::processing::adapters::{adapter_for_format, adapters, ProviderAdapter}; #[cfg(feature = "openai")] -use crate::providers::openai::model_supports_max_tokens; +use crate::providers::openai::model_needs_transforms; use crate::serde_json::Value; use crate::universal::{UniversalResponse, UniversalStreamChunk}; use thiserror::Error; @@ -509,11 +509,9 @@ fn needs_forced_translation(payload: &Value, model: Option<&str>, target: Provid #[cfg(feature = "openai")] { - // If the model doesn't support max_tokens, we need to force translation + // Force translation if model needs any transforms (temperature stripping, max_tokens conversion, etc.) let request_model = payload.get("model").and_then(Value::as_str).or(model); - request_model - .map(|m| !model_supports_max_tokens(m)) - .unwrap_or(false) + request_model.map(model_needs_transforms).unwrap_or(false) } #[cfg(not(feature = "openai"))] diff --git a/crates/lingua/src/providers/openai/adapter.rs b/crates/lingua/src/providers/openai/adapter.rs index b77a3d3..d4b278e 100644 --- a/crates/lingua/src/providers/openai/adapter.rs +++ b/crates/lingua/src/providers/openai/adapter.rs @@ -13,17 +13,10 @@ use crate::processing::adapters::{ insert_opt_bool, insert_opt_f64, insert_opt_i64, insert_opt_value, ProviderAdapter, }; use crate::processing::transform::TransformError; -use crate::providers::openai::capabilities::{OpenAICapabilities, TargetProvider}; +use crate::providers::openai::capabilities::apply_model_transforms; use crate::providers::openai::convert::{ ChatCompletionRequestMessageExt, ChatCompletionResponseMessageExt, }; -use crate::providers::openai::generated::{ - AllowedToolsFunction, ChatCompletionRequestMessageContent, - ChatCompletionRequestMessageContentPart, ChatCompletionRequestMessageRole, - ChatCompletionToolChoiceOption, CreateChatCompletionRequestClass, File, FunctionObject, - FunctionToolChoiceClass, FunctionToolChoiceType, PurpleType, ResponseFormatType, ToolElement, - ToolType, -}; use crate::providers::openai::params::OpenAIChatParams; use crate::providers::openai::try_parse_openai; use crate::serde_json::{self, Map, Value}; @@ -36,7 +29,6 @@ use crate::universal::{ parse_stop_sequences, UniversalParams, UniversalRequest, UniversalResponse, UniversalStreamChoice, UniversalStreamChunk, UniversalUsage, PLACEHOLDER_ID, PLACEHOLDER_MODEL, }; -use crate::util::media::parse_base64_data_url; use std::convert::TryInto; /// Adapter for OpenAI Chat Completions API. @@ -203,7 +195,6 @@ impl ProviderAdapter for OpenAIAdapter { .map_err(|e| TransformError::SerializationFailed(e.to_string()))?, ); - // Insert params insert_opt_f64(&mut obj, "temperature", req.params.temperature); insert_opt_f64(&mut obj, "top_p", req.params.top_p); insert_opt_i64(&mut obj, "max_completion_tokens", req.params.max_tokens); @@ -282,6 +273,8 @@ impl ProviderAdapter for OpenAIAdapter { } } + apply_model_transforms(model, &mut obj); + Ok(Value::Object(obj)) } @@ -592,272 +585,6 @@ fn build_reasoning_config( }) } -// ============================================================================= -// OpenAI Target-Specific Transformations -// ============================================================================= - -/// Error type for transformation operations. -#[derive(Debug, thiserror::Error)] -pub enum OpenAITransformError { - #[error("missing required field: {field}")] - MissingField { field: &'static str }, - #[error("invalid value: {message}")] - InvalidValue { message: String }, - #[error("unsupported feature: {feature}")] - Unsupported { feature: String }, - #[error("serialization failed: {0}")] - SerializationFailed(String), -} - -/// Apply target-specific transformations to an OpenAI-format request payload. -/// -/// This function applies transformations needed to make an OpenAI-format request -/// work with different target providers (Azure, Vertex, Mistral, etc.). -/// -/// # Arguments -/// -/// * `payload` - The OpenAI-format request payload (modified in place) -/// * `target_provider` - The target provider that will receive the request -/// * `provider_metadata` - Optional provider-specific metadata (e.g., api_version for Azure) -/// -/// # Returns -/// -/// The transformed payload, or an error if transformation fails. -pub fn apply_target_transforms( - payload: &Value, - target_provider: TargetProvider, - provider_metadata: Option<&Map>, -) -> Result { - // Parse as OpenAI request - let mut request: CreateChatCompletionRequestClass = serde_json::from_value(payload.clone()) - .map_err(|e| OpenAITransformError::SerializationFailed(e.to_string()))?; - - // Detect capabilities based on request and target - let capabilities = OpenAICapabilities::detect(&request, target_provider); - - // Apply reasoning model transformations - if capabilities.requires_reasoning_transforms() { - apply_reasoning_transforms(&mut request, &capabilities); - } - - // Apply provider-specific field sanitization - apply_provider_sanitization( - &mut request, - &capabilities, - target_provider, - provider_metadata, - ); - - // Apply model name normalization if needed - if capabilities.requires_model_normalization { - apply_model_normalization(&mut request, target_provider); - } - - // Normalize user messages (handle non-image base64 content) - normalize_user_messages(&mut request)?; - - // Apply response format transformations - apply_response_format(&mut request, &capabilities)?; - - // Serialize back to Value - serde_json::to_value(&request) - .map_err(|e| OpenAITransformError::SerializationFailed(e.to_string())) -} - -fn apply_reasoning_transforms( - request: &mut CreateChatCompletionRequestClass, - capabilities: &OpenAICapabilities, -) { - // Remove unsupported fields for reasoning models - request.temperature = None; - request.parallel_tool_calls = None; - - // For legacy o1 models, convert system messages to user messages - if capabilities.is_legacy_o1_model { - for message in &mut request.messages { - if matches!(message.role, ChatCompletionRequestMessageRole::System) { - message.role = ChatCompletionRequestMessageRole::User; - } - } - } -} - -fn apply_provider_sanitization( - request: &mut CreateChatCompletionRequestClass, - capabilities: &OpenAICapabilities, - target_provider: TargetProvider, - provider_metadata: Option<&Map>, -) { - // Remove stream_options for providers that don't support it - if !capabilities.supports_stream_options { - request.stream_options = None; - } - - // Remove parallel_tool_calls for providers that don't support it - if !capabilities.supports_parallel_tools { - request.parallel_tool_calls = None; - } - - // Remove seed field for Azure with API version - let has_api_version = provider_metadata - .and_then(|meta| meta.get("api_version")) - .is_some(); - - if capabilities.should_remove_seed_for_azure(target_provider, has_api_version) { - request.seed = None; - } -} - -fn apply_model_normalization( - request: &mut CreateChatCompletionRequestClass, - target_provider: TargetProvider, -) { - // Normalize Vertex model names - if target_provider == TargetProvider::Vertex { - if request.model.starts_with("publishers/meta/models/") { - // Strip to "meta/..." format - request.model = request - .model - .strip_prefix("publishers/") - .and_then(|s| s.strip_prefix("meta/models/")) - .map(|s| format!("meta/{}", s)) - .unwrap_or_else(|| request.model.clone()); - } else if let Some(stripped) = request.model.strip_prefix("publishers/") { - // Strip "publishers/X/models/Y" to "Y" - if let Some(model_part) = stripped.split("/models/").nth(1) { - request.model = model_part.to_string(); - } - } - } -} - -fn normalize_user_messages( - request: &mut CreateChatCompletionRequestClass, -) -> Result<(), OpenAITransformError> { - for message in &mut request.messages { - if matches!(message.role, ChatCompletionRequestMessageRole::User) { - if let Some( - ChatCompletionRequestMessageContent::ChatCompletionRequestMessageContentPartArray( - parts, - ), - ) = message.content.as_mut() - { - for part in parts.iter_mut() { - normalize_content_part(part)?; - } - } - } - } - Ok(()) -} - -fn normalize_content_part( - part: &mut ChatCompletionRequestMessageContentPart, -) -> Result<(), OpenAITransformError> { - if !matches!( - part.chat_completion_request_message_content_part_type, - PurpleType::ImageUrl - ) { - return Ok(()); - } - - let Some(image_url_value) = part - .image_url - .as_ref() - .map(|image_url| image_url.url.clone()) - else { - return Ok(()); - }; - - // Handle base64 data URLs - convert non-images to file type - if let Some(data_url) = parse_base64_data_url(&image_url_value) { - if !data_url.media_type.starts_with("image/") { - part.chat_completion_request_message_content_part_type = PurpleType::File; - part.image_url = None; - part.file = Some(File { - file_data: Some(image_url_value), - file_id: None, - filename: Some(if data_url.media_type == "application/pdf" { - "file_from_base64.pdf".to_string() - } else { - "file_from_base64".to_string() - }), - }); - } - } - - Ok(()) -} - -fn apply_response_format( - request: &mut CreateChatCompletionRequestClass, - capabilities: &OpenAICapabilities, -) -> Result<(), OpenAITransformError> { - let Some(response_format) = request.response_format.take() else { - return Ok(()); - }; - - match response_format.text_type { - ResponseFormatType::Text => Ok(()), - ResponseFormatType::JsonSchema => { - if capabilities.supports_native_structured_output { - request.response_format = Some(response_format); - return Ok(()); - } - - // Check if tools are already being used - if request - .tools - .as_ref() - .is_some_and(|tools| !tools.is_empty()) - || request.function_call.is_some() - || request.tool_choice.is_some() - { - return Err(OpenAITransformError::Unsupported { - feature: "tools_with_structured_output".to_string(), - }); - } - - // Convert json_schema to a tool call - match response_format.json_schema { - Some(schema) => { - request.tools = Some(vec![ToolElement { - function: Some(FunctionObject { - description: Some("Output the result in JSON format".to_string()), - name: "json".to_string(), - parameters: schema.schema.clone(), - strict: schema.strict, - }), - tool_type: ToolType::Function, - custom: None, - }]); - - request.tool_choice = - Some(ChatCompletionToolChoiceOption::FunctionToolChoiceClass( - FunctionToolChoiceClass { - allowed_tools: None, - allowed_tools_type: FunctionToolChoiceType::Function, - function: Some(AllowedToolsFunction { - name: "json".to_string(), - }), - custom: None, - }, - )); - - Ok(()) - } - None => Err(OpenAITransformError::InvalidValue { - message: "json_schema response_format is missing schema".to_string(), - }), - } - } - ResponseFormatType::JsonObject => { - request.response_format = Some(response_format); - Ok(()) - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -1258,4 +985,61 @@ mod tests { assert_eq!(thinking.get("type").unwrap(), "enabled"); assert_eq!(thinking.get("budget_tokens").unwrap(), 3000); } + + // ========================================================================= + // Temperature stripping for reasoning models + // ========================================================================= + + #[test] + fn test_openai_omits_temperature_for_reasoning_models() { + use crate::universal::message::UserContent; + + let adapter = OpenAIAdapter; + + // gpt-5-mini is a reasoning model - temperature should be omitted + let req = UniversalRequest { + model: Some("gpt-5-mini".to_string()), + messages: vec![Message::User { + content: UserContent::String("Hello".to_string()), + }], + params: UniversalParams { + temperature: Some(0.0), // User specified, but should be omitted + ..Default::default() + }, + }; + + let result = adapter.request_from_universal(&req).unwrap(); + + assert!( + result.get("temperature").is_none(), + "Temperature should be omitted for reasoning models (gpt-5-mini)" + ); + } + + #[test] + fn test_openai_preserves_temperature_for_non_reasoning_models() { + use crate::universal::message::UserContent; + + let adapter = OpenAIAdapter; + + // gpt-4 is not a reasoning model - temperature should be preserved + let req = UniversalRequest { + model: Some("gpt-4".to_string()), + messages: vec![Message::User { + content: UserContent::String("Hello".to_string()), + }], + params: UniversalParams { + temperature: Some(0.7), + ..Default::default() + }, + }; + + let result = adapter.request_from_universal(&req).unwrap(); + + assert_eq!( + result.get("temperature").unwrap().as_f64().unwrap(), + 0.7, + "Temperature should be preserved for non-reasoning models" + ); + } } diff --git a/crates/lingua/src/providers/openai/capabilities.rs b/crates/lingua/src/providers/openai/capabilities.rs index b9f47bd..727ba39 100644 --- a/crates/lingua/src/providers/openai/capabilities.rs +++ b/crates/lingua/src/providers/openai/capabilities.rs @@ -1,140 +1,62 @@ /*! OpenAI-specific capability detection used by the transformation pipeline. */ +use crate::serde_json::{Map, Value}; -use crate::providers::openai::generated::CreateChatCompletionRequestClass; - -/// Target provider that will receive a translated OpenAI payload. +/// Transforms required for specific model families. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TargetProvider { - OpenAI, - Azure, - Vertex, - Fireworks, - Mistral, - Databricks, - Lepton, - Other, -} - -impl std::str::FromStr for TargetProvider { - type Err = std::convert::Infallible; - - fn from_str(provider: &str) -> Result { - Ok(match provider { - "openai" => Self::OpenAI, - "azure" => Self::Azure, - "vertex" => Self::Vertex, - "fireworks" => Self::Fireworks, - "mistral" => Self::Mistral, - "databricks" => Self::Databricks, - "lepton" => Self::Lepton, - _ => Self::Other, - }) +pub enum ModelTransform { + /// Strip temperature parameter (reasoning models don't support it) + StripTemperature, + /// Convert max_tokens to max_completion_tokens + ForceMaxCompletionTokens, +} + +use ModelTransform::*; + +/// Model prefixes and their required transforms. +/// Order matters - more specific prefixes must come first. +const MODEL_TRANSFORM_RULES: &[(&str, &[ModelTransform])] = &[ + ("o1", &[StripTemperature, ForceMaxCompletionTokens]), + ("o3", &[StripTemperature, ForceMaxCompletionTokens]), + ("o4", &[StripTemperature, ForceMaxCompletionTokens]), + ("gpt-5", &[StripTemperature, ForceMaxCompletionTokens]), +]; + +/// Get the transforms required for a model. +pub fn get_model_transforms(model: &str) -> &'static [ModelTransform] { + let lower = model.to_ascii_lowercase(); + for (prefix, transforms) in MODEL_TRANSFORM_RULES { + if lower.starts_with(prefix) { + return transforms; + } } + &[] } -/// Capability view derived from a request/model combination. -#[derive(Debug, Clone)] -pub struct OpenAICapabilities { - pub uses_reasoning_mode: bool, - pub is_legacy_o1_model: bool, - pub supports_native_structured_output: bool, - - // Provider-specific limitations - pub supports_stream_options: bool, - pub supports_parallel_tools: bool, - pub supports_seed_field: bool, - pub requires_model_normalization: bool, +/// Check if a model requires any transforms. +pub fn model_needs_transforms(model: &str) -> bool { + !get_model_transforms(model).is_empty() } -impl OpenAICapabilities { - pub fn detect( - request: &CreateChatCompletionRequestClass, - target: TargetProvider, - ) -> OpenAICapabilities { - let model = request.model.to_ascii_lowercase(); - let uses_reasoning_mode = - request.reasoning_effort.is_some() || is_reasoning_model_name(&model); - - // Provider-specific capability detection - let ( - supports_stream_options, - supports_parallel_tools, - supports_seed_field, - requires_model_normalization, - ) = match target { - TargetProvider::Mistral => (false, false, true, false), - TargetProvider::Fireworks => (false, true, true, false), - TargetProvider::Databricks => (false, false, true, false), - TargetProvider::Azure => (true, false, true, false), - TargetProvider::Vertex => (true, true, true, true), - TargetProvider::OpenAI | TargetProvider::Lepton | TargetProvider::Other => { - (true, true, true, false) +/// Apply all transforms for a model to a request object. +pub fn apply_model_transforms(model: &str, obj: &mut Map) { + for transform in get_model_transforms(model) { + match transform { + StripTemperature => { + obj.remove("temperature"); + } + ForceMaxCompletionTokens => { + // (Responses API) max_output_tokens is valid. + if obj.contains_key("max_output_tokens") { + return; + } + + // (Chat Completions API) max_tokens is deprecated - convert to max_completion_tokens. + if let Some(max_tokens) = obj.remove("max_tokens") { + obj.entry("max_completion_tokens").or_insert(max_tokens); + } } - }; - - OpenAICapabilities { - uses_reasoning_mode, - is_legacy_o1_model: is_legacy_o1_model(&model), - supports_native_structured_output: supports_native_structured_output(&model, target), - supports_stream_options, - supports_parallel_tools, - supports_seed_field, - requires_model_normalization, } } - - pub fn requires_reasoning_transforms(&self) -> bool { - self.uses_reasoning_mode - } - - /// Check if seed field should be removed for Azure with API version - pub fn should_remove_seed_for_azure( - &self, - target: TargetProvider, - has_api_version: bool, - ) -> bool { - matches!(target, TargetProvider::Azure) && has_api_version - } -} - -/// Model prefixes that support native structured output. -const STRUCTURED_OUTPUT_PREFIXES: &[&str] = &["gpt", "o1", "o3"]; - -/// Model prefixes that indicate reasoning models. -const REASONING_MODEL_PREFIXES: &[&str] = &["o1", "o2", "o3", "o4", "gpt-5"]; - -/// Legacy o1 models that need special handling. -const LEGACY_O1_MODELS: &[&str] = &["o1-preview", "o1-mini", "o1-preview-2024-09-12"]; - -/// Model prefixes that do NOT support the `max_tokens` parameter. -/// These models require `max_completion_tokens` instead. -const MODELS_WITHOUT_MAX_TOKENS_SUPPORT: &[&str] = &["o1", "o3", "o4", "gpt-5"]; - -fn supports_native_structured_output(model: &str, target: TargetProvider) -> bool { - STRUCTURED_OUTPUT_PREFIXES - .iter() - .any(|prefix| model.starts_with(prefix)) - || matches!(target, TargetProvider::Fireworks) -} - -fn is_reasoning_model_name(model: &str) -> bool { - let lower = model.to_ascii_lowercase(); - REASONING_MODEL_PREFIXES - .iter() - .any(|prefix| lower.starts_with(prefix)) -} - -fn is_legacy_o1_model(model: &str) -> bool { - LEGACY_O1_MODELS.contains(&model) -} - -/// Check if a model supports the `max_tokens` parameter. -/// Returns false for models that require `max_completion_tokens` instead. -pub fn model_supports_max_tokens(model: &str) -> bool { - let lower = model.to_ascii_lowercase(); - !MODELS_WITHOUT_MAX_TOKENS_SUPPORT - .iter() - .any(|prefix| lower.starts_with(prefix)) } diff --git a/crates/lingua/src/providers/openai/mod.rs b/crates/lingua/src/providers/openai/mod.rs index 92c9fa1..c8eb57d 100644 --- a/crates/lingua/src/providers/openai/mod.rs +++ b/crates/lingua/src/providers/openai/mod.rs @@ -15,7 +15,7 @@ pub mod params; pub mod responses_adapter; // Re-export adapters and transformations -pub use adapter::{apply_target_transforms, OpenAIAdapter, OpenAITransformError}; +pub use adapter::OpenAIAdapter; pub use responses_adapter::ResponsesAdapter; #[cfg(test)] @@ -28,7 +28,7 @@ pub mod test_chat_completions; pub use detect::{try_parse_openai, try_parse_responses, DetectionError}; // Re-export capability functions -pub use capabilities::model_supports_max_tokens; +pub use capabilities::model_needs_transforms; // Re-export conversion functions and extension types pub use convert::{ diff --git a/crates/lingua/src/providers/openai/responses_adapter.rs b/crates/lingua/src/providers/openai/responses_adapter.rs index b394c87..634ba82 100644 --- a/crates/lingua/src/providers/openai/responses_adapter.rs +++ b/crates/lingua/src/providers/openai/responses_adapter.rs @@ -12,6 +12,7 @@ use crate::processing::adapters::{ insert_opt_bool, insert_opt_f64, insert_opt_i64, ProviderAdapter, }; use crate::processing::transform::TransformError; +use crate::providers::openai::capabilities::apply_model_transforms; use crate::providers::openai::generated::{ InputItem, InputItemContent, InputItemRole, InputItemType, Instructions, OutputItemType, }; @@ -29,14 +30,6 @@ use crate::universal::{ }; use std::convert::TryInto; -/// Check if a model is a reasoning model that doesn't support temperature. -fn is_reasoning_model(model: &str) -> bool { - let model_lower = model.to_lowercase(); - model_lower.starts_with("o1") - || model_lower.starts_with("o3") - || model_lower.starts_with("gpt-5") -} - fn system_text(message: &Message) -> Option<&str> { match message { Message::System { content } => match content { @@ -233,11 +226,7 @@ impl ProviderAdapter for ResponsesAdapter { .map_err(|e| TransformError::SerializationFailed(e.to_string()))?, ); - // Pass temperature through for non-reasoning models - // Reasoning models (o1-*, o3-*, gpt-5-*) don't support temperature - if !is_reasoning_model(model) { - insert_opt_f64(&mut obj, "temperature", req.params.temperature); - } + insert_opt_f64(&mut obj, "temperature", req.params.temperature); insert_opt_f64(&mut obj, "top_p", req.params.top_p); insert_opt_i64(&mut obj, "max_output_tokens", req.params.max_tokens); insert_opt_i64(&mut obj, "top_logprobs", req.params.top_logprobs); @@ -299,6 +288,9 @@ impl ProviderAdapter for ResponsesAdapter { } } + // Apply capability-based transforms (e.g., strip temperature for reasoning models) + apply_model_transforms(model, &mut obj); + Ok(Value::Object(obj)) } From 684e1a28f23997b976f83b2cacc2237e8955ee1c Mon Sep 17 00:00:00 2001 From: Ken Jiang Date: Wed, 4 Feb 2026 22:26:06 -0500 Subject: [PATCH 2/2] address pr comments --- .../src/providers/openai/capabilities.rs | 153 +++++++++++++++++- 1 file changed, 152 insertions(+), 1 deletion(-) diff --git a/crates/lingua/src/providers/openai/capabilities.rs b/crates/lingua/src/providers/openai/capabilities.rs index 727ba39..9e91f44 100644 --- a/crates/lingua/src/providers/openai/capabilities.rs +++ b/crates/lingua/src/providers/openai/capabilities.rs @@ -49,7 +49,7 @@ pub fn apply_model_transforms(model: &str, obj: &mut Map) { ForceMaxCompletionTokens => { // (Responses API) max_output_tokens is valid. if obj.contains_key("max_output_tokens") { - return; + continue; } // (Chat Completions API) max_tokens is deprecated - convert to max_completion_tokens. @@ -60,3 +60,154 @@ pub fn apply_model_transforms(model: &str, obj: &mut Map) { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::serde_json::json; + + #[test] + fn test_get_model_transforms() { + let cases = [ + ("o1", &[StripTemperature, ForceMaxCompletionTokens][..]), + ("o1-mini", &[StripTemperature, ForceMaxCompletionTokens][..]), + ("o3", &[StripTemperature, ForceMaxCompletionTokens][..]), + ( + "o4-preview", + &[StripTemperature, ForceMaxCompletionTokens][..], + ), + ( + "gpt-5-mini", + &[StripTemperature, ForceMaxCompletionTokens][..], + ), + ("gpt-4", &[][..]), + ("gpt-4o", &[][..]), + ("claude-3", &[][..]), + ]; + for (model, expected) in cases { + assert_eq!(get_model_transforms(model), expected, "model: {}", model); + } + } + + #[test] + fn test_model_needs_transforms() { + let needs = ["o1", "o3", "gpt-5"]; + let no_needs = ["gpt-4", "gpt-4o", "claude-3"]; + for model in needs { + assert!(model_needs_transforms(model), "should need: {}", model); + } + for model in no_needs { + assert!(!model_needs_transforms(model), "should not need: {}", model); + } + } + + #[test] + fn test_strip_temperature() { + let reasoning_models = ["o1", "o1-mini", "o3", "gpt-5-mini"]; + let non_reasoning_models = ["gpt-4", "gpt-4o", "claude-3"]; + + // Reasoning models: temperature should be stripped + for model in reasoning_models { + let mut obj = json!({ + "model": model, + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.7 + }) + .as_object() + .unwrap() + .clone(); + apply_model_transforms(model, &mut obj); + assert!( + !obj.contains_key("temperature"), + "{} should strip temperature", + model + ); + } + + // Non-reasoning models: temperature should be preserved + for model in non_reasoning_models { + let mut obj = json!({ + "model": model, + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.7 + }) + .as_object() + .unwrap() + .clone(); + apply_model_transforms(model, &mut obj); + assert!( + obj.contains_key("temperature"), + "{} should preserve temperature", + model + ); + } + } + + #[test] + fn test_force_max_completion_tokens() { + // Reasoning models: max_tokens → max_completion_tokens + for model in ["o1", "gpt-5"] { + let mut obj = json!({ + "model": model, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 100 + }) + .as_object() + .unwrap() + .clone(); + apply_model_transforms(model, &mut obj); + assert!( + obj.contains_key("max_completion_tokens"), + "{} should add max_completion_tokens", + model + ); + assert!( + !obj.contains_key("max_tokens"), + "{} should remove max_tokens", + model + ); + } + + // Non-reasoning models: max_tokens stays as-is + for model in ["gpt-4", "gpt-4o"] { + let mut obj = json!({ + "model": model, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 100 + }) + .as_object() + .unwrap() + .clone(); + apply_model_transforms(model, &mut obj); + assert!( + !obj.contains_key("max_completion_tokens"), + "{} should not add max_completion_tokens", + model + ); + assert!( + obj.contains_key("max_tokens"), + "{} should preserve max_tokens", + model + ); + } + + // max_output_tokens is valid for Responses API - not converted + let mut obj = json!({ + "model": "o3", + "input": [{"role": "user", "content": "Hello"}], + "max_output_tokens": 100 + }) + .as_object() + .unwrap() + .clone(); + apply_model_transforms("o3", &mut obj); + assert!( + obj.contains_key("max_output_tokens"), + "max_output_tokens should be preserved" + ); + assert!( + !obj.contains_key("max_completion_tokens"), + "should not convert max_output_tokens" + ); + } +}