From 31190a02405d3053e644c50c6930c472c8124494 Mon Sep 17 00:00:00 2001 From: fastio Date: Mon, 18 Aug 2025 21:31:45 +0800 Subject: [PATCH 1/4] Support embedding api --- include/ai/openai.h | 3 + include/ai/types/client.h | 7 ++ include/ai/types/embeddding_options.h | 84 +++++++++++++++++++ include/ai/types/generate_options.h | 10 +++ src/providers/anthropic/anthropic_client.cpp | 5 ++ src/providers/anthropic/anthropic_client.h | 1 + .../anthropic/anthropic_request_builder.cpp | 4 + .../anthropic/anthropic_request_builder.h | 1 + .../anthropic/anthropic_response_parser.cpp | 13 ++- .../anthropic/anthropic_response_parser.h | 8 +- src/providers/base_provider_client.cpp | 11 ++- src/providers/base_provider_client.h | 10 ++- src/providers/openai/openai_client.cpp | 49 +++++++++++ src/providers/openai/openai_client.h | 1 + .../openai/openai_request_builder.cpp | 42 ++++++++++ src/providers/openai/openai_request_builder.h | 1 + .../openai/openai_response_parser.cpp | 40 ++++++++- src/providers/openai/openai_response_parser.h | 8 +- src/types/embedding_options.cpp | 7 ++ 19 files changed, 293 insertions(+), 12 deletions(-) create mode 100644 include/ai/types/embeddding_options.h create mode 100644 src/types/embedding_options.cpp diff --git a/include/ai/openai.h b/include/ai/openai.h index 6a9bd92..35d1404 100644 --- a/include/ai/openai.h +++ b/include/ai/openai.h @@ -47,6 +47,9 @@ constexpr const char* kChatGpt4oLatest = "chatgpt-4o-latest"; /// Default model used when none is specified constexpr const char* kDefaultModel = kGpt4o; + +constexpr const char* kCompletions = "/v1/chat/completions"; +constexpr const char* kEmbeddings = "/v1/embeddings"; } // namespace models /// Create an OpenAI client with default configuration diff --git a/include/ai/types/client.h b/include/ai/types/client.h index dd5084d..51f383d 100644 --- a/include/ai/types/client.h +++ b/include/ai/types/client.h @@ -3,6 +3,7 @@ #include "generate_options.h" #include "stream_options.h" #include "stream_result.h" +#include "embeddding_options.h" #include #include @@ -31,6 +32,12 @@ class Client { return GenerateResult("Client not initialized"); } + virtual EmbeddingResult embedding(const EmbeddingOptions& options) { + if (pimpl_) + return pimpl_->embedding(options); + return EmbeddingResult("Client not initialized"); + } + virtual StreamResult stream_text(const StreamOptions& options) { if (pimpl_) return pimpl_->stream_text(options); diff --git a/include/ai/types/embeddding_options.h b/include/ai/types/embeddding_options.h new file mode 100644 index 0000000..f3fd3eb --- /dev/null +++ b/include/ai/types/embeddding_options.h @@ -0,0 +1,84 @@ +#pragma once + +#include "enums.h" +#include "message.h" +#include "model.h" +#include "tool.h" +#include "usage.h" + +#include +#include +#include +#include + +namespace ai { + +struct EmbeddingOptions { + std::string model; + nlohmann::json input; + std::optional dimensions; + std::optional encoding_format; + std::optional max_tokens; + std::optional temperature; + std::optional top_p; + std::optional seed; + std::optional frequency_penalty; + std::optional presence_penalty; + + EmbeddingOptions(std::string model_name, nlohmann::json input_) + : model(std::move(model_name)), + input(std::move(input_)) {} + + EmbeddingOptions(std::string model_name, nlohmann::json input_, int dimensions_) + : model(std::move(model_name)), + input(std::move(input_)), + dimensions(dimensions_) {} + + EmbeddingOptions(std::string model_name, nlohmann::json input_, int dimensions_, std::string encoding_format_) + : model(std::move(model_name)), + input(std::move(input_)), + dimensions(dimensions_), + encoding_format(std::move(encoding_format_)) {} + + EmbeddingOptions() = default; + + bool is_valid() const { + return !model.empty() && !input.empty(); + } + + bool has_input() const { return !input.empty(); } + +}; + +struct EmbeddingResult { + nlohmann::json data; + Usage usage; + + /// Additional metadata (like TypeScript SDK) + std::optional model; + + /// Error handling + std::optional error; + std::optional is_retryable; + + /// Provider-specific metadata + std::optional provider_metadata; + + EmbeddingResult() = default; + + // EmbeddingResult(std::string data_, Usage token_usage) + // : data(std::move(data_)), usage(token_usage) {} + + explicit EmbeddingResult(std::optional error_message) + : error(std::move(error_message)) {} + + bool is_success() const { + return !error.has_value(); + } + + explicit operator bool() const { return is_success(); } + + std::string error_message() const { return error.value_or(""); } +}; + +} // namespace ai diff --git a/include/ai/types/generate_options.h b/include/ai/types/generate_options.h index 8176d4e..832c06e 100644 --- a/include/ai/types/generate_options.h +++ b/include/ai/types/generate_options.h @@ -18,6 +18,7 @@ struct GenerateOptions { std::string system; std::string prompt; Messages messages; + std::optional response_format {}; std::optional max_tokens; std::optional temperature; std::optional top_p; @@ -46,6 +47,15 @@ struct GenerateOptions { system(std::move(system_prompt)), prompt(std::move(user_prompt)) {} + GenerateOptions(std::string model_name, + std::string system_prompt, + std::string user_prompt, + std::optional response_format_) + : model(std::move(model_name)), + system(std::move(system_prompt)), + prompt(std::move(user_prompt)), + response_format(std::move(response_format_)) {} + GenerateOptions(std::string model_name, Messages conversation) : model(std::move(model_name)), messages(std::move(conversation)) {} diff --git a/src/providers/anthropic/anthropic_client.cpp b/src/providers/anthropic/anthropic_client.cpp index 7684717..4f2868d 100644 --- a/src/providers/anthropic/anthropic_client.cpp +++ b/src/providers/anthropic/anthropic_client.cpp @@ -53,6 +53,11 @@ StreamResult AnthropicClient::stream_text(const StreamOptions& options) { return StreamResult(std::move(impl)); } +EmbeddingResult AnthropicClient::embedding(const EmbeddingOptions& options) { + ai::logger::log_error("Embedding not yet implemented in AnthropicClient"); + return EmbeddingResult(); +} + std::string AnthropicClient::provider_name() const { return "anthropic"; } diff --git a/src/providers/anthropic/anthropic_client.h b/src/providers/anthropic/anthropic_client.h index bb9bb08..f869139 100644 --- a/src/providers/anthropic/anthropic_client.h +++ b/src/providers/anthropic/anthropic_client.h @@ -17,6 +17,7 @@ class AnthropicClient : public providers::BaseProviderClient { // Override only what's specific to Anthropic StreamResult stream_text(const StreamOptions& options) override; + EmbeddingResult embedding(const EmbeddingOptions& options) override; std::string provider_name() const override; std::vector supported_models() const override; bool supports_model(const std::string& model_name) const override; diff --git a/src/providers/anthropic/anthropic_request_builder.cpp b/src/providers/anthropic/anthropic_request_builder.cpp index 8a29c45..47f2d96 100644 --- a/src/providers/anthropic/anthropic_request_builder.cpp +++ b/src/providers/anthropic/anthropic_request_builder.cpp @@ -157,6 +157,10 @@ nlohmann::json AnthropicRequestBuilder::build_request_json( return request; } +nlohmann::json AnthropicRequestBuilder::build_request_json(const EmbeddingOptions&) { + return {}; +} + httplib::Headers AnthropicRequestBuilder::build_headers( const providers::ProviderConfig& config) { httplib::Headers headers = { diff --git a/src/providers/anthropic/anthropic_request_builder.h b/src/providers/anthropic/anthropic_request_builder.h index bab9fb1..c7f6387 100644 --- a/src/providers/anthropic/anthropic_request_builder.h +++ b/src/providers/anthropic/anthropic_request_builder.h @@ -11,6 +11,7 @@ namespace anthropic { class AnthropicRequestBuilder : public providers::RequestBuilder { public: nlohmann::json build_request_json(const GenerateOptions& options) override; + nlohmann::json build_request_json(const EmbeddingOptions& options) override; httplib::Headers build_headers( const providers::ProviderConfig& config) override; }; diff --git a/src/providers/anthropic/anthropic_response_parser.cpp b/src/providers/anthropic/anthropic_response_parser.cpp index 1235713..f719e2d 100644 --- a/src/providers/anthropic/anthropic_response_parser.cpp +++ b/src/providers/anthropic/anthropic_response_parser.cpp @@ -6,7 +6,7 @@ namespace ai { namespace anthropic { -GenerateResult AnthropicResponseParser::parse_success_response( +GenerateResult AnthropicResponseParser::parse_success_completion_response( const nlohmann::json& response) { ai::logger::log_debug("Parsing Anthropic messages response"); @@ -86,12 +86,21 @@ GenerateResult AnthropicResponseParser::parse_success_response( return result; } -GenerateResult AnthropicResponseParser::parse_error_response( +GenerateResult AnthropicResponseParser::parse_error_completion_response( int status_code, const std::string& body) { return utils::parse_standard_error_response("Anthropic", status_code, body); } +EmbeddingResult AnthropicResponseParser::parse_success_embedding_response(const nlohmann::json&) { + return {}; +} + +EmbeddingResult AnthropicResponseParser::parse_error_embedding_response(int, const std::string&) { + return {}; +} + + FinishReason AnthropicResponseParser::parse_stop_reason( const std::string& reason) { if (reason == "end_turn") { diff --git a/src/providers/anthropic/anthropic_response_parser.h b/src/providers/anthropic/anthropic_response_parser.h index 822cd70..e4de998 100644 --- a/src/providers/anthropic/anthropic_response_parser.h +++ b/src/providers/anthropic/anthropic_response_parser.h @@ -10,9 +10,13 @@ namespace anthropic { class AnthropicResponseParser : public providers::ResponseParser { public: - GenerateResult parse_success_response( + GenerateResult parse_success_completion_response( const nlohmann::json& response) override; - GenerateResult parse_error_response(int status_code, + GenerateResult parse_error_completion_response(int status_code, + const std::string& body) override; + EmbeddingResult parse_success_embedding_response( + const nlohmann::json& response) override; + EmbeddingResult parse_error_embedding_response(int status_code, const std::string& body) override; private: diff --git a/src/providers/base_provider_client.cpp b/src/providers/base_provider_client.cpp index e734fae..927811c 100644 --- a/src/providers/base_provider_client.cpp +++ b/src/providers/base_provider_client.cpp @@ -71,7 +71,7 @@ GenerateResult BaseProviderClient::generate_text_single_step( // Parse error response using provider-specific parser if (result.provider_metadata.has_value()) { int status_code = std::stoi(result.provider_metadata.value()); - return response_parser_->parse_error_response( + return response_parser_->parse_error_completion_response( status_code, result.error.value_or("")); } return result; @@ -94,7 +94,7 @@ GenerateResult BaseProviderClient::generate_text_single_step( // Parse using provider-specific parser auto parsed_result = - response_parser_->parse_success_response(json_response); + response_parser_->parse_success_completion_response(json_response); if (parsed_result.has_tool_calls()) { ai::logger::log_debug("Model made {} tool calls", @@ -144,5 +144,12 @@ StreamResult BaseProviderClient::stream_text(const StreamOptions& options) { return StreamResult(); } +EmbeddingResult BaseProviderClient::embedding(const EmbeddingOptions& options) { + // This needs to be implemented with provider-specific stream implementations + // For now, return an error + ai::logger::log_error("Embedding not yet implemented in BaseProviderClient"); + return EmbeddingResult(); +} + } // namespace providers } // namespace ai \ No newline at end of file diff --git a/src/providers/base_provider_client.h b/src/providers/base_provider_client.h index 34fc3cd..dd99073 100644 --- a/src/providers/base_provider_client.h +++ b/src/providers/base_provider_client.h @@ -32,6 +32,7 @@ class RequestBuilder { public: virtual ~RequestBuilder() = default; virtual nlohmann::json build_request_json(const GenerateOptions& options) = 0; + virtual nlohmann::json build_request_json(const EmbeddingOptions& options) = 0; virtual httplib::Headers build_headers(const ProviderConfig& config) = 0; }; @@ -39,9 +40,13 @@ class RequestBuilder { class ResponseParser { public: virtual ~ResponseParser() = default; - virtual GenerateResult parse_success_response( + virtual GenerateResult parse_success_completion_response( const nlohmann::json& response) = 0; - virtual GenerateResult parse_error_response(int status_code, + virtual GenerateResult parse_error_completion_response(int status_code, + const std::string& body) = 0; + virtual EmbeddingResult parse_success_embedding_response( + const nlohmann::json& response) = 0; + virtual EmbeddingResult parse_error_embedding_response(int status_code, const std::string& body) = 0; }; @@ -55,6 +60,7 @@ class BaseProviderClient : public Client { // Implements the common flow using the composed components GenerateResult generate_text(const GenerateOptions& options) override; StreamResult stream_text(const StreamOptions& options) override; + EmbeddingResult embedding(const EmbeddingOptions& options) override; bool is_valid() const override { return !config_.api_key.empty(); } diff --git a/src/providers/openai/openai_client.cpp b/src/providers/openai/openai_client.cpp index 28a4b9b..3469920 100644 --- a/src/providers/openai/openai_client.cpp +++ b/src/providers/openai/openai_client.cpp @@ -70,6 +70,55 @@ StreamResult OpenAIClient::stream_text(const StreamOptions& options) { return StreamResult(std::move(impl)); } +EmbeddingResult OpenAIClient::embedding(const EmbeddingOptions& options) { + try { + // Build request JSON using the provider-specific builder + auto request_json = request_builder_->build_request_json(options); + std::string json_body = request_json.dump(); + ai::logger::log_debug("Request JSON built: {}", json_body); + + // Build headers + auto headers = request_builder_->build_headers(config_); + + // Make the requests + auto result = + http_handler_->post(models::kEmbeddings, headers, json_body); + + if (!result.is_success()) { + // Parse error response using provider-specific parser + if (result.provider_metadata.has_value()) { + int status_code = std::stoi(result.provider_metadata.value()); + return response_parser_->parse_error_embedding_response( + status_code, result.error.value_or("")); + } + return EmbeddingResult(result.error); + } + + // Parse the response JSON from result.text + nlohmann::json json_response; + try { + json_response = nlohmann::json::parse(result.text); + } catch (const nlohmann::json::exception& e) { + ai::logger::log_error("Failed to parse response JSON: {}", e.what()); + ai::logger::log_debug("Raw response text: {}", result.text); + return EmbeddingResult("Failed to parse response: " + + std::string(e.what())); + } + + ai::logger::log_info( + "Text generation successful - model: {}, response_id: {}", + options.model, json_response.value("id", "unknown")); + + // Parse using provider-specific parser + auto parsed_result = + response_parser_->parse_success_embedding_response(json_response); + return parsed_result; + + } catch (const std::exception& e) { + ai::logger::log_error("Exception during text generation: {}", e.what()); + return EmbeddingResult(std::string("Exception: ") + e.what()); + } +} std::string OpenAIClient::provider_name() const { return "openai"; } diff --git a/src/providers/openai/openai_client.h b/src/providers/openai/openai_client.h index e3ea93e..3f59cc2 100644 --- a/src/providers/openai/openai_client.h +++ b/src/providers/openai/openai_client.h @@ -21,6 +21,7 @@ class OpenAIClient : public providers::BaseProviderClient { // Override only what's specific to OpenAI StreamResult stream_text(const StreamOptions& options) override; + EmbeddingResult embedding(const EmbeddingOptions& options) override; std::string provider_name() const override; std::vector supported_models() const override; bool supports_model(const std::string& model_name) const override; diff --git a/src/providers/openai/openai_request_builder.cpp b/src/providers/openai/openai_request_builder.cpp index f2ed8be..8afac44 100644 --- a/src/providers/openai/openai_request_builder.cpp +++ b/src/providers/openai/openai_request_builder.cpp @@ -11,6 +11,8 @@ nlohmann::json OpenAIRequestBuilder::build_request_json( nlohmann::json request{{"model", options.model}, {"messages", nlohmann::json::array()}}; + if (options.response_format) + request["response_format"] = options.response_format.value(); // Build messages array if (!options.messages.empty()) { // Use provided messages @@ -164,6 +166,46 @@ nlohmann::json OpenAIRequestBuilder::build_request_json( return request; } +nlohmann::json OpenAIRequestBuilder::build_request_json( + const EmbeddingOptions& options) { + nlohmann::json request{{"model", options.model}, + {"input", options.input}}; + + if (options.encoding_format) { + request["encoding_format"] = options.encoding_format.value(); + } + + if (options.dimensions && options.dimensions.value()) { + request["dimensions"] = options.dimensions.value(); + } + // Add optional parameters + if (options.temperature) { + request["temperature"] = *options.temperature; + } + + if (options.max_tokens) { + request["max_completion_tokens"] = *options.max_tokens; + } + + if (options.top_p) { + request["top_p"] = *options.top_p; + } + + if (options.frequency_penalty) { + request["frequency_penalty"] = *options.frequency_penalty; + } + + if (options.presence_penalty) { + request["presence_penalty"] = *options.presence_penalty; + } + + if (options.seed) { + request["seed"] = *options.seed; + } + + return request; +} + httplib::Headers OpenAIRequestBuilder::build_headers( const providers::ProviderConfig& config) { httplib::Headers headers = { diff --git a/src/providers/openai/openai_request_builder.h b/src/providers/openai/openai_request_builder.h index 5317769..3d529c3 100644 --- a/src/providers/openai/openai_request_builder.h +++ b/src/providers/openai/openai_request_builder.h @@ -11,6 +11,7 @@ namespace openai { class OpenAIRequestBuilder : public providers::RequestBuilder { public: nlohmann::json build_request_json(const GenerateOptions& options) override; + nlohmann::json build_request_json(const EmbeddingOptions& options) override; httplib::Headers build_headers( const providers::ProviderConfig& config) override; }; diff --git a/src/providers/openai/openai_response_parser.cpp b/src/providers/openai/openai_response_parser.cpp index 90c520a..1bff129 100644 --- a/src/providers/openai/openai_response_parser.cpp +++ b/src/providers/openai/openai_response_parser.cpp @@ -6,7 +6,7 @@ namespace ai { namespace openai { -GenerateResult OpenAIResponseParser::parse_success_response( +GenerateResult OpenAIResponseParser::parse_success_completion_response( const nlohmann::json& response) { ai::logger::log_debug("Parsing OpenAI chat completion response"); @@ -128,12 +128,48 @@ GenerateResult OpenAIResponseParser::parse_success_response( return result; } -GenerateResult OpenAIResponseParser::parse_error_response( +GenerateResult OpenAIResponseParser::parse_error_completion_response( int status_code, const std::string& body) { return utils::parse_standard_error_response("OpenAI", status_code, body); } +EmbeddingResult OpenAIResponseParser::parse_success_embedding_response(const nlohmann::json& response) { + ai::logger::log_debug("Parsing OpenAI embedding response"); + + EmbeddingResult result; + + // Extract basic fields + result.model = response.value("model", ""); + + // Extract choices + if (response.contains("data") && !response["data"].empty()) { + result.data = std::move(response["data"]); + } + + // Extract usage + if (response.contains("usage")) { + auto& usage = response["usage"]; + result.usage.prompt_tokens = usage.value("prompt_tokens", 0); + result.usage.completion_tokens = usage.value("completion_tokens", 0); + result.usage.total_tokens = usage.value("total_tokens", 0); + ai::logger::log_debug("Token usage - prompt: {}, completion: {}, total: {}", + result.usage.prompt_tokens, + result.usage.completion_tokens, + result.usage.total_tokens); + } + + // Store full metadata + result.provider_metadata = response.dump(); + + return result; +} + +EmbeddingResult OpenAIResponseParser::parse_error_embedding_response(int status_code, const std::string& body) { + auto generate_result = utils::parse_standard_error_response("OpenAI", status_code, body); + return EmbeddingResult(generate_result.error); +} + FinishReason OpenAIResponseParser::parse_finish_reason( const std::string& reason) { if (reason == "stop") { diff --git a/src/providers/openai/openai_response_parser.h b/src/providers/openai/openai_response_parser.h index 14d62d9..6c2e564 100644 --- a/src/providers/openai/openai_response_parser.h +++ b/src/providers/openai/openai_response_parser.h @@ -10,9 +10,13 @@ namespace openai { class OpenAIResponseParser : public providers::ResponseParser { public: - GenerateResult parse_success_response( + GenerateResult parse_success_completion_response( const nlohmann::json& response) override; - GenerateResult parse_error_response(int status_code, + GenerateResult parse_error_completion_response(int status_code, + const std::string& body) override; + EmbeddingResult parse_success_embedding_response( + const nlohmann::json& response) override; + EmbeddingResult parse_error_embedding_response(int status_code, const std::string& body) override; private: diff --git a/src/types/embedding_options.cpp b/src/types/embedding_options.cpp new file mode 100644 index 0000000..19fde7e --- /dev/null +++ b/src/types/embedding_options.cpp @@ -0,0 +1,7 @@ +#include "ai/types/generate_options.h" + +namespace ai { + +// Implementation details for GenerateOptions if needed + +} // namespace ai \ No newline at end of file From 65a2603daed83e88850ae38362f7d10a954fdeb2 Mon Sep 17 00:00:00 2001 From: fastio Date: Tue, 19 Aug 2025 10:21:46 +0800 Subject: [PATCH 2/4] fix log --- src/providers/openai/openai_client.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/providers/openai/openai_client.cpp b/src/providers/openai/openai_client.cpp index 3469920..38cf742 100644 --- a/src/providers/openai/openai_client.cpp +++ b/src/providers/openai/openai_client.cpp @@ -106,7 +106,7 @@ EmbeddingResult OpenAIClient::embedding(const EmbeddingOptions& options) { } ai::logger::log_info( - "Text generation successful - model: {}, response_id: {}", + "Embedding successful - model: {}, response_id: {}", options.model, json_response.value("id", "unknown")); // Parse using provider-specific parser @@ -115,7 +115,7 @@ EmbeddingResult OpenAIClient::embedding(const EmbeddingOptions& options) { return parsed_result; } catch (const std::exception& e) { - ai::logger::log_error("Exception during text generation: {}", e.what()); + ai::logger::log_error("Exception during embedding: {}", e.what()); return EmbeddingResult(std::string("Exception: ") + e.what()); } } @@ -143,4 +143,4 @@ std::string OpenAIClient::default_model() const { } } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai From 880771da1b4871da111dae2ad2ce8a22dfb6f114 Mon Sep 17 00:00:00 2001 From: fastio Date: Tue, 19 Aug 2025 16:24:21 +0800 Subject: [PATCH 3/4] support embeddings --- include/ai/types/client.h | 4 +- src/providers/anthropic/anthropic_client.cpp | 12 ++-- src/providers/anthropic/anthropic_client.h | 2 +- .../anthropic/anthropic_request_builder.cpp | 8 ++- .../anthropic/anthropic_response_parser.cpp | 36 +++++++++-- src/providers/base_provider_client.cpp | 61 ++++++++++++++++--- src/providers/base_provider_client.h | 5 +- src/providers/openai/openai_client.cpp | 19 +++--- src/providers/openai/openai_client.h | 2 +- .../openai/openai_response_parser.cpp | 2 +- 10 files changed, 117 insertions(+), 34 deletions(-) diff --git a/include/ai/types/client.h b/include/ai/types/client.h index 51f383d..c49c570 100644 --- a/include/ai/types/client.h +++ b/include/ai/types/client.h @@ -32,9 +32,9 @@ class Client { return GenerateResult("Client not initialized"); } - virtual EmbeddingResult embedding(const EmbeddingOptions& options) { + virtual EmbeddingResult embeddings(const EmbeddingOptions& options) { if (pimpl_) - return pimpl_->embedding(options); + return pimpl_->embeddings(options); return EmbeddingResult("Client not initialized"); } diff --git a/src/providers/anthropic/anthropic_client.cpp b/src/providers/anthropic/anthropic_client.cpp index 4f2868d..8f6a8f4 100644 --- a/src/providers/anthropic/anthropic_client.cpp +++ b/src/providers/anthropic/anthropic_client.cpp @@ -18,7 +18,8 @@ AnthropicClient::AnthropicClient(const std::string& api_key, providers::ProviderConfig{ .api_key = api_key, .base_url = base_url, - .endpoint_path = "/v1/messages", + .completions_endpoint_path = "/v1/messages", + .embeddings_endpoint_path = "/v1/embeddings", .auth_header_name = "x-api-key", .auth_header_prefix = "", .extra_headers = {{"anthropic-version", "2023-06-01"}}}, @@ -44,19 +45,20 @@ StreamResult AnthropicClient::stream_text(const StreamOptions& options) { // Create stream implementation auto impl = std::make_unique(); - impl->start_stream(config_.base_url + config_.endpoint_path, headers, - request_json); + impl->start_stream(config_.base_url + config_.completions_endpoint_path, + headers, request_json); ai::logger::log_info("Text streaming started - model: {}", options.model); // Return StreamResult with implementation return StreamResult(std::move(impl)); } - -EmbeddingResult AnthropicClient::embedding(const EmbeddingOptions& options) { +#if 0 +EmbeddingResult AnthropicClient::embeddings(const EmbeddingOptions& options) { ai::logger::log_error("Embedding not yet implemented in AnthropicClient"); return EmbeddingResult(); } +#endif std::string AnthropicClient::provider_name() const { return "anthropic"; diff --git a/src/providers/anthropic/anthropic_client.h b/src/providers/anthropic/anthropic_client.h index f869139..5b83bad 100644 --- a/src/providers/anthropic/anthropic_client.h +++ b/src/providers/anthropic/anthropic_client.h @@ -17,7 +17,7 @@ class AnthropicClient : public providers::BaseProviderClient { // Override only what's specific to Anthropic StreamResult stream_text(const StreamOptions& options) override; - EmbeddingResult embedding(const EmbeddingOptions& options) override; + //EmbeddingResult embeddings(const EmbeddingOptions& options) override; std::string provider_name() const override; std::vector supported_models() const override; bool supports_model(const std::string& model_name) const override; diff --git a/src/providers/anthropic/anthropic_request_builder.cpp b/src/providers/anthropic/anthropic_request_builder.cpp index 47f2d96..231bebd 100644 --- a/src/providers/anthropic/anthropic_request_builder.cpp +++ b/src/providers/anthropic/anthropic_request_builder.cpp @@ -13,6 +13,8 @@ nlohmann::json AnthropicRequestBuilder::build_request_json( request["max_tokens"] = options.max_tokens.value_or(4096); request["messages"] = nlohmann::json::array(); + if (options.response_format) + request["response_format"] = options.response_format.value(); // Handle system message if (!options.system.empty()) { request["system"] = options.system; @@ -157,8 +159,10 @@ nlohmann::json AnthropicRequestBuilder::build_request_json( return request; } -nlohmann::json AnthropicRequestBuilder::build_request_json(const EmbeddingOptions&) { - return {}; +nlohmann::json AnthropicRequestBuilder::build_request_json(const EmbeddingOptions& options) { + nlohmann::json request{{"model", options.model}, + {"input", options.input}}; + return request; } httplib::Headers AnthropicRequestBuilder::build_headers( diff --git a/src/providers/anthropic/anthropic_response_parser.cpp b/src/providers/anthropic/anthropic_response_parser.cpp index f719e2d..0c2e7f4 100644 --- a/src/providers/anthropic/anthropic_response_parser.cpp +++ b/src/providers/anthropic/anthropic_response_parser.cpp @@ -92,12 +92,40 @@ GenerateResult AnthropicResponseParser::parse_error_completion_response( return utils::parse_standard_error_response("Anthropic", status_code, body); } -EmbeddingResult AnthropicResponseParser::parse_success_embedding_response(const nlohmann::json&) { - return {}; +EmbeddingResult AnthropicResponseParser::parse_success_embedding_response(const nlohmann::json& response) { + ai::logger::log_debug("Parsing Anthropic embeddings response"); + + EmbeddingResult result; + + // Extract basic fields + result.model = response.value("model", ""); + + // Extract choices + if (response.contains("data") && !response["data"].empty()) { + result.data = std::move(response["data"]); + } + + // Extract usage + if (response.contains("usage")) { + auto& usage = response["usage"]; + result.usage.prompt_tokens = usage.value("prompt_tokens", 0); + result.usage.completion_tokens = usage.value("completion_tokens", 0); + result.usage.total_tokens = usage.value("total_tokens", 0); + ai::logger::log_debug("Token usage - prompt: {}, completion: {}, total: {}", + result.usage.prompt_tokens, + result.usage.completion_tokens, + result.usage.total_tokens); + } + + // Store full metadata + result.provider_metadata = response.dump(); + + return result; } -EmbeddingResult AnthropicResponseParser::parse_error_embedding_response(int, const std::string&) { - return {}; +EmbeddingResult AnthropicResponseParser::parse_error_embedding_response(int status_code, const std::string& body) { + auto generate_result = utils::parse_standard_error_response("Anthropic", status_code, body); + return EmbeddingResult(generate_result.error); } diff --git a/src/providers/base_provider_client.cpp b/src/providers/base_provider_client.cpp index 927811c..5e47c2d 100644 --- a/src/providers/base_provider_client.cpp +++ b/src/providers/base_provider_client.cpp @@ -24,8 +24,10 @@ BaseProviderClient::BaseProviderClient( http_handler_ = std::make_unique(http_config); ai::logger::log_debug( - "BaseProviderClient initialized - base_url: {}, endpoint: {}", - config.base_url, config.endpoint_path); + R"(BaseProviderClient initialized - base_url: {}, + completions_endpoint: {}, embeddings_endpoint: {})", + config.base_url, config.completions_endpoint_path, + config.embeddings_endpoint_path); } GenerateResult BaseProviderClient::generate_text( @@ -65,7 +67,7 @@ GenerateResult BaseProviderClient::generate_text_single_step( // Make the request auto result = - http_handler_->post(config_.endpoint_path, headers, json_body); + http_handler_->post(config_.completions_endpoint_path, headers, json_body); if (!result.is_success()) { // Parse error response using provider-specific parser @@ -144,11 +146,54 @@ StreamResult BaseProviderClient::stream_text(const StreamOptions& options) { return StreamResult(); } -EmbeddingResult BaseProviderClient::embedding(const EmbeddingOptions& options) { - // This needs to be implemented with provider-specific stream implementations - // For now, return an error - ai::logger::log_error("Embedding not yet implemented in BaseProviderClient"); - return EmbeddingResult(); +EmbeddingResult BaseProviderClient::embeddings(const EmbeddingOptions& options) { + try { + // Build request JSON using the provider-specific builder + auto request_json = request_builder_->build_request_json(options); + std::string json_body = request_json.dump(); + ai::logger::log_debug("Request JSON built: {}", json_body); + + // Build headers + auto headers = request_builder_->build_headers(config_); + + // Make the requests + auto result = + http_handler_->post(config_.embeddings_endpoint_path, headers, json_body); + + if (!result.is_success()) { + // Parse error response using provider-specific parser + if (result.provider_metadata.has_value()) { + int status_code = std::stoi(result.provider_metadata.value()); + return response_parser_->parse_error_embedding_response( + status_code, result.error.value_or("")); + } + return EmbeddingResult(result.error); + } + + // Parse the response JSON from result.text + nlohmann::json json_response; + try { + json_response = nlohmann::json::parse(result.text); + } catch (const nlohmann::json::exception& e) { + ai::logger::log_error("Failed to parse response JSON: {}", e.what()); + ai::logger::log_debug("Raw response text: {}", result.text); + return EmbeddingResult("Failed to parse response: " + + std::string(e.what())); + } + + ai::logger::log_info( + "Embeddings successful - model: {}, response_id: {}", + options.model, json_response.value("id", "unknown")); + + // Parse using provider-specific parser + auto parsed_result = + response_parser_->parse_success_embedding_response(json_response); + return parsed_result; + + } catch (const std::exception& e) { + ai::logger::log_error("Exception during embeddings: {}", e.what()); + return EmbeddingResult(std::string("Exception: ") + e.what()); + } } } // namespace providers diff --git a/src/providers/base_provider_client.h b/src/providers/base_provider_client.h index dd99073..6ca19a0 100644 --- a/src/providers/base_provider_client.h +++ b/src/providers/base_provider_client.h @@ -18,7 +18,8 @@ namespace providers { struct ProviderConfig { std::string api_key; std::string base_url; - std::string endpoint_path; // e.g., "/v1/chat/completions" or "/v1/messages" + std::string completions_endpoint_path; // e.g. "/v1/chat/completions" + std::string embeddings_endpoint_path; std::string auth_header_name; // e.g., "Authorization" or "x-api-key" std::string auth_header_prefix; // e.g., "Bearer " or "" httplib::Headers extra_headers; // Additional headers like anthropic-version @@ -60,7 +61,7 @@ class BaseProviderClient : public Client { // Implements the common flow using the composed components GenerateResult generate_text(const GenerateOptions& options) override; StreamResult stream_text(const StreamOptions& options) override; - EmbeddingResult embedding(const EmbeddingOptions& options) override; + EmbeddingResult embeddings(const EmbeddingOptions& options) override; bool is_valid() const override { return !config_.api_key.empty(); } diff --git a/src/providers/openai/openai_client.cpp b/src/providers/openai/openai_client.cpp index 38cf742..9896d83 100644 --- a/src/providers/openai/openai_client.cpp +++ b/src/providers/openai/openai_client.cpp @@ -17,7 +17,8 @@ OpenAIClient::OpenAIClient(const std::string& api_key, : BaseProviderClient( providers::ProviderConfig{.api_key = api_key, .base_url = base_url, - .endpoint_path = "/v1/chat/completions", + .completions_endpoint_path = "/v1/chat/completions", + .embeddings_endpoint_path = "/v1/embeddings", .auth_header_name = "Authorization", .auth_header_prefix = "Bearer ", .extra_headers = {}}, @@ -33,7 +34,8 @@ OpenAIClient::OpenAIClient(const std::string& api_key, : BaseProviderClient( providers::ProviderConfig{.api_key = api_key, .base_url = base_url, - .endpoint_path = "/v1/chat/completions", + .completions_endpoint_path = "/v1/chat/completions", + .embeddings_endpoint_path = "/v1/embeddings", .auth_header_name = "Authorization", .auth_header_prefix = "Bearer ", .extra_headers = {}, @@ -61,7 +63,7 @@ StreamResult OpenAIClient::stream_text(const StreamOptions& options) { // Create stream implementation auto impl = std::make_unique(); - impl->start_stream(config_.base_url + config_.endpoint_path, headers, + impl->start_stream(config_.base_url + config_.completions_endpoint_path, headers, request_json); ai::logger::log_info("Text streaming started - model: {}", options.model); @@ -69,8 +71,8 @@ StreamResult OpenAIClient::stream_text(const StreamOptions& options) { // Return StreamResult with implementation return StreamResult(std::move(impl)); } - -EmbeddingResult OpenAIClient::embedding(const EmbeddingOptions& options) { +#if 0 +EmbeddingResult OpenAIClient::embeddings(const EmbeddingOptions& options) { try { // Build request JSON using the provider-specific builder auto request_json = request_builder_->build_request_json(options); @@ -106,7 +108,7 @@ EmbeddingResult OpenAIClient::embedding(const EmbeddingOptions& options) { } ai::logger::log_info( - "Embedding successful - model: {}, response_id: {}", + "Text generation successful - model: {}, response_id: {}", options.model, json_response.value("id", "unknown")); // Parse using provider-specific parser @@ -115,10 +117,11 @@ EmbeddingResult OpenAIClient::embedding(const EmbeddingOptions& options) { return parsed_result; } catch (const std::exception& e) { - ai::logger::log_error("Exception during embedding: {}", e.what()); + ai::logger::log_error("Exception during text generation: {}", e.what()); return EmbeddingResult(std::string("Exception: ") + e.what()); } } +#endif std::string OpenAIClient::provider_name() const { return "openai"; } @@ -143,4 +146,4 @@ std::string OpenAIClient::default_model() const { } } // namespace openai -} // namespace ai +} // namespace ai \ No newline at end of file diff --git a/src/providers/openai/openai_client.h b/src/providers/openai/openai_client.h index 3f59cc2..361e00f 100644 --- a/src/providers/openai/openai_client.h +++ b/src/providers/openai/openai_client.h @@ -21,7 +21,7 @@ class OpenAIClient : public providers::BaseProviderClient { // Override only what's specific to OpenAI StreamResult stream_text(const StreamOptions& options) override; - EmbeddingResult embedding(const EmbeddingOptions& options) override; + //EmbeddingResult embeddings(const EmbeddingOptions& options) override; std::string provider_name() const override; std::vector supported_models() const override; bool supports_model(const std::string& model_name) const override; diff --git a/src/providers/openai/openai_response_parser.cpp b/src/providers/openai/openai_response_parser.cpp index 1bff129..b68e3d8 100644 --- a/src/providers/openai/openai_response_parser.cpp +++ b/src/providers/openai/openai_response_parser.cpp @@ -135,7 +135,7 @@ GenerateResult OpenAIResponseParser::parse_error_completion_response( } EmbeddingResult OpenAIResponseParser::parse_success_embedding_response(const nlohmann::json& response) { - ai::logger::log_debug("Parsing OpenAI embedding response"); + ai::logger::log_debug("Parsing OpenAI embeddings response"); EmbeddingResult result; From 8e4926bff593a3e70015cc9d9689ebd4e366b565 Mon Sep 17 00:00:00 2001 From: fastio Date: Tue, 19 Aug 2025 16:29:32 +0800 Subject: [PATCH 4/4] fix stype --- include/ai/openai.h | 4 +- src/providers/anthropic/anthropic_client.cpp | 8 +-- src/providers/anthropic/anthropic_client.h | 3 +- .../anthropic/anthropic_request_builder.cpp | 4 +- src/providers/openai/openai_client.cpp | 54 +------------------ src/providers/openai/openai_client.h | 3 +- 6 files changed, 7 insertions(+), 69 deletions(-) diff --git a/include/ai/openai.h b/include/ai/openai.h index 35d1404..1277e5d 100644 --- a/include/ai/openai.h +++ b/include/ai/openai.h @@ -48,8 +48,6 @@ constexpr const char* kChatGpt4oLatest = "chatgpt-4o-latest"; /// Default model used when none is specified constexpr const char* kDefaultModel = kGpt4o; -constexpr const char* kCompletions = "/v1/chat/completions"; -constexpr const char* kEmbeddings = "/v1/embeddings"; } // namespace models /// Create an OpenAI client with default configuration @@ -85,4 +83,4 @@ Client create_client(const std::string& api_key, std::optional try_create_client(); } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/anthropic/anthropic_client.cpp b/src/providers/anthropic/anthropic_client.cpp index 8f6a8f4..21b8f44 100644 --- a/src/providers/anthropic/anthropic_client.cpp +++ b/src/providers/anthropic/anthropic_client.cpp @@ -53,12 +53,6 @@ StreamResult AnthropicClient::stream_text(const StreamOptions& options) { // Return StreamResult with implementation return StreamResult(std::move(impl)); } -#if 0 -EmbeddingResult AnthropicClient::embeddings(const EmbeddingOptions& options) { - ai::logger::log_error("Embedding not yet implemented in AnthropicClient"); - return EmbeddingResult(); -} -#endif std::string AnthropicClient::provider_name() const { return "anthropic"; @@ -84,4 +78,4 @@ std::string AnthropicClient::default_model() const { } } // namespace anthropic -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/anthropic/anthropic_client.h b/src/providers/anthropic/anthropic_client.h index 5b83bad..f66e432 100644 --- a/src/providers/anthropic/anthropic_client.h +++ b/src/providers/anthropic/anthropic_client.h @@ -17,7 +17,6 @@ class AnthropicClient : public providers::BaseProviderClient { // Override only what's specific to Anthropic StreamResult stream_text(const StreamOptions& options) override; - //EmbeddingResult embeddings(const EmbeddingOptions& options) override; std::string provider_name() const override; std::vector supported_models() const override; bool supports_model(const std::string& model_name) const override; @@ -30,4 +29,4 @@ class AnthropicClient : public providers::BaseProviderClient { }; } // namespace anthropic -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/anthropic/anthropic_request_builder.cpp b/src/providers/anthropic/anthropic_request_builder.cpp index 231bebd..a15df27 100644 --- a/src/providers/anthropic/anthropic_request_builder.cpp +++ b/src/providers/anthropic/anthropic_request_builder.cpp @@ -13,8 +13,6 @@ nlohmann::json AnthropicRequestBuilder::build_request_json( request["max_tokens"] = options.max_tokens.value_or(4096); request["messages"] = nlohmann::json::array(); - if (options.response_format) - request["response_format"] = options.response_format.value(); // Handle system message if (!options.system.empty()) { request["system"] = options.system; @@ -180,4 +178,4 @@ httplib::Headers AnthropicRequestBuilder::build_headers( } } // namespace anthropic -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/openai/openai_client.cpp b/src/providers/openai/openai_client.cpp index 9896d83..6876dfb 100644 --- a/src/providers/openai/openai_client.cpp +++ b/src/providers/openai/openai_client.cpp @@ -71,57 +71,7 @@ StreamResult OpenAIClient::stream_text(const StreamOptions& options) { // Return StreamResult with implementation return StreamResult(std::move(impl)); } -#if 0 -EmbeddingResult OpenAIClient::embeddings(const EmbeddingOptions& options) { - try { - // Build request JSON using the provider-specific builder - auto request_json = request_builder_->build_request_json(options); - std::string json_body = request_json.dump(); - ai::logger::log_debug("Request JSON built: {}", json_body); - - // Build headers - auto headers = request_builder_->build_headers(config_); - - // Make the requests - auto result = - http_handler_->post(models::kEmbeddings, headers, json_body); - - if (!result.is_success()) { - // Parse error response using provider-specific parser - if (result.provider_metadata.has_value()) { - int status_code = std::stoi(result.provider_metadata.value()); - return response_parser_->parse_error_embedding_response( - status_code, result.error.value_or("")); - } - return EmbeddingResult(result.error); - } - - // Parse the response JSON from result.text - nlohmann::json json_response; - try { - json_response = nlohmann::json::parse(result.text); - } catch (const nlohmann::json::exception& e) { - ai::logger::log_error("Failed to parse response JSON: {}", e.what()); - ai::logger::log_debug("Raw response text: {}", result.text); - return EmbeddingResult("Failed to parse response: " + - std::string(e.what())); - } - - ai::logger::log_info( - "Text generation successful - model: {}, response_id: {}", - options.model, json_response.value("id", "unknown")); - - // Parse using provider-specific parser - auto parsed_result = - response_parser_->parse_success_embedding_response(json_response); - return parsed_result; - - } catch (const std::exception& e) { - ai::logger::log_error("Exception during text generation: {}", e.what()); - return EmbeddingResult(std::string("Exception: ") + e.what()); - } -} -#endif + std::string OpenAIClient::provider_name() const { return "openai"; } @@ -146,4 +96,4 @@ std::string OpenAIClient::default_model() const { } } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai diff --git a/src/providers/openai/openai_client.h b/src/providers/openai/openai_client.h index 361e00f..12df8bc 100644 --- a/src/providers/openai/openai_client.h +++ b/src/providers/openai/openai_client.h @@ -21,7 +21,6 @@ class OpenAIClient : public providers::BaseProviderClient { // Override only what's specific to OpenAI StreamResult stream_text(const StreamOptions& options) override; - //EmbeddingResult embeddings(const EmbeddingOptions& options) override; std::string provider_name() const override; std::vector supported_models() const override; bool supports_model(const std::string& model_name) const override; @@ -34,4 +33,4 @@ class OpenAIClient : public providers::BaseProviderClient { }; } // namespace openai -} // namespace ai \ No newline at end of file +} // namespace ai