diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4c26e33..87cedb9 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -46,10 +46,13 @@ add_ai_example(openrouter_example openrouter_example.cpp) # Tool calling examples add_ai_example(tool_calling_basic tool_calling_basic.cpp) -add_ai_example(tool_calling_multistep tool_calling_multistep.cpp) +add_ai_example(tool_calling_multistep tool_calling_multistep.cpp) add_ai_example(tool_calling_async tool_calling_async.cpp) add_ai_example(test_tool_integration test_tool_integration.cpp) +# Embeddings example +add_ai_example(embeddings_example embeddings_example.cpp) + # Component-specific examples add_subdirectory(components/openai) add_subdirectory(components/anthropic) diff --git a/examples/embeddings_example.cpp b/examples/embeddings_example.cpp new file mode 100644 index 0000000..05b4462 --- /dev/null +++ b/examples/embeddings_example.cpp @@ -0,0 +1,286 @@ +/** + * Embeddings Example - AI SDK C++ + * + * This example demonstrates how to use the embeddings API with the AI SDK. + * It shows how to: + * - Generate embeddings for single and multiple texts + * - Use different embedding models and dimensions + * - Calculate cosine similarity between embeddings + * - Handle errors and display results + * + * Usage: + * export OPENAI_API_KEY=your_key_here + * ./embeddings_example + */ + +#include +#include +#include +#include +#include + +#include +#include + +// Helper function to calculate cosine similarity between two embeddings +double cosine_similarity(const std::vector& a, + const std::vector& b) { + if (a.size() != b.size()) { + return 0.0; + } + + double dot_product = 0.0; + double norm_a = 0.0; + double norm_b = 0.0; + + for (size_t i = 0; i < a.size(); ++i) { + dot_product += a[i] * b[i]; + norm_a += a[i] * a[i]; + norm_b += b[i] * b[i]; + } + + if (norm_a == 0.0 || norm_b == 0.0) { + return 0.0; + } + + return dot_product / (std::sqrt(norm_a) * std::sqrt(norm_b)); +} + +// Helper function to extract embedding as a vector of doubles +std::vector extract_embedding(const nlohmann::json& data, + size_t index) { + std::vector embedding; + if (data.is_array() && index < data.size()) { + for (const auto& val : data[index]["embedding"]) { + embedding.push_back(val.get()); + } + } + return embedding; +} + +int main() { + std::cout << "AI SDK C++ - Embeddings Example\n"; + std::cout << "================================\n\n"; + + // Create OpenAI client + auto client = ai::openai::create_client(); + if (!client.is_valid()) { + std::cerr << "Error: Failed to create OpenAI client. Make sure " + "OPENAI_API_KEY is set.\n"; + return 1; + } + + // Example 1: Basic single text embedding + std::cout << "1. Single Text Embedding:\n"; + std::cout << "Text: \"Hello, world!\"\n\n"; + + nlohmann::json input1 = "Hello, world!"; + ai::EmbeddingOptions options1("text-embedding-3-small", input1); + auto result1 = client.embeddings(options1); + + if (result1) { + auto embedding = result1.data[0]["embedding"]; + std::cout << "✓ Successfully generated embedding\n"; + std::cout << " Dimensions: " << embedding.size() << "\n"; + std::cout << " Token usage: " << result1.usage.total_tokens << " tokens\n"; + std::cout << " First 5 values: ["; + for (size_t i = 0; i < std::min(size_t(5), embedding.size()); ++i) { + std::cout << std::fixed << std::setprecision(6) + << embedding[i].get(); + if (i < 4) + std::cout << ", "; + } + std::cout << ", ...]\n\n"; + } else { + std::cout << "✗ Error: " << result1.error_message() << "\n\n"; + } + + // Example 2: Multiple texts embedding + std::cout << "2. Multiple Texts Embedding:\n"; + nlohmann::json input2 = nlohmann::json::array( + {"sunny day at the beach", "rainy afternoon in the city", + "snowy night in the mountains"}); + + ai::EmbeddingOptions options2("text-embedding-3-small", input2); + auto result2 = client.embeddings(options2); + + if (result2) { + std::cout << "✓ Successfully generated " << result2.data.size() + << " embeddings\n"; + std::cout << " Token usage: " << result2.usage.total_tokens << " tokens\n"; + for (size_t i = 0; i < result2.data.size(); ++i) { + std::cout << " Embedding " << i + 1 + << " dimensions: " << result2.data[i]["embedding"].size() + << "\n"; + } + std::cout << "\n"; + } else { + std::cout << "✗ Error: " << result2.error_message() << "\n\n"; + } + + // Example 3: Embedding with custom dimensions + std::cout << "3. Custom Dimensions (512 instead of default 1536):\n"; + nlohmann::json input3 = "Testing custom dimensions"; + ai::EmbeddingOptions options3("text-embedding-3-small", input3, 512); + auto result3 = client.embeddings(options3); + + if (result3) { + auto embedding = result3.data[0]["embedding"]; + std::cout << "✓ Successfully generated embedding with custom dimensions\n"; + std::cout << " Dimensions: " << embedding.size() << " (requested: 512)\n"; + std::cout << " Token usage: " << result3.usage.total_tokens + << " tokens\n\n"; + } else { + std::cout << "✗ Error: " << result3.error_message() << "\n\n"; + } + + // Example 4: Semantic similarity between texts + std::cout << "4. Calculating Semantic Similarity:\n"; + nlohmann::json input4 = nlohmann::json::array( + {"cat", "kitten", "dog", "puppy", "car", "automobile"}); + + ai::EmbeddingOptions options4("text-embedding-3-small", input4); + auto result4 = client.embeddings(options4); + + if (result4) { + std::cout << "✓ Generated embeddings for similarity comparison\n\n"; + + // Extract embeddings + std::vector texts = {"cat", "kitten", "dog", + "puppy", "car", "automobile"}; + std::vector> embeddings; + + for (size_t i = 0; i < result4.data.size(); ++i) { + embeddings.push_back(extract_embedding(result4.data, i)); + } + + // Calculate and display similarities + std::cout << " Similarity scores (cosine similarity):\n"; + std::cout << " ----------------------------------------\n"; + std::cout << " cat ↔ kitten: " << std::fixed << std::setprecision(4) + << cosine_similarity(embeddings[0], embeddings[1]) << "\n"; + std::cout << " dog ↔ puppy: " << std::fixed << std::setprecision(4) + << cosine_similarity(embeddings[2], embeddings[3]) << "\n"; + std::cout << " car ↔ automobile: " << std::fixed << std::setprecision(4) + << cosine_similarity(embeddings[4], embeddings[5]) << "\n"; + std::cout << " cat ↔ dog: " << std::fixed << std::setprecision(4) + << cosine_similarity(embeddings[0], embeddings[2]) << "\n"; + std::cout << " cat ↔ car: " << std::fixed << std::setprecision(4) + << cosine_similarity(embeddings[0], embeddings[4]) << "\n\n"; + + std::cout + << " Note: Similar concepts have similarity scores closer to 1.0\n\n"; + } else { + std::cout << "✗ Error: " << result4.error_message() << "\n\n"; + } + + // Example 5: Using different embedding models + std::cout << "5. Comparing Different Embedding Models:\n"; + nlohmann::json input5 = "Artificial intelligence and machine learning"; + + // text-embedding-3-small + ai::EmbeddingOptions options5a("text-embedding-3-small", input5); + auto result5a = client.embeddings(options5a); + + if (result5a) { + std::cout << " text-embedding-3-small:\n"; + std::cout << " Dimensions: " << result5a.data[0]["embedding"].size() + << "\n"; + std::cout << " Token usage: " << result5a.usage.total_tokens + << " tokens\n"; + } + + // text-embedding-3-large + ai::EmbeddingOptions options5b("text-embedding-3-large", input5); + auto result5b = client.embeddings(options5b); + + if (result5b) { + std::cout << " text-embedding-3-large:\n"; + std::cout << " Dimensions: " << result5b.data[0]["embedding"].size() + << "\n"; + std::cout << " Token usage: " << result5b.usage.total_tokens + << " tokens\n"; + } + + std::cout << "\n"; + + // Example 6: Practical use case - Finding similar items + std::cout << "6. Practical Use Case - Finding Most Similar Item:\n"; + + std::string query = "I need a programming language for web development"; + std::vector documents = { + "Python is great for data science and machine learning", + "JavaScript is the language of the web and runs in browsers", + "C++ is perfect for high-performance systems programming", + "Java is widely used for enterprise applications", + "TypeScript adds types to JavaScript for better development"}; + + // Add query at the beginning + nlohmann::json input6 = nlohmann::json::array(); + input6.push_back(query); + for (const auto& doc : documents) { + input6.push_back(doc); + } + + ai::EmbeddingOptions options6("text-embedding-3-small", input6); + auto result6 = client.embeddings(options6); + + if (result6) { + std::cout << " Query: \"" << query << "\"\n\n"; + std::cout << " Similarity to documents:\n"; + std::cout << " ----------------------------------------\n"; + + // Extract query embedding + auto query_embedding = extract_embedding(result6.data, 0); + + // Calculate similarity to each document + std::vector> similarities; + for (size_t i = 0; i < documents.size(); ++i) { + auto doc_embedding = extract_embedding(result6.data, i + 1); + double sim = cosine_similarity(query_embedding, doc_embedding); + similarities.push_back({i, sim}); + } + + // Sort by similarity (highest first) + std::sort(similarities.begin(), similarities.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + + // Display results + for (size_t i = 0; i < similarities.size(); ++i) { + size_t idx = similarities[i].first; + double sim = similarities[i].second; + std::cout << " " << (i + 1) << ". [" << std::fixed + << std::setprecision(4) << sim << "] " << documents[idx] + << "\n"; + } + std::cout << "\n"; + } else { + std::cout << "✗ Error: " << result6.error_message() << "\n\n"; + } + + // Example 7: Error handling + std::cout << "7. Error Handling:\n"; + + // Test with invalid model + nlohmann::json input7 = "Test error handling"; + ai::EmbeddingOptions options7("invalid-model-name", input7); + auto result7 = client.embeddings(options7); + + if (!result7) { + std::cout << "✓ Error properly handled for invalid model:\n"; + std::cout << " Error message: " << result7.error_message() << "\n\n"; + } + + std::cout << "\nExample completed!\n"; + std::cout << "\nTips:\n"; + std::cout + << " - text-embedding-3-small: 1536 dimensions, faster and cheaper\n"; + std::cout << " - text-embedding-3-large: 3072 dimensions, higher quality\n"; + std::cout << " - Use custom dimensions to reduce vector storage size\n"; + std::cout << " - Cosine similarity scores closer to 1.0 indicate more " + "similar texts\n"; + std::cout << "\nMake sure to set your API key:\n"; + std::cout << " export OPENAI_API_KEY=your_openai_key\n"; + + return 0; +} diff --git a/include/ai/openai.h b/include/ai/openai.h index 6a9bd92..1277e5d 100644 --- a/include/ai/openai.h +++ b/include/ai/openai.h @@ -47,6 +47,7 @@ constexpr const char* kChatGpt4oLatest = "chatgpt-4o-latest"; /// Default model used when none is specified constexpr const char* kDefaultModel = kGpt4o; + } // namespace models /// Create an OpenAI client with default configuration @@ -82,4 +83,4 @@ Client create_client(const std::string& api_key, std::optional try_create_client(); } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/include/ai/types/client.h b/include/ai/types/client.h index dd5084d..a1193a5 100644 --- a/include/ai/types/client.h +++ b/include/ai/types/client.h @@ -1,5 +1,6 @@ #pragma once +#include "embedding_options.h" #include "generate_options.h" #include "stream_options.h" #include "stream_result.h" @@ -31,6 +32,12 @@ class Client { return GenerateResult("Client not initialized"); } + virtual EmbeddingResult embeddings(const EmbeddingOptions& options) { + if (pimpl_) + return pimpl_->embeddings(options); + return EmbeddingResult("Client not initialized"); + } + virtual StreamResult stream_text(const StreamOptions& options) { if (pimpl_) return pimpl_->stream_text(options); diff --git a/include/ai/types/embedding_options.h b/include/ai/types/embedding_options.h new file mode 100644 index 0000000..4ebb199 --- /dev/null +++ b/include/ai/types/embedding_options.h @@ -0,0 +1,78 @@ +#pragma once + +#include "enums.h" +#include "message.h" +#include "model.h" +#include "tool.h" +#include "usage.h" + +#include +#include +#include +#include + +namespace ai { + +struct EmbeddingOptions { + std::string model; + nlohmann::json input; + std::optional dimensions; + std::optional encoding_format; + std::optional user; // Optional user identifier for OpenAI + + EmbeddingOptions() = default; + + EmbeddingOptions(std::string model_name, nlohmann::json input_) + : model(std::move(model_name)), input(std::move(input_)) {} + + EmbeddingOptions(std::string model_name, + nlohmann::json input_, + int dimensions_) + : model(std::move(model_name)), + input(std::move(input_)), + dimensions(dimensions_) {} + + EmbeddingOptions(std::string model_name, + nlohmann::json input_, + int dimensions_, + std::string encoding_format_) + : model(std::move(model_name)), + input(std::move(input_)), + dimensions(dimensions_), + encoding_format(std::move(encoding_format_)) {} + + bool is_valid() const { return !model.empty() && !input.is_null(); } + + bool has_input() const { return !input.is_null(); } +}; + +struct EmbeddingResult { + nlohmann::json data; + Usage usage; + + /// Additional metadata (like TypeScript SDK) + std::optional model; + + /// Error handling + std::optional error; + std::optional is_retryable; + + /// Provider-specific metadata + std::optional provider_metadata; + + EmbeddingResult() = default; + + // EmbeddingResult(std::string data_, Usage token_usage) + // : data(std::move(data_)), usage(token_usage) {} + + explicit EmbeddingResult(std::optional error_message) + : error(std::move(error_message)) {} + + bool is_success() const { return !error.has_value(); } + + explicit operator bool() const { return is_success(); } + + std::string error_message() const { return error.value_or(""); } +}; + +} // namespace ai diff --git a/include/ai/types/generate_options.h b/include/ai/types/generate_options.h index 8176d4e..bfa27f8 100644 --- a/include/ai/types/generate_options.h +++ b/include/ai/types/generate_options.h @@ -18,6 +18,7 @@ struct GenerateOptions { std::string system; std::string prompt; Messages messages; + std::optional response_format{}; std::optional max_tokens; std::optional temperature; std::optional top_p; @@ -46,6 +47,15 @@ struct GenerateOptions { system(std::move(system_prompt)), prompt(std::move(user_prompt)) {} + GenerateOptions(std::string model_name, + std::string system_prompt, + std::string user_prompt, + std::optional response_format_) + : model(std::move(model_name)), + system(std::move(system_prompt)), + prompt(std::move(user_prompt)), + response_format(std::move(response_format_)) {} + GenerateOptions(std::string model_name, Messages conversation) : model(std::move(model_name)), messages(std::move(conversation)) {} diff --git a/src/providers/anthropic/anthropic_client.cpp b/src/providers/anthropic/anthropic_client.cpp index 7684717..bedc098 100644 --- a/src/providers/anthropic/anthropic_client.cpp +++ b/src/providers/anthropic/anthropic_client.cpp @@ -18,7 +18,8 @@ AnthropicClient::AnthropicClient(const std::string& api_key, providers::ProviderConfig{ .api_key = api_key, .base_url = base_url, - .endpoint_path = "/v1/messages", + .completions_endpoint_path = "/v1/messages", + .embeddings_endpoint_path = "/v1/embeddings", .auth_header_name = "x-api-key", .auth_header_prefix = "", .extra_headers = {{"anthropic-version", "2023-06-01"}}}, @@ -44,8 +45,8 @@ StreamResult AnthropicClient::stream_text(const StreamOptions& options) { // Create stream implementation auto impl = std::make_unique(); - impl->start_stream(config_.base_url + config_.endpoint_path, headers, - request_json); + impl->start_stream(config_.base_url + config_.completions_endpoint_path, + headers, request_json); ai::logger::log_info("Text streaming started - model: {}", options.model); @@ -77,4 +78,4 @@ std::string AnthropicClient::default_model() const { } } // namespace anthropic -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/anthropic/anthropic_client.h b/src/providers/anthropic/anthropic_client.h index bb9bb08..f66e432 100644 --- a/src/providers/anthropic/anthropic_client.h +++ b/src/providers/anthropic/anthropic_client.h @@ -29,4 +29,4 @@ class AnthropicClient : public providers::BaseProviderClient { }; } // namespace anthropic -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/anthropic/anthropic_request_builder.cpp b/src/providers/anthropic/anthropic_request_builder.cpp index 8a29c45..2933c58 100644 --- a/src/providers/anthropic/anthropic_request_builder.cpp +++ b/src/providers/anthropic/anthropic_request_builder.cpp @@ -157,19 +157,28 @@ nlohmann::json AnthropicRequestBuilder::build_request_json( return request; } +nlohmann::json AnthropicRequestBuilder::build_request_json( + const EmbeddingOptions& options) { + // Note: Anthropic does not currently offer embeddings API + // This is a placeholder for future compatibility or custom endpoints + nlohmann::json request{{"model", options.model}, {"input", options.input}}; + return request; +} + httplib::Headers AnthropicRequestBuilder::build_headers( const providers::ProviderConfig& config) { httplib::Headers headers = { - {config.auth_header_name, config.auth_header_prefix + config.api_key}, - {"Content-Type", "application/json"}}; + {config.auth_header_name, config.auth_header_prefix + config.api_key}}; // Add any extra headers for (const auto& [key, value] : config.extra_headers) { headers.emplace(key, value); } + // Note: Content-Type is passed separately to httplib::Post() as content_type + // parameter return headers; } } // namespace anthropic -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/anthropic/anthropic_request_builder.h b/src/providers/anthropic/anthropic_request_builder.h index bab9fb1..464dc41 100644 --- a/src/providers/anthropic/anthropic_request_builder.h +++ b/src/providers/anthropic/anthropic_request_builder.h @@ -1,5 +1,6 @@ #pragma once +#include "ai/types/embedding_options.h" #include "ai/types/generate_options.h" #include "providers/base_provider_client.h" @@ -11,6 +12,7 @@ namespace anthropic { class AnthropicRequestBuilder : public providers::RequestBuilder { public: nlohmann::json build_request_json(const GenerateOptions& options) override; + nlohmann::json build_request_json(const EmbeddingOptions& options) override; httplib::Headers build_headers( const providers::ProviderConfig& config) override; }; diff --git a/src/providers/anthropic/anthropic_response_parser.cpp b/src/providers/anthropic/anthropic_response_parser.cpp index 1235713..4789730 100644 --- a/src/providers/anthropic/anthropic_response_parser.cpp +++ b/src/providers/anthropic/anthropic_response_parser.cpp @@ -6,7 +6,7 @@ namespace ai { namespace anthropic { -GenerateResult AnthropicResponseParser::parse_success_response( +GenerateResult AnthropicResponseParser::parse_success_completion_response( const nlohmann::json& response) { ai::logger::log_debug("Parsing Anthropic messages response"); @@ -86,12 +86,52 @@ GenerateResult AnthropicResponseParser::parse_success_response( return result; } -GenerateResult AnthropicResponseParser::parse_error_response( +GenerateResult AnthropicResponseParser::parse_error_completion_response( int status_code, const std::string& body) { return utils::parse_standard_error_response("Anthropic", status_code, body); } +EmbeddingResult AnthropicResponseParser::parse_success_embedding_response( + const nlohmann::json& response) { + ai::logger::log_debug("Parsing Anthropic embeddings response"); + + EmbeddingResult result; + + // Extract basic fields + result.model = response.value("model", ""); + + // Extract choices + if (response.contains("data") && !response["data"].empty()) { + result.data = std::move(response["data"]); + } + + // Extract usage + if (response.contains("usage")) { + auto& usage = response["usage"]; + result.usage.prompt_tokens = usage.value("prompt_tokens", 0); + result.usage.completion_tokens = usage.value("completion_tokens", 0); + result.usage.total_tokens = usage.value("total_tokens", 0); + ai::logger::log_debug("Token usage - prompt: {}, completion: {}, total: {}", + result.usage.prompt_tokens, + result.usage.completion_tokens, + result.usage.total_tokens); + } + + // Store full metadata + result.provider_metadata = response.dump(); + + return result; +} + +EmbeddingResult AnthropicResponseParser::parse_error_embedding_response( + int status_code, + const std::string& body) { + auto generate_result = + utils::parse_standard_error_response("Anthropic", status_code, body); + return EmbeddingResult(generate_result.error); +} + FinishReason AnthropicResponseParser::parse_stop_reason( const std::string& reason) { if (reason == "end_turn") { diff --git a/src/providers/anthropic/anthropic_response_parser.h b/src/providers/anthropic/anthropic_response_parser.h index 822cd70..1df91d6 100644 --- a/src/providers/anthropic/anthropic_response_parser.h +++ b/src/providers/anthropic/anthropic_response_parser.h @@ -10,10 +10,16 @@ namespace anthropic { class AnthropicResponseParser : public providers::ResponseParser { public: - GenerateResult parse_success_response( + GenerateResult parse_success_completion_response( const nlohmann::json& response) override; - GenerateResult parse_error_response(int status_code, - const std::string& body) override; + GenerateResult parse_error_completion_response( + int status_code, + const std::string& body) override; + EmbeddingResult parse_success_embedding_response( + const nlohmann::json& response) override; + EmbeddingResult parse_error_embedding_response( + int status_code, + const std::string& body) override; private: static FinishReason parse_stop_reason(const std::string& reason); diff --git a/src/providers/base_provider_client.cpp b/src/providers/base_provider_client.cpp index e734fae..9484724 100644 --- a/src/providers/base_provider_client.cpp +++ b/src/providers/base_provider_client.cpp @@ -24,8 +24,10 @@ BaseProviderClient::BaseProviderClient( http_handler_ = std::make_unique(http_config); ai::logger::log_debug( - "BaseProviderClient initialized - base_url: {}, endpoint: {}", - config.base_url, config.endpoint_path); + R"(BaseProviderClient initialized - base_url: {}, + completions_endpoint: {}, embeddings_endpoint: {})", + config.base_url, config.completions_endpoint_path, + config.embeddings_endpoint_path); } GenerateResult BaseProviderClient::generate_text( @@ -64,14 +66,14 @@ GenerateResult BaseProviderClient::generate_text_single_step( auto headers = request_builder_->build_headers(config_); // Make the request - auto result = - http_handler_->post(config_.endpoint_path, headers, json_body); + auto result = http_handler_->post(config_.completions_endpoint_path, + headers, json_body); if (!result.is_success()) { // Parse error response using provider-specific parser if (result.provider_metadata.has_value()) { int status_code = std::stoi(result.provider_metadata.value()); - return response_parser_->parse_error_response( + return response_parser_->parse_error_completion_response( status_code, result.error.value_or("")); } return result; @@ -94,7 +96,7 @@ GenerateResult BaseProviderClient::generate_text_single_step( // Parse using provider-specific parser auto parsed_result = - response_parser_->parse_success_response(json_response); + response_parser_->parse_success_completion_response(json_response); if (parsed_result.has_tool_calls()) { ai::logger::log_debug("Model made {} tool calls", @@ -144,5 +146,55 @@ StreamResult BaseProviderClient::stream_text(const StreamOptions& options) { return StreamResult(); } +EmbeddingResult BaseProviderClient::embeddings( + const EmbeddingOptions& options) { + try { + // Build request JSON using the provider-specific builder + auto request_json = request_builder_->build_request_json(options); + std::string json_body = request_json.dump(); + ai::logger::log_debug("Request JSON built: {}", json_body); + + // Build headers + auto headers = request_builder_->build_headers(config_); + + // Make the requests + auto result = http_handler_->post(config_.embeddings_endpoint_path, headers, + json_body); + + if (!result.is_success()) { + // Parse error response using provider-specific parser + if (result.provider_metadata.has_value()) { + int status_code = std::stoi(result.provider_metadata.value()); + return response_parser_->parse_error_embedding_response( + status_code, result.error.value_or("")); + } + return EmbeddingResult(result.error); + } + + // Parse the response JSON from result.text + nlohmann::json json_response; + try { + json_response = nlohmann::json::parse(result.text); + } catch (const nlohmann::json::exception& e) { + ai::logger::log_error("Failed to parse response JSON: {}", e.what()); + ai::logger::log_debug("Raw response text: {}", result.text); + return EmbeddingResult("Failed to parse response: " + + std::string(e.what())); + } + + ai::logger::log_info("Embeddings successful - model: {}, response_id: {}", + options.model, json_response.value("id", "unknown")); + + // Parse using provider-specific parser + auto parsed_result = + response_parser_->parse_success_embedding_response(json_response); + return parsed_result; + + } catch (const std::exception& e) { + ai::logger::log_error("Exception during embeddings: {}", e.what()); + return EmbeddingResult(std::string("Exception: ") + e.what()); + } +} + } // namespace providers } // namespace ai \ No newline at end of file diff --git a/src/providers/base_provider_client.h b/src/providers/base_provider_client.h index 34fc3cd..47f5e75 100644 --- a/src/providers/base_provider_client.h +++ b/src/providers/base_provider_client.h @@ -2,6 +2,7 @@ #include "ai/retry/retry_policy.h" #include "ai/types/client.h" +#include "ai/types/embedding_options.h" #include "ai/types/generate_options.h" #include "ai/types/stream_options.h" #include "http/http_request_handler.h" @@ -18,7 +19,8 @@ namespace providers { struct ProviderConfig { std::string api_key; std::string base_url; - std::string endpoint_path; // e.g., "/v1/chat/completions" or "/v1/messages" + std::string completions_endpoint_path; // e.g. "/v1/chat/completions" + std::string embeddings_endpoint_path; std::string auth_header_name; // e.g., "Authorization" or "x-api-key" std::string auth_header_prefix; // e.g., "Bearer " or "" httplib::Headers extra_headers; // Additional headers like anthropic-version @@ -32,6 +34,8 @@ class RequestBuilder { public: virtual ~RequestBuilder() = default; virtual nlohmann::json build_request_json(const GenerateOptions& options) = 0; + virtual nlohmann::json build_request_json( + const EmbeddingOptions& options) = 0; virtual httplib::Headers build_headers(const ProviderConfig& config) = 0; }; @@ -39,10 +43,16 @@ class RequestBuilder { class ResponseParser { public: virtual ~ResponseParser() = default; - virtual GenerateResult parse_success_response( + virtual GenerateResult parse_success_completion_response( const nlohmann::json& response) = 0; - virtual GenerateResult parse_error_response(int status_code, - const std::string& body) = 0; + virtual GenerateResult parse_error_completion_response( + int status_code, + const std::string& body) = 0; + virtual EmbeddingResult parse_success_embedding_response( + const nlohmann::json& response) = 0; + virtual EmbeddingResult parse_error_embedding_response( + int status_code, + const std::string& body) = 0; }; // Base client that uses composition to share common functionality @@ -55,6 +65,7 @@ class BaseProviderClient : public Client { // Implements the common flow using the composed components GenerateResult generate_text(const GenerateOptions& options) override; StreamResult stream_text(const StreamOptions& options) override; + EmbeddingResult embeddings(const EmbeddingOptions& options) override; bool is_valid() const override { return !config_.api_key.empty(); } diff --git a/src/providers/openai/openai_client.cpp b/src/providers/openai/openai_client.cpp index 28a4b9b..e2a5baf 100644 --- a/src/providers/openai/openai_client.cpp +++ b/src/providers/openai/openai_client.cpp @@ -15,12 +15,14 @@ namespace openai { OpenAIClient::OpenAIClient(const std::string& api_key, const std::string& base_url) : BaseProviderClient( - providers::ProviderConfig{.api_key = api_key, - .base_url = base_url, - .endpoint_path = "/v1/chat/completions", - .auth_header_name = "Authorization", - .auth_header_prefix = "Bearer ", - .extra_headers = {}}, + providers::ProviderConfig{ + .api_key = api_key, + .base_url = base_url, + .completions_endpoint_path = "/v1/chat/completions", + .embeddings_endpoint_path = "/v1/embeddings", + .auth_header_name = "Authorization", + .auth_header_prefix = "Bearer ", + .extra_headers = {}}, std::make_unique(), std::make_unique()) { ai::logger::log_debug("OpenAI client initialized with base_url: {}", @@ -31,13 +33,15 @@ OpenAIClient::OpenAIClient(const std::string& api_key, const std::string& base_url, const retry::RetryConfig& retry_config) : BaseProviderClient( - providers::ProviderConfig{.api_key = api_key, - .base_url = base_url, - .endpoint_path = "/v1/chat/completions", - .auth_header_name = "Authorization", - .auth_header_prefix = "Bearer ", - .extra_headers = {}, - .retry_config = retry_config}, + providers::ProviderConfig{ + .api_key = api_key, + .base_url = base_url, + .completions_endpoint_path = "/v1/chat/completions", + .embeddings_endpoint_path = "/v1/embeddings", + .auth_header_name = "Authorization", + .auth_header_prefix = "Bearer ", + .extra_headers = {}, + .retry_config = retry_config}, std::make_unique(), std::make_unique()) { ai::logger::log_debug( @@ -61,8 +65,8 @@ StreamResult OpenAIClient::stream_text(const StreamOptions& options) { // Create stream implementation auto impl = std::make_unique(); - impl->start_stream(config_.base_url + config_.endpoint_path, headers, - request_json); + impl->start_stream(config_.base_url + config_.completions_endpoint_path, + headers, request_json); ai::logger::log_info("Text streaming started - model: {}", options.model); @@ -94,4 +98,4 @@ std::string OpenAIClient::default_model() const { } } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/openai/openai_client.h b/src/providers/openai/openai_client.h index e3ea93e..12df8bc 100644 --- a/src/providers/openai/openai_client.h +++ b/src/providers/openai/openai_client.h @@ -33,4 +33,4 @@ class OpenAIClient : public providers::BaseProviderClient { }; } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/openai/openai_request_builder.cpp b/src/providers/openai/openai_request_builder.cpp index b9d40b2..5249567 100644 --- a/src/providers/openai/openai_request_builder.cpp +++ b/src/providers/openai/openai_request_builder.cpp @@ -11,6 +11,9 @@ nlohmann::json OpenAIRequestBuilder::build_request_json( nlohmann::json request{{"model", options.model}, {"messages", nlohmann::json::array()}}; + if (options.response_format) { + request["response_format"] = options.response_format.value(); + } // Build messages array if (!options.messages.empty()) { // Use provided messages @@ -164,6 +167,30 @@ nlohmann::json OpenAIRequestBuilder::build_request_json( return request; } +nlohmann::json OpenAIRequestBuilder::build_request_json( + const EmbeddingOptions& options) { + nlohmann::json request{{"model", options.model}, {"input", options.input}}; + + // Set encoding format (default to float for compatibility) + if (options.encoding_format) { + request["encoding_format"] = options.encoding_format.value(); + } else { + request["encoding_format"] = "float"; + } + + // Add dimensions if specified + if (options.dimensions && options.dimensions.value() > 0) { + request["dimensions"] = options.dimensions.value(); + } + + // Add user identifier if specified + if (options.user) { + request["user"] = options.user.value(); + } + + return request; +} + httplib::Headers OpenAIRequestBuilder::build_headers( const providers::ProviderConfig& config) { httplib::Headers headers = { @@ -174,6 +201,8 @@ httplib::Headers OpenAIRequestBuilder::build_headers( headers.emplace(key, value); } + // Note: Content-Type is passed separately to httplib::Post() as content_type + // parameter return headers; } diff --git a/src/providers/openai/openai_request_builder.h b/src/providers/openai/openai_request_builder.h index 5317769..53839c0 100644 --- a/src/providers/openai/openai_request_builder.h +++ b/src/providers/openai/openai_request_builder.h @@ -1,5 +1,6 @@ #pragma once +#include "ai/types/embedding_options.h" #include "ai/types/generate_options.h" #include "providers/base_provider_client.h" @@ -11,6 +12,7 @@ namespace openai { class OpenAIRequestBuilder : public providers::RequestBuilder { public: nlohmann::json build_request_json(const GenerateOptions& options) override; + nlohmann::json build_request_json(const EmbeddingOptions& options) override; httplib::Headers build_headers( const providers::ProviderConfig& config) override; }; diff --git a/src/providers/openai/openai_response_parser.cpp b/src/providers/openai/openai_response_parser.cpp index 90c520a..0274956 100644 --- a/src/providers/openai/openai_response_parser.cpp +++ b/src/providers/openai/openai_response_parser.cpp @@ -6,7 +6,7 @@ namespace ai { namespace openai { -GenerateResult OpenAIResponseParser::parse_success_response( +GenerateResult OpenAIResponseParser::parse_success_completion_response( const nlohmann::json& response) { ai::logger::log_debug("Parsing OpenAI chat completion response"); @@ -128,12 +128,52 @@ GenerateResult OpenAIResponseParser::parse_success_response( return result; } -GenerateResult OpenAIResponseParser::parse_error_response( +GenerateResult OpenAIResponseParser::parse_error_completion_response( int status_code, const std::string& body) { return utils::parse_standard_error_response("OpenAI", status_code, body); } +EmbeddingResult OpenAIResponseParser::parse_success_embedding_response( + const nlohmann::json& response) { + ai::logger::log_debug("Parsing OpenAI embeddings response"); + + EmbeddingResult result; + + // Extract basic fields + result.model = response.value("model", ""); + + // Extract choices + if (response.contains("data") && !response["data"].empty()) { + result.data = std::move(response["data"]); + } + + // Extract usage + if (response.contains("usage")) { + auto& usage = response["usage"]; + result.usage.prompt_tokens = usage.value("prompt_tokens", 0); + result.usage.completion_tokens = usage.value("completion_tokens", 0); + result.usage.total_tokens = usage.value("total_tokens", 0); + ai::logger::log_debug("Token usage - prompt: {}, completion: {}, total: {}", + result.usage.prompt_tokens, + result.usage.completion_tokens, + result.usage.total_tokens); + } + + // Store full metadata + result.provider_metadata = response.dump(); + + return result; +} + +EmbeddingResult OpenAIResponseParser::parse_error_embedding_response( + int status_code, + const std::string& body) { + auto generate_result = + utils::parse_standard_error_response("OpenAI", status_code, body); + return EmbeddingResult(generate_result.error); +} + FinishReason OpenAIResponseParser::parse_finish_reason( const std::string& reason) { if (reason == "stop") { diff --git a/src/providers/openai/openai_response_parser.h b/src/providers/openai/openai_response_parser.h index 14d62d9..ff9923e 100644 --- a/src/providers/openai/openai_response_parser.h +++ b/src/providers/openai/openai_response_parser.h @@ -10,10 +10,16 @@ namespace openai { class OpenAIResponseParser : public providers::ResponseParser { public: - GenerateResult parse_success_response( + GenerateResult parse_success_completion_response( const nlohmann::json& response) override; - GenerateResult parse_error_response(int status_code, - const std::string& body) override; + GenerateResult parse_error_completion_response( + int status_code, + const std::string& body) override; + EmbeddingResult parse_success_embedding_response( + const nlohmann::json& response) override; + EmbeddingResult parse_error_embedding_response( + int status_code, + const std::string& body) override; private: static FinishReason parse_finish_reason(const std::string& reason); diff --git a/src/types/embedding_options.cpp b/src/types/embedding_options.cpp new file mode 100644 index 0000000..77e54a4 --- /dev/null +++ b/src/types/embedding_options.cpp @@ -0,0 +1,7 @@ +#include "ai/types/embedding_options.h" + +namespace ai { + +// Implementation details for EmbeddingOptions if needed + +} // namespace ai \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 119240b..4e7cadd 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -6,14 +6,16 @@ add_executable(ai_tests unit/types_test.cpp unit/openai_stream_test.cpp unit/anthropic_stream_test.cpp - + unit/openai_embeddings_test.cpp + # Integration tests integration/openai_integration_test.cpp integration/anthropic_integration_test.cpp integration/tool_calling_integration_test.cpp integration/clickhouse_integration_test.cpp + integration/openai_embeddings_integration_test.cpp integration/multi_step_duplicate_execution_test.cpp - + # Utility classes utils/mock_openai_client.cpp utils/mock_anthropic_client.cpp diff --git a/tests/integration/openai_embeddings_integration_test.cpp b/tests/integration/openai_embeddings_integration_test.cpp new file mode 100644 index 0000000..7244ffc --- /dev/null +++ b/tests/integration/openai_embeddings_integration_test.cpp @@ -0,0 +1,411 @@ +#include "../utils/test_fixtures.h" +#include "ai/openai.h" +#include "ai/types/embedding_options.h" + +#include +#include + +#include +#include + +// Note: These tests connect to the real OpenAI API when OPENAI_API_KEY is set +// Otherwise they are skipped + +namespace ai { +namespace test { + +class OpenAIEmbeddingsIntegrationTest : public AITestFixture { + protected: + void SetUp() override { + AITestFixture::SetUp(); + + // Check if we should run real API tests + const char* api_key = std::getenv("OPENAI_API_KEY"); + + if (api_key != nullptr) { + use_real_api_ = true; + client_ = ai::openai::create_client(api_key); + } else { + use_real_api_ = false; + // Skip tests if no API key is available + } + } + + void TearDown() override { AITestFixture::TearDown(); } + + // Helper to calculate cosine similarity between two embeddings + double cosine_similarity(const std::vector& a, + const std::vector& b) { + if (a.size() != b.size()) { + return 0.0; + } + + double dot_product = 0.0; + double norm_a = 0.0; + double norm_b = 0.0; + + for (size_t i = 0; i < a.size(); ++i) { + dot_product += a[i] * b[i]; + norm_a += a[i] * a[i]; + norm_b += b[i] * b[i]; + } + + if (norm_a == 0.0 || norm_b == 0.0) { + return 0.0; + } + + return dot_product / (std::sqrt(norm_a) * std::sqrt(norm_b)); + } + + bool use_real_api_; + std::optional client_; +}; + +// Basic Embeddings Tests +TEST_F(OpenAIEmbeddingsIntegrationTest, BasicSingleStringEmbedding) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = "Hello, world!"; + EmbeddingOptions options("text-embedding-3-small", input); + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_FALSE(result.data.is_null()); + EXPECT_TRUE(result.data.is_array()); + EXPECT_EQ(result.data.size(), 1); + + // Check that we got an embedding vector + auto embedding = result.data[0]["embedding"]; + EXPECT_TRUE(embedding.is_array()); + EXPECT_GT(embedding.size(), 0); + + // text-embedding-3-small should have 1536 dimensions by default + EXPECT_EQ(embedding.size(), 1536); + + // Check token usage + EXPECT_GT(result.usage.total_tokens, 0); + EXPECT_GT(result.usage.prompt_tokens, 0); +} + +TEST_F(OpenAIEmbeddingsIntegrationTest, MultipleStringsEmbedding) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = nlohmann::json::array( + {"sunny day at the beach", "rainy afternoon in the city", + "snowy night in the mountains"}); + + EmbeddingOptions options("text-embedding-3-small", input); + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_TRUE(result.data.is_array()); + EXPECT_EQ(result.data.size(), 3); + + // Check each embedding + for (size_t i = 0; i < 3; ++i) { + auto embedding = result.data[i]["embedding"]; + EXPECT_TRUE(embedding.is_array()); + EXPECT_EQ(embedding.size(), 1536); + } + + // Check token usage + EXPECT_GT(result.usage.total_tokens, 0); +} + +TEST_F(OpenAIEmbeddingsIntegrationTest, EmbeddingWithCustomDimensions) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = "Test with custom dimensions"; + EmbeddingOptions options("text-embedding-3-small", input, 512); + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_TRUE(result.data.is_array()); + EXPECT_EQ(result.data.size(), 1); + + // Check that the embedding has the requested dimensions + auto embedding = result.data[0]["embedding"]; + EXPECT_TRUE(embedding.is_array()); + EXPECT_EQ(embedding.size(), 512); +} + +TEST_F(OpenAIEmbeddingsIntegrationTest, EmbeddingWithLargeModel) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = "Testing large embedding model"; + EmbeddingOptions options("text-embedding-3-large", input); + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_TRUE(result.data.is_array()); + EXPECT_EQ(result.data.size(), 1); + + // text-embedding-3-large should have 3072 dimensions by default + auto embedding = result.data[0]["embedding"]; + EXPECT_TRUE(embedding.is_array()); + EXPECT_EQ(embedding.size(), 3072); +} + +TEST_F(OpenAIEmbeddingsIntegrationTest, EmbeddingSimilarity) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = + nlohmann::json::array({"cat", "kitten", "dog", "puppy", "car"}); + + EmbeddingOptions options("text-embedding-3-small", input); + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_EQ(result.data.size(), 5); + + // Convert embeddings to vectors + std::vector> embeddings; + for (const auto& item : result.data) { + std::vector embedding; + for (const auto& val : item["embedding"]) { + embedding.push_back(val.get()); + } + embeddings.push_back(embedding); + } + + // Calculate similarities + double cat_kitten_sim = cosine_similarity(embeddings[0], embeddings[1]); + double dog_puppy_sim = cosine_similarity(embeddings[2], embeddings[3]); + double cat_car_sim = cosine_similarity(embeddings[0], embeddings[4]); + + // Similar words should have higher similarity than unrelated words + // Note: Single words have moderate similarity (~0.5-0.6), not as high as full + // sentences + EXPECT_GT(cat_kitten_sim, 0.5) << "cat and kitten should be similar"; + EXPECT_GT(dog_puppy_sim, 0.5) << "dog and puppy should be similar"; + EXPECT_LT(cat_car_sim, cat_kitten_sim) + << "cat and car should be less similar than cat and kitten"; +} + +// Error Handling Tests +TEST_F(OpenAIEmbeddingsIntegrationTest, InvalidApiKey) { + auto invalid_client = ai::openai::create_client("sk-invalid123"); + + nlohmann::json input = "Test with invalid key"; + EmbeddingOptions options("text-embedding-3-small", input); + auto result = invalid_client.embeddings(options); + + EXPECT_FALSE(result.is_success()); + EXPECT_THAT(result.error_message(), + testing::AnyOf(testing::HasSubstr("401"), + testing::HasSubstr("Unauthorized"), + testing::HasSubstr("API key"), + testing::HasSubstr("authentication"))); +} + +TEST_F(OpenAIEmbeddingsIntegrationTest, InvalidModel) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = "Test with invalid model"; + EmbeddingOptions options("invalid-embedding-model", input); + auto result = client_->embeddings(options); + + EXPECT_FALSE(result.is_success()); + EXPECT_THAT( + result.error_message(), + testing::AnyOf(testing::HasSubstr("404"), testing::HasSubstr("model"), + testing::HasSubstr("not found"), + testing::HasSubstr("does not exist"))); +} + +// Edge Cases +TEST_F(OpenAIEmbeddingsIntegrationTest, EmptyStringEmbedding) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = ""; + EmbeddingOptions options("text-embedding-3-small", input); + auto result = client_->embeddings(options); + + // OpenAI API should handle empty strings + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_TRUE(result.data.is_array()); + EXPECT_EQ(result.data.size(), 1); +} + +TEST_F(OpenAIEmbeddingsIntegrationTest, LongTextEmbedding) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + // Create a reasonably long text (but within token limits) + std::string long_text = "This is a test sentence. "; + for (int i = 0; i < 50; ++i) { + long_text += "This is a test sentence. "; + } + + nlohmann::json input = long_text; + EmbeddingOptions options("text-embedding-3-small", input); + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_TRUE(result.data.is_array()); + EXPECT_EQ(result.data.size(), 1); + + auto embedding = result.data[0]["embedding"]; + EXPECT_TRUE(embedding.is_array()); + EXPECT_EQ(embedding.size(), 1536); + + // Should use more tokens for longer text + EXPECT_GT(result.usage.prompt_tokens, 50); +} + +TEST_F(OpenAIEmbeddingsIntegrationTest, SpecialCharactersEmbedding) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = + nlohmann::json::array({"Hello! How are you?", "¡Hola! ¿Cómo estás?", + "你好!你好吗?", "🌟✨🎉"}); + + EmbeddingOptions options("text-embedding-3-small", input); + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_EQ(result.data.size(), 4); + + // All embeddings should have the correct size + for (const auto& item : result.data) { + auto embedding = item["embedding"]; + EXPECT_TRUE(embedding.is_array()); + EXPECT_EQ(embedding.size(), 1536); + } +} + +// Configuration Tests +TEST_F(OpenAIEmbeddingsIntegrationTest, EmbeddingWithUserIdentifier) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = "Test with user identifier"; + EmbeddingOptions options("text-embedding-3-small", input); + options.user = "test-user-123"; + + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_TRUE(result.data.is_array()); + EXPECT_EQ(result.data.size(), 1); +} + +TEST_F(OpenAIEmbeddingsIntegrationTest, EmbeddingWithBase64Encoding) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = "Test with base64 encoding"; + EmbeddingOptions options("text-embedding-3-small", input); + options.encoding_format = "base64"; + + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_TRUE(result.data.is_array()); + EXPECT_EQ(result.data.size(), 1); + + // With base64 encoding, the embedding should be a string + auto embedding = result.data[0]["embedding"]; + EXPECT_TRUE(embedding.is_string()); +} + +// Token Usage Tests +TEST_F(OpenAIEmbeddingsIntegrationTest, TokenUsageTracking) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = nlohmann::json::array( + {"Short text", "This is a slightly longer text with more words"}); + + EmbeddingOptions options("text-embedding-3-small", input); + auto result = client_->embeddings(options); + + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + + // Check token usage is properly tracked + EXPECT_GT(result.usage.prompt_tokens, 0); + EXPECT_GT(result.usage.total_tokens, 0); + EXPECT_EQ(result.usage.total_tokens, result.usage.prompt_tokens); + + // Longer text should use more tokens (approximately) + EXPECT_GT(result.usage.prompt_tokens, 5); +} + +// Network Error Tests +TEST_F(OpenAIEmbeddingsIntegrationTest, NetworkFailure) { + const char* api_key = std::getenv("OPENAI_API_KEY"); + if (!api_key) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + // Test with localhost on unused port to simulate connection refused + auto failing_client = + ai::openai::create_client(api_key, "http://localhost:59999"); + + nlohmann::json input = "Test network failure"; + EmbeddingOptions options("text-embedding-3-small", input); + auto result = failing_client.embeddings(options); + + EXPECT_FALSE(result.is_success()); + EXPECT_THAT(result.error_message(), + testing::AnyOf( + testing::HasSubstr("Network"), testing::HasSubstr("network"), + testing::HasSubstr("connection"), + testing::HasSubstr("refused"), testing::HasSubstr("failed"))); +} + +// Different Models Test +TEST_F(OpenAIEmbeddingsIntegrationTest, DifferentEmbeddingModels) { + if (!use_real_api_) { + GTEST_SKIP() << "No OPENAI_API_KEY environment variable set"; + } + + nlohmann::json input = "Test different models"; + + // Test text-embedding-3-small + { + EmbeddingOptions options("text-embedding-3-small", input); + auto result = client_->embeddings(options); + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_EQ(result.data[0]["embedding"].size(), 1536); + } + + // Test text-embedding-3-large + { + EmbeddingOptions options("text-embedding-3-large", input); + auto result = client_->embeddings(options); + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_EQ(result.data[0]["embedding"].size(), 3072); + } + + // Test text-embedding-ada-002 (legacy model) + { + EmbeddingOptions options("text-embedding-ada-002", input); + auto result = client_->embeddings(options); + ASSERT_TRUE(result.is_success()) << "Error: " << result.error_message(); + EXPECT_EQ(result.data[0]["embedding"].size(), 1536); + } +} + +} // namespace test +} // namespace ai diff --git a/tests/unit/openai_embeddings_test.cpp b/tests/unit/openai_embeddings_test.cpp new file mode 100644 index 0000000..fd626d2 --- /dev/null +++ b/tests/unit/openai_embeddings_test.cpp @@ -0,0 +1,241 @@ +#include +#include + +// Include the OpenAI client headers +#include "ai/types/embedding_options.h" + +// Include the real OpenAI client implementation for testing +#include "providers/openai/openai_client.h" + +// Test utilities +#include "../utils/test_fixtures.h" + +namespace ai { +namespace test { + +class OpenAIEmbeddingsTest : public OpenAITestFixture { + protected: + void SetUp() override { + OpenAITestFixture::SetUp(); + client_ = + std::make_unique(kTestApiKey, kTestBaseUrl); + } + + void TearDown() override { + client_.reset(); + OpenAITestFixture::TearDown(); + } + + std::unique_ptr client_; +}; + +// EmbeddingOptions Constructor and Validation Tests +TEST_F(OpenAIEmbeddingsTest, EmbeddingOptionsDefaultConstructor) { + EmbeddingOptions options; + + EXPECT_FALSE(options.is_valid()); + EXPECT_FALSE(options.has_input()); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingOptionsBasicConstructor) { + nlohmann::json input = "test text"; + EmbeddingOptions options("text-embedding-3-small", input); + + EXPECT_TRUE(options.is_valid()); + EXPECT_TRUE(options.has_input()); + EXPECT_EQ(options.model, "text-embedding-3-small"); + EXPECT_EQ(options.input, input); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingOptionsWithDimensions) { + nlohmann::json input = "test text"; + EmbeddingOptions options("text-embedding-3-small", input, 512); + + EXPECT_TRUE(options.is_valid()); + EXPECT_EQ(options.dimensions, 512); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingOptionsWithEncodingFormat) { + nlohmann::json input = "test text"; + EmbeddingOptions options("text-embedding-3-small", input, 512, "float"); + + EXPECT_TRUE(options.is_valid()); + EXPECT_EQ(options.dimensions, 512); + EXPECT_EQ(options.encoding_format, "float"); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingOptionsWithArrayInput) { + nlohmann::json input = + nlohmann::json::array({"first text", "second text", "third text"}); + EmbeddingOptions options("text-embedding-3-small", input); + + EXPECT_TRUE(options.is_valid()); + EXPECT_TRUE(options.has_input()); + EXPECT_TRUE(options.input.is_array()); + EXPECT_EQ(options.input.size(), 3); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingOptionsInvalidEmptyModel) { + nlohmann::json input = "test text"; + EmbeddingOptions options("", input); + + EXPECT_FALSE(options.is_valid()); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingOptionsInvalidNullInput) { + nlohmann::json input = nlohmann::json(); + EmbeddingOptions options("text-embedding-3-small", input); + + EXPECT_FALSE(options.is_valid()); + EXPECT_FALSE(options.has_input()); +} + +// EmbeddingResult Tests +TEST_F(OpenAIEmbeddingsTest, EmbeddingResultDefaultConstructor) { + EmbeddingResult result; + + EXPECT_TRUE(result.is_success()); + EXPECT_TRUE(result.data.is_null()); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingResultWithError) { + EmbeddingResult result("Test error message"); + + EXPECT_FALSE(result.is_success()); + EXPECT_EQ(result.error_message(), "Test error message"); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingResultBoolConversion) { + EmbeddingResult success_result; + EXPECT_TRUE(static_cast(success_result)); + + EmbeddingResult error_result("error"); + EXPECT_FALSE(static_cast(error_result)); +} + +// Client Tests - Testing error handling without network calls +TEST_F(OpenAIEmbeddingsTest, EmbeddingsWithInvalidApiKey) { + ai::openai::OpenAIClient client("invalid-key", "https://api.openai.com"); + + nlohmann::json input = "test text"; + EmbeddingOptions options("text-embedding-3-small", input); + + // This will attempt a real call and should fail gracefully + auto result = client.embeddings(options); + + // We expect this to fail since we're using an invalid API key + EXPECT_FALSE(result.is_success()); + EXPECT_FALSE(result.error_message().empty()); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingsWithBadUrl) { + ai::openai::OpenAIClient client( + "sk-test", "http://invalid-url-that-does-not-exist.example"); + + nlohmann::json input = "test text"; + EmbeddingOptions options("text-embedding-3-small", input); + + // This should fail due to network connectivity + auto result = client.embeddings(options); + + EXPECT_FALSE(result.is_success()); + EXPECT_FALSE(result.error_message().empty()); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingsWithInvalidOptions) { + nlohmann::json input = nlohmann::json(); // null input + EmbeddingOptions invalid_options("", input); + + EXPECT_FALSE(invalid_options.is_valid()); +} + +// Test option validation +TEST_F(OpenAIEmbeddingsTest, ValidateEmbeddingOptionsValidation) { + // Test with empty model + nlohmann::json input = "test"; + EmbeddingOptions invalid_options("", input); + EXPECT_FALSE(invalid_options.is_valid()); + + // Test with null input + EmbeddingOptions invalid_input_options("text-embedding-3-small", + nlohmann::json()); + EXPECT_FALSE(invalid_input_options.is_valid()); + + // Test with valid options + EmbeddingOptions valid_options("text-embedding-3-small", input); + EXPECT_TRUE(valid_options.is_valid()); +} + +// Test different input formats +TEST_F(OpenAIEmbeddingsTest, EmbeddingsWithSingleString) { + nlohmann::json input = "This is a test string for embeddings"; + EmbeddingOptions options("text-embedding-3-small", input); + + EXPECT_TRUE(options.is_valid()); + EXPECT_TRUE(input.is_string()); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingsWithMultipleStrings) { + nlohmann::json input = + nlohmann::json::array({"First embedding text", "Second embedding text", + "Third embedding text"}); + EmbeddingOptions options("text-embedding-3-small", input); + + EXPECT_TRUE(options.is_valid()); + EXPECT_TRUE(input.is_array()); + EXPECT_EQ(input.size(), 3); +} + +// Test optional parameters +TEST_F(OpenAIEmbeddingsTest, EmbeddingsWithOptionalUser) { + nlohmann::json input = "test text"; + EmbeddingOptions options("text-embedding-3-small", input); + options.user = "test-user-123"; + + EXPECT_TRUE(options.is_valid()); + EXPECT_TRUE(options.user.has_value()); + EXPECT_EQ(options.user.value(), "test-user-123"); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingsWithCustomDimensions) { + nlohmann::json input = "test text"; + EmbeddingOptions options("text-embedding-3-large", input); + options.dimensions = 1024; + + EXPECT_TRUE(options.is_valid()); + EXPECT_TRUE(options.dimensions.has_value()); + EXPECT_EQ(options.dimensions.value(), 1024); +} + +TEST_F(OpenAIEmbeddingsTest, EmbeddingsWithBase64Encoding) { + nlohmann::json input = "test text"; + EmbeddingOptions options("text-embedding-3-small", input); + options.encoding_format = "base64"; + + EXPECT_TRUE(options.is_valid()); + EXPECT_TRUE(options.encoding_format.has_value()); + EXPECT_EQ(options.encoding_format.value(), "base64"); +} + +// Test different embedding models +TEST_F(OpenAIEmbeddingsTest, EmbeddingsWithDifferentModels) { + nlohmann::json input = "test text"; + + // Test with text-embedding-3-small + EmbeddingOptions small_options("text-embedding-3-small", input); + EXPECT_TRUE(small_options.is_valid()); + EXPECT_EQ(small_options.model, "text-embedding-3-small"); + + // Test with text-embedding-3-large + EmbeddingOptions large_options("text-embedding-3-large", input); + EXPECT_TRUE(large_options.is_valid()); + EXPECT_EQ(large_options.model, "text-embedding-3-large"); + + // Test with text-embedding-ada-002 + EmbeddingOptions ada_options("text-embedding-ada-002", input); + EXPECT_TRUE(ada_options.is_valid()); + EXPECT_EQ(ada_options.model, "text-embedding-ada-002"); +} + +} // namespace test +} // namespace ai