diff --git a/include/ai/openai.h b/include/ai/openai.h index 6a9bd92..1277e5d 100644 --- a/include/ai/openai.h +++ b/include/ai/openai.h @@ -47,6 +47,7 @@ constexpr const char* kChatGpt4oLatest = "chatgpt-4o-latest"; /// Default model used when none is specified constexpr const char* kDefaultModel = kGpt4o; + } // namespace models /// Create an OpenAI client with default configuration @@ -82,4 +83,4 @@ Client create_client(const std::string& api_key, std::optional try_create_client(); } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/include/ai/types/client.h b/include/ai/types/client.h index dd5084d..c49c570 100644 --- a/include/ai/types/client.h +++ b/include/ai/types/client.h @@ -3,6 +3,7 @@ #include "generate_options.h" #include "stream_options.h" #include "stream_result.h" +#include "embeddding_options.h" #include #include @@ -31,6 +32,12 @@ class Client { return GenerateResult("Client not initialized"); } + virtual EmbeddingResult embeddings(const EmbeddingOptions& options) { + if (pimpl_) + return pimpl_->embeddings(options); + return EmbeddingResult("Client not initialized"); + } + virtual StreamResult stream_text(const StreamOptions& options) { if (pimpl_) return pimpl_->stream_text(options); diff --git a/include/ai/types/embeddding_options.h b/include/ai/types/embeddding_options.h new file mode 100644 index 0000000..f3fd3eb --- /dev/null +++ b/include/ai/types/embeddding_options.h @@ -0,0 +1,84 @@ +#pragma once + +#include "enums.h" +#include "message.h" +#include "model.h" +#include "tool.h" +#include "usage.h" + +#include +#include +#include +#include + +namespace ai { + +struct EmbeddingOptions { + std::string model; + nlohmann::json input; + std::optional dimensions; + std::optional encoding_format; + std::optional max_tokens; + std::optional temperature; + std::optional top_p; + std::optional seed; + std::optional frequency_penalty; + std::optional presence_penalty; + + EmbeddingOptions(std::string model_name, nlohmann::json input_) + : model(std::move(model_name)), + input(std::move(input_)) {} + + EmbeddingOptions(std::string model_name, nlohmann::json input_, int dimensions_) + : model(std::move(model_name)), + input(std::move(input_)), + dimensions(dimensions_) {} + + EmbeddingOptions(std::string model_name, nlohmann::json input_, int dimensions_, std::string encoding_format_) + : model(std::move(model_name)), + input(std::move(input_)), + dimensions(dimensions_), + encoding_format(std::move(encoding_format_)) {} + + EmbeddingOptions() = default; + + bool is_valid() const { + return !model.empty() && !input.empty(); + } + + bool has_input() const { return !input.empty(); } + +}; + +struct EmbeddingResult { + nlohmann::json data; + Usage usage; + + /// Additional metadata (like TypeScript SDK) + std::optional model; + + /// Error handling + std::optional error; + std::optional is_retryable; + + /// Provider-specific metadata + std::optional provider_metadata; + + EmbeddingResult() = default; + + // EmbeddingResult(std::string data_, Usage token_usage) + // : data(std::move(data_)), usage(token_usage) {} + + explicit EmbeddingResult(std::optional error_message) + : error(std::move(error_message)) {} + + bool is_success() const { + return !error.has_value(); + } + + explicit operator bool() const { return is_success(); } + + std::string error_message() const { return error.value_or(""); } +}; + +} // namespace ai diff --git a/include/ai/types/generate_options.h b/include/ai/types/generate_options.h index 8176d4e..832c06e 100644 --- a/include/ai/types/generate_options.h +++ b/include/ai/types/generate_options.h @@ -18,6 +18,7 @@ struct GenerateOptions { std::string system; std::string prompt; Messages messages; + std::optional response_format {}; std::optional max_tokens; std::optional temperature; std::optional top_p; @@ -46,6 +47,15 @@ struct GenerateOptions { system(std::move(system_prompt)), prompt(std::move(user_prompt)) {} + GenerateOptions(std::string model_name, + std::string system_prompt, + std::string user_prompt, + std::optional response_format_) + : model(std::move(model_name)), + system(std::move(system_prompt)), + prompt(std::move(user_prompt)), + response_format(std::move(response_format_)) {} + GenerateOptions(std::string model_name, Messages conversation) : model(std::move(model_name)), messages(std::move(conversation)) {} diff --git a/src/providers/anthropic/anthropic_client.cpp b/src/providers/anthropic/anthropic_client.cpp index 7684717..21b8f44 100644 --- a/src/providers/anthropic/anthropic_client.cpp +++ b/src/providers/anthropic/anthropic_client.cpp @@ -18,7 +18,8 @@ AnthropicClient::AnthropicClient(const std::string& api_key, providers::ProviderConfig{ .api_key = api_key, .base_url = base_url, - .endpoint_path = "/v1/messages", + .completions_endpoint_path = "/v1/messages", + .embeddings_endpoint_path = "/v1/embeddings", .auth_header_name = "x-api-key", .auth_header_prefix = "", .extra_headers = {{"anthropic-version", "2023-06-01"}}}, @@ -44,8 +45,8 @@ StreamResult AnthropicClient::stream_text(const StreamOptions& options) { // Create stream implementation auto impl = std::make_unique(); - impl->start_stream(config_.base_url + config_.endpoint_path, headers, - request_json); + impl->start_stream(config_.base_url + config_.completions_endpoint_path, + headers, request_json); ai::logger::log_info("Text streaming started - model: {}", options.model); @@ -77,4 +78,4 @@ std::string AnthropicClient::default_model() const { } } // namespace anthropic -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/anthropic/anthropic_client.h b/src/providers/anthropic/anthropic_client.h index bb9bb08..f66e432 100644 --- a/src/providers/anthropic/anthropic_client.h +++ b/src/providers/anthropic/anthropic_client.h @@ -29,4 +29,4 @@ class AnthropicClient : public providers::BaseProviderClient { }; } // namespace anthropic -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/anthropic/anthropic_request_builder.cpp b/src/providers/anthropic/anthropic_request_builder.cpp index 8a29c45..a15df27 100644 --- a/src/providers/anthropic/anthropic_request_builder.cpp +++ b/src/providers/anthropic/anthropic_request_builder.cpp @@ -157,6 +157,12 @@ nlohmann::json AnthropicRequestBuilder::build_request_json( return request; } +nlohmann::json AnthropicRequestBuilder::build_request_json(const EmbeddingOptions& options) { + nlohmann::json request{{"model", options.model}, + {"input", options.input}}; + return request; +} + httplib::Headers AnthropicRequestBuilder::build_headers( const providers::ProviderConfig& config) { httplib::Headers headers = { @@ -172,4 +178,4 @@ httplib::Headers AnthropicRequestBuilder::build_headers( } } // namespace anthropic -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/anthropic/anthropic_request_builder.h b/src/providers/anthropic/anthropic_request_builder.h index bab9fb1..c7f6387 100644 --- a/src/providers/anthropic/anthropic_request_builder.h +++ b/src/providers/anthropic/anthropic_request_builder.h @@ -11,6 +11,7 @@ namespace anthropic { class AnthropicRequestBuilder : public providers::RequestBuilder { public: nlohmann::json build_request_json(const GenerateOptions& options) override; + nlohmann::json build_request_json(const EmbeddingOptions& options) override; httplib::Headers build_headers( const providers::ProviderConfig& config) override; }; diff --git a/src/providers/anthropic/anthropic_response_parser.cpp b/src/providers/anthropic/anthropic_response_parser.cpp index 1235713..0c2e7f4 100644 --- a/src/providers/anthropic/anthropic_response_parser.cpp +++ b/src/providers/anthropic/anthropic_response_parser.cpp @@ -6,7 +6,7 @@ namespace ai { namespace anthropic { -GenerateResult AnthropicResponseParser::parse_success_response( +GenerateResult AnthropicResponseParser::parse_success_completion_response( const nlohmann::json& response) { ai::logger::log_debug("Parsing Anthropic messages response"); @@ -86,12 +86,49 @@ GenerateResult AnthropicResponseParser::parse_success_response( return result; } -GenerateResult AnthropicResponseParser::parse_error_response( +GenerateResult AnthropicResponseParser::parse_error_completion_response( int status_code, const std::string& body) { return utils::parse_standard_error_response("Anthropic", status_code, body); } +EmbeddingResult AnthropicResponseParser::parse_success_embedding_response(const nlohmann::json& response) { + ai::logger::log_debug("Parsing Anthropic embeddings response"); + + EmbeddingResult result; + + // Extract basic fields + result.model = response.value("model", ""); + + // Extract choices + if (response.contains("data") && !response["data"].empty()) { + result.data = std::move(response["data"]); + } + + // Extract usage + if (response.contains("usage")) { + auto& usage = response["usage"]; + result.usage.prompt_tokens = usage.value("prompt_tokens", 0); + result.usage.completion_tokens = usage.value("completion_tokens", 0); + result.usage.total_tokens = usage.value("total_tokens", 0); + ai::logger::log_debug("Token usage - prompt: {}, completion: {}, total: {}", + result.usage.prompt_tokens, + result.usage.completion_tokens, + result.usage.total_tokens); + } + + // Store full metadata + result.provider_metadata = response.dump(); + + return result; +} + +EmbeddingResult AnthropicResponseParser::parse_error_embedding_response(int status_code, const std::string& body) { + auto generate_result = utils::parse_standard_error_response("Anthropic", status_code, body); + return EmbeddingResult(generate_result.error); +} + + FinishReason AnthropicResponseParser::parse_stop_reason( const std::string& reason) { if (reason == "end_turn") { diff --git a/src/providers/anthropic/anthropic_response_parser.h b/src/providers/anthropic/anthropic_response_parser.h index 822cd70..e4de998 100644 --- a/src/providers/anthropic/anthropic_response_parser.h +++ b/src/providers/anthropic/anthropic_response_parser.h @@ -10,9 +10,13 @@ namespace anthropic { class AnthropicResponseParser : public providers::ResponseParser { public: - GenerateResult parse_success_response( + GenerateResult parse_success_completion_response( const nlohmann::json& response) override; - GenerateResult parse_error_response(int status_code, + GenerateResult parse_error_completion_response(int status_code, + const std::string& body) override; + EmbeddingResult parse_success_embedding_response( + const nlohmann::json& response) override; + EmbeddingResult parse_error_embedding_response(int status_code, const std::string& body) override; private: diff --git a/src/providers/base_provider_client.cpp b/src/providers/base_provider_client.cpp index e734fae..5e47c2d 100644 --- a/src/providers/base_provider_client.cpp +++ b/src/providers/base_provider_client.cpp @@ -24,8 +24,10 @@ BaseProviderClient::BaseProviderClient( http_handler_ = std::make_unique(http_config); ai::logger::log_debug( - "BaseProviderClient initialized - base_url: {}, endpoint: {}", - config.base_url, config.endpoint_path); + R"(BaseProviderClient initialized - base_url: {}, + completions_endpoint: {}, embeddings_endpoint: {})", + config.base_url, config.completions_endpoint_path, + config.embeddings_endpoint_path); } GenerateResult BaseProviderClient::generate_text( @@ -65,13 +67,13 @@ GenerateResult BaseProviderClient::generate_text_single_step( // Make the request auto result = - http_handler_->post(config_.endpoint_path, headers, json_body); + http_handler_->post(config_.completions_endpoint_path, headers, json_body); if (!result.is_success()) { // Parse error response using provider-specific parser if (result.provider_metadata.has_value()) { int status_code = std::stoi(result.provider_metadata.value()); - return response_parser_->parse_error_response( + return response_parser_->parse_error_completion_response( status_code, result.error.value_or("")); } return result; @@ -94,7 +96,7 @@ GenerateResult BaseProviderClient::generate_text_single_step( // Parse using provider-specific parser auto parsed_result = - response_parser_->parse_success_response(json_response); + response_parser_->parse_success_completion_response(json_response); if (parsed_result.has_tool_calls()) { ai::logger::log_debug("Model made {} tool calls", @@ -144,5 +146,55 @@ StreamResult BaseProviderClient::stream_text(const StreamOptions& options) { return StreamResult(); } +EmbeddingResult BaseProviderClient::embeddings(const EmbeddingOptions& options) { + try { + // Build request JSON using the provider-specific builder + auto request_json = request_builder_->build_request_json(options); + std::string json_body = request_json.dump(); + ai::logger::log_debug("Request JSON built: {}", json_body); + + // Build headers + auto headers = request_builder_->build_headers(config_); + + // Make the requests + auto result = + http_handler_->post(config_.embeddings_endpoint_path, headers, json_body); + + if (!result.is_success()) { + // Parse error response using provider-specific parser + if (result.provider_metadata.has_value()) { + int status_code = std::stoi(result.provider_metadata.value()); + return response_parser_->parse_error_embedding_response( + status_code, result.error.value_or("")); + } + return EmbeddingResult(result.error); + } + + // Parse the response JSON from result.text + nlohmann::json json_response; + try { + json_response = nlohmann::json::parse(result.text); + } catch (const nlohmann::json::exception& e) { + ai::logger::log_error("Failed to parse response JSON: {}", e.what()); + ai::logger::log_debug("Raw response text: {}", result.text); + return EmbeddingResult("Failed to parse response: " + + std::string(e.what())); + } + + ai::logger::log_info( + "Embeddings successful - model: {}, response_id: {}", + options.model, json_response.value("id", "unknown")); + + // Parse using provider-specific parser + auto parsed_result = + response_parser_->parse_success_embedding_response(json_response); + return parsed_result; + + } catch (const std::exception& e) { + ai::logger::log_error("Exception during embeddings: {}", e.what()); + return EmbeddingResult(std::string("Exception: ") + e.what()); + } +} + } // namespace providers } // namespace ai \ No newline at end of file diff --git a/src/providers/base_provider_client.h b/src/providers/base_provider_client.h index 34fc3cd..6ca19a0 100644 --- a/src/providers/base_provider_client.h +++ b/src/providers/base_provider_client.h @@ -18,7 +18,8 @@ namespace providers { struct ProviderConfig { std::string api_key; std::string base_url; - std::string endpoint_path; // e.g., "/v1/chat/completions" or "/v1/messages" + std::string completions_endpoint_path; // e.g. "/v1/chat/completions" + std::string embeddings_endpoint_path; std::string auth_header_name; // e.g., "Authorization" or "x-api-key" std::string auth_header_prefix; // e.g., "Bearer " or "" httplib::Headers extra_headers; // Additional headers like anthropic-version @@ -32,6 +33,7 @@ class RequestBuilder { public: virtual ~RequestBuilder() = default; virtual nlohmann::json build_request_json(const GenerateOptions& options) = 0; + virtual nlohmann::json build_request_json(const EmbeddingOptions& options) = 0; virtual httplib::Headers build_headers(const ProviderConfig& config) = 0; }; @@ -39,9 +41,13 @@ class RequestBuilder { class ResponseParser { public: virtual ~ResponseParser() = default; - virtual GenerateResult parse_success_response( + virtual GenerateResult parse_success_completion_response( const nlohmann::json& response) = 0; - virtual GenerateResult parse_error_response(int status_code, + virtual GenerateResult parse_error_completion_response(int status_code, + const std::string& body) = 0; + virtual EmbeddingResult parse_success_embedding_response( + const nlohmann::json& response) = 0; + virtual EmbeddingResult parse_error_embedding_response(int status_code, const std::string& body) = 0; }; @@ -55,6 +61,7 @@ class BaseProviderClient : public Client { // Implements the common flow using the composed components GenerateResult generate_text(const GenerateOptions& options) override; StreamResult stream_text(const StreamOptions& options) override; + EmbeddingResult embeddings(const EmbeddingOptions& options) override; bool is_valid() const override { return !config_.api_key.empty(); } diff --git a/src/providers/openai/openai_client.cpp b/src/providers/openai/openai_client.cpp index 28a4b9b..6876dfb 100644 --- a/src/providers/openai/openai_client.cpp +++ b/src/providers/openai/openai_client.cpp @@ -17,7 +17,8 @@ OpenAIClient::OpenAIClient(const std::string& api_key, : BaseProviderClient( providers::ProviderConfig{.api_key = api_key, .base_url = base_url, - .endpoint_path = "/v1/chat/completions", + .completions_endpoint_path = "/v1/chat/completions", + .embeddings_endpoint_path = "/v1/embeddings", .auth_header_name = "Authorization", .auth_header_prefix = "Bearer ", .extra_headers = {}}, @@ -33,7 +34,8 @@ OpenAIClient::OpenAIClient(const std::string& api_key, : BaseProviderClient( providers::ProviderConfig{.api_key = api_key, .base_url = base_url, - .endpoint_path = "/v1/chat/completions", + .completions_endpoint_path = "/v1/chat/completions", + .embeddings_endpoint_path = "/v1/embeddings", .auth_header_name = "Authorization", .auth_header_prefix = "Bearer ", .extra_headers = {}, @@ -61,7 +63,7 @@ StreamResult OpenAIClient::stream_text(const StreamOptions& options) { // Create stream implementation auto impl = std::make_unique(); - impl->start_stream(config_.base_url + config_.endpoint_path, headers, + impl->start_stream(config_.base_url + config_.completions_endpoint_path, headers, request_json); ai::logger::log_info("Text streaming started - model: {}", options.model); @@ -94,4 +96,4 @@ std::string OpenAIClient::default_model() const { } } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/openai/openai_client.h b/src/providers/openai/openai_client.h index e3ea93e..12df8bc 100644 --- a/src/providers/openai/openai_client.h +++ b/src/providers/openai/openai_client.h @@ -33,4 +33,4 @@ class OpenAIClient : public providers::BaseProviderClient { }; } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/openai/openai_request_builder.cpp b/src/providers/openai/openai_request_builder.cpp index f2ed8be..8afac44 100644 --- a/src/providers/openai/openai_request_builder.cpp +++ b/src/providers/openai/openai_request_builder.cpp @@ -11,6 +11,8 @@ nlohmann::json OpenAIRequestBuilder::build_request_json( nlohmann::json request{{"model", options.model}, {"messages", nlohmann::json::array()}}; + if (options.response_format) + request["response_format"] = options.response_format.value(); // Build messages array if (!options.messages.empty()) { // Use provided messages @@ -164,6 +166,46 @@ nlohmann::json OpenAIRequestBuilder::build_request_json( return request; } +nlohmann::json OpenAIRequestBuilder::build_request_json( + const EmbeddingOptions& options) { + nlohmann::json request{{"model", options.model}, + {"input", options.input}}; + + if (options.encoding_format) { + request["encoding_format"] = options.encoding_format.value(); + } + + if (options.dimensions && options.dimensions.value()) { + request["dimensions"] = options.dimensions.value(); + } + // Add optional parameters + if (options.temperature) { + request["temperature"] = *options.temperature; + } + + if (options.max_tokens) { + request["max_completion_tokens"] = *options.max_tokens; + } + + if (options.top_p) { + request["top_p"] = *options.top_p; + } + + if (options.frequency_penalty) { + request["frequency_penalty"] = *options.frequency_penalty; + } + + if (options.presence_penalty) { + request["presence_penalty"] = *options.presence_penalty; + } + + if (options.seed) { + request["seed"] = *options.seed; + } + + return request; +} + httplib::Headers OpenAIRequestBuilder::build_headers( const providers::ProviderConfig& config) { httplib::Headers headers = { diff --git a/src/providers/openai/openai_request_builder.h b/src/providers/openai/openai_request_builder.h index 5317769..3d529c3 100644 --- a/src/providers/openai/openai_request_builder.h +++ b/src/providers/openai/openai_request_builder.h @@ -11,6 +11,7 @@ namespace openai { class OpenAIRequestBuilder : public providers::RequestBuilder { public: nlohmann::json build_request_json(const GenerateOptions& options) override; + nlohmann::json build_request_json(const EmbeddingOptions& options) override; httplib::Headers build_headers( const providers::ProviderConfig& config) override; }; diff --git a/src/providers/openai/openai_response_parser.cpp b/src/providers/openai/openai_response_parser.cpp index 90c520a..b68e3d8 100644 --- a/src/providers/openai/openai_response_parser.cpp +++ b/src/providers/openai/openai_response_parser.cpp @@ -6,7 +6,7 @@ namespace ai { namespace openai { -GenerateResult OpenAIResponseParser::parse_success_response( +GenerateResult OpenAIResponseParser::parse_success_completion_response( const nlohmann::json& response) { ai::logger::log_debug("Parsing OpenAI chat completion response"); @@ -128,12 +128,48 @@ GenerateResult OpenAIResponseParser::parse_success_response( return result; } -GenerateResult OpenAIResponseParser::parse_error_response( +GenerateResult OpenAIResponseParser::parse_error_completion_response( int status_code, const std::string& body) { return utils::parse_standard_error_response("OpenAI", status_code, body); } +EmbeddingResult OpenAIResponseParser::parse_success_embedding_response(const nlohmann::json& response) { + ai::logger::log_debug("Parsing OpenAI embeddings response"); + + EmbeddingResult result; + + // Extract basic fields + result.model = response.value("model", ""); + + // Extract choices + if (response.contains("data") && !response["data"].empty()) { + result.data = std::move(response["data"]); + } + + // Extract usage + if (response.contains("usage")) { + auto& usage = response["usage"]; + result.usage.prompt_tokens = usage.value("prompt_tokens", 0); + result.usage.completion_tokens = usage.value("completion_tokens", 0); + result.usage.total_tokens = usage.value("total_tokens", 0); + ai::logger::log_debug("Token usage - prompt: {}, completion: {}, total: {}", + result.usage.prompt_tokens, + result.usage.completion_tokens, + result.usage.total_tokens); + } + + // Store full metadata + result.provider_metadata = response.dump(); + + return result; +} + +EmbeddingResult OpenAIResponseParser::parse_error_embedding_response(int status_code, const std::string& body) { + auto generate_result = utils::parse_standard_error_response("OpenAI", status_code, body); + return EmbeddingResult(generate_result.error); +} + FinishReason OpenAIResponseParser::parse_finish_reason( const std::string& reason) { if (reason == "stop") { diff --git a/src/providers/openai/openai_response_parser.h b/src/providers/openai/openai_response_parser.h index 14d62d9..6c2e564 100644 --- a/src/providers/openai/openai_response_parser.h +++ b/src/providers/openai/openai_response_parser.h @@ -10,9 +10,13 @@ namespace openai { class OpenAIResponseParser : public providers::ResponseParser { public: - GenerateResult parse_success_response( + GenerateResult parse_success_completion_response( const nlohmann::json& response) override; - GenerateResult parse_error_response(int status_code, + GenerateResult parse_error_completion_response(int status_code, + const std::string& body) override; + EmbeddingResult parse_success_embedding_response( + const nlohmann::json& response) override; + EmbeddingResult parse_error_embedding_response(int status_code, const std::string& body) override; private: diff --git a/src/types/embedding_options.cpp b/src/types/embedding_options.cpp new file mode 100644 index 0000000..19fde7e --- /dev/null +++ b/src/types/embedding_options.cpp @@ -0,0 +1,7 @@ +#include "ai/types/generate_options.h" + +namespace ai { + +// Implementation details for GenerateOptions if needed + +} // namespace ai \ No newline at end of file