diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3bba290..7839dcd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,6 +16,17 @@ jobs: matrix: build_type: [debug, release] + services: + clickhouse: + image: clickhouse/clickhouse-server + ports: + - 18123:8123 + - 19000:9000 + env: + CLICKHOUSE_PASSWORD: changeme + options: >- + --ulimit nofile=262144:262144 + steps: - name: Checkout code uses: actions/checkout@v4 diff --git a/.gitmodules b/.gitmodules index c9c2c42..b1ac7d7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "third_party/googletest"] path = third_party/googletest url = https://github.com/google/googletest.git +[submodule "third_party/clickhouse-cpp"] + path = third_party/clickhouse-cpp + url = https://github.com/ClickHouse/clickhouse-cpp.git diff --git a/include/ai/openai.h b/include/ai/openai.h index 0f13d91..6a9bd92 100644 --- a/include/ai/openai.h +++ b/include/ai/openai.h @@ -16,12 +16,35 @@ namespace openai { namespace models { /// Common OpenAI model identifiers + +// O-series reasoning models +constexpr const char* kO1 = "o1"; +constexpr const char* kO1Mini = "o1-mini"; +constexpr const char* kO1Preview = "o1-preview"; +constexpr const char* kO3 = "o3"; +constexpr const char* kO3Mini = "o3-mini"; +constexpr const char* kO4Mini = "o4-mini"; + +// GPT-4.1 series +constexpr const char* kGpt41 = "gpt-4.1"; +constexpr const char* kGpt41Mini = "gpt-4.1-mini"; +constexpr const char* kGpt41Nano = "gpt-4.1-nano"; + +// GPT-4o series constexpr const char* kGpt4o = "gpt-4o"; constexpr const char* kGpt4oMini = "gpt-4o-mini"; +constexpr const char* kGpt4oAudioPreview = "gpt-4o-audio-preview"; + +// GPT-4 series constexpr const char* kGpt4Turbo = "gpt-4-turbo"; -constexpr const char* kGpt35Turbo = "gpt-3.5-turbo"; constexpr const char* kGpt4 = "gpt-4"; +// GPT-3.5 series +constexpr const char* kGpt35Turbo = "gpt-3.5-turbo"; + +// Special models +constexpr const char* kChatGpt4oLatest = "chatgpt-4o-latest"; + /// Default model used when none is specified constexpr const char* kDefaultModel = kGpt4o; } // namespace models diff --git a/include/ai/tools.h b/include/ai/tools.h index 8c7f51d..1e2f0fe 100644 --- a/include/ai/tools.h +++ b/include/ai/tools.h @@ -80,7 +80,8 @@ class MultiStepCoordinator { /// @return Final generation result with all steps static GenerateResult execute_multi_step( const GenerateOptions& initial_options, - std::function generate_func); + const std::function& + generate_func); private: /// Create the next generation options based on previous step diff --git a/include/ai/types/message.h b/include/ai/types/message.h index fac77aa..f42ba4b 100644 --- a/include/ai/types/message.h +++ b/include/ai/types/message.h @@ -3,29 +3,148 @@ #include "enums.h" #include +#include #include +#include + namespace ai { +using JsonValue = nlohmann::json; + +// Base content part types +struct TextContentPart { + std::string text; + + explicit TextContentPart(std::string t) : text(std::move(t)) {} +}; + +struct ToolCallContentPart { + std::string id; + std::string tool_name; + JsonValue arguments; + + ToolCallContentPart(std::string i, std::string n, JsonValue a) + : id(std::move(i)), tool_name(std::move(n)), arguments(std::move(a)) {} +}; + +struct ToolResultContentPart { + std::string tool_call_id; + JsonValue result; + bool is_error = false; + + ToolResultContentPart(std::string id, JsonValue r, bool err = false) + : tool_call_id(std::move(id)), result(std::move(r)), is_error(err) {} +}; + +// Content part variant +using ContentPart = + std::variant; + +// Message content is now a vector of content parts +using MessageContent = std::vector; + struct Message { MessageRole role; - std::string content; + MessageContent content; - Message(MessageRole r, std::string c) : role(r), content(std::move(c)) {} + Message(MessageRole r, MessageContent c) : role(r), content(std::move(c)) {} - static Message system(const std::string& content) { - return Message(kMessageRoleSystem, content); + // Factory methods for convenience + static Message system(const std::string& text) { + return Message(kMessageRoleSystem, {TextContentPart{text}}); } - static Message user(const std::string& content) { - return Message(kMessageRoleUser, content); + static Message user(const std::string& text) { + return Message(kMessageRoleUser, {TextContentPart{text}}); } - static Message assistant(const std::string& content) { - return Message(kMessageRoleAssistant, content); + static Message assistant(const std::string& text) { + return Message(kMessageRoleAssistant, {TextContentPart{text}}); } - bool empty() const { return content.empty(); } + static Message assistant_with_tools( + const std::string& text, + const std::vector& tools) { + MessageContent content_parts; + + // Add text content if not empty + if (!text.empty()) { + content_parts.emplace_back(TextContentPart{text}); + } + + // Add tool calls + for (const auto& tool : tools) { + content_parts.emplace_back( + ToolCallContentPart{tool.id, tool.tool_name, tool.arguments}); + } + + return Message(kMessageRoleAssistant, std::move(content_parts)); + } + + static Message tool_results( + const std::vector& results) { + MessageContent content_parts; + for (const auto& result : results) { + content_parts.emplace_back(ToolResultContentPart{ + result.tool_call_id, result.result, result.is_error}); + } + return Message(kMessageRoleUser, std::move(content_parts)); + } + + // Helper methods + bool has_text() const { + return std::any_of(content.begin(), content.end(), + [](const ContentPart& part) { + return std::holds_alternative(part); + }); + } + + bool has_tool_calls() const { + return std::any_of( + content.begin(), content.end(), [](const ContentPart& part) { + return std::holds_alternative(part); + }); + } + + bool has_tool_results() const { + return std::any_of( + content.begin(), content.end(), [](const ContentPart& part) { + return std::holds_alternative(part); + }); + } + + std::string get_text() const { + std::string result; + for (const auto& part : content) { + if (const auto* text_part = std::get_if(&part)) { + result += text_part->text; + } + } + return result; + } + + std::vector get_tool_calls() const { + std::vector result; + for (const auto& part : content) { + if (const auto* tool_part = std::get_if(&part)) { + result.emplace_back(tool_part->id, tool_part->tool_name, + tool_part->arguments); + } + } + return result; + } + + std::vector get_tool_results() const { + std::vector result; + for (const auto& part : content) { + if (const auto* result_part = std::get_if(&part)) { + result.emplace_back(result_part->tool_call_id, result_part->result, + result_part->is_error); + } + } + return result; + } std::string roleToString() const { switch (role) { diff --git a/include/ai/types/tool.h b/include/ai/types/tool.h index 418e38e..4e0b1e4 100644 --- a/include/ai/types/tool.h +++ b/include/ai/types/tool.h @@ -15,13 +15,15 @@ namespace ai { -// Forward declarations -struct ToolCall; -struct ToolExecutionContext; - -/// JSON value type for tool parameters and results using JsonValue = nlohmann::json; +/// Context provided to tool execution functions +struct ToolExecutionContext { + std::string tool_call_id; + Messages messages; + std::optional> abort_signal; +}; + /// Tool execution function signature /// Parameters: (args, context) -> result using ToolExecuteFunction = @@ -33,13 +35,6 @@ using AsyncToolExecuteFunction = std::function(const JsonValue&, const ToolExecutionContext&)>; -/// Context provided to tool execution functions -struct ToolExecutionContext { - std::string tool_call_id; - Messages messages; - std::optional> abort_signal; -}; - struct Tool { std::string description; JsonValue parameters_schema; diff --git a/src/providers/anthropic/anthropic_request_builder.cpp b/src/providers/anthropic/anthropic_request_builder.cpp index d7c401f..8a29c45 100644 --- a/src/providers/anthropic/anthropic_request_builder.cpp +++ b/src/providers/anthropic/anthropic_request_builder.cpp @@ -23,8 +23,62 @@ nlohmann::json AnthropicRequestBuilder::build_request_json( // Use provided messages for (const auto& msg : options.messages) { nlohmann::json message; - message["role"] = utils::message_role_to_string(msg.role); - message["content"] = msg.content; + + // Handle different content types + if (msg.has_tool_results()) { + // Anthropic expects tool results as content arrays in user messages + message["role"] = "user"; + message["content"] = nlohmann::json::array(); + + for (const auto& result : msg.get_tool_results()) { + nlohmann::json tool_result_content; + tool_result_content["type"] = "tool_result"; + tool_result_content["tool_use_id"] = result.tool_call_id; + + if (!result.is_error) { + tool_result_content["content"] = result.result.dump(); + } else { + tool_result_content["content"] = result.result.dump(); + tool_result_content["is_error"] = true; + } + + message["content"].push_back(tool_result_content); + } + } else { + // Handle messages with text and/or tool calls + message["role"] = utils::message_role_to_string(msg.role); + + // Get text content and tool calls + std::string text_content = msg.get_text(); + auto tool_calls = msg.get_tool_calls(); + + // Anthropic expects content as array for mixed content or tool calls + if (!tool_calls.empty() || + (msg.role == kMessageRoleAssistant && !text_content.empty())) { + message["content"] = nlohmann::json::array(); + + // Add text content if present + if (!text_content.empty()) { + message["content"].push_back( + {{"type", "text"}, {"text", text_content}}); + } + + // Add tool use content + for (const auto& tool_call : tool_calls) { + message["content"].push_back({{"type", "tool_use"}, + {"id", tool_call.id}, + {"name", tool_call.tool_name}, + {"input", tool_call.arguments}}); + } + } else if (!text_content.empty()) { + // Simple text message (non-assistant or assistant with text only) + message["content"] = text_content; + } else { + // Empty message, skip + continue; + } + } + request["messages"].push_back(message); } } else { diff --git a/src/providers/anthropic/anthropic_response_parser.cpp b/src/providers/anthropic/anthropic_response_parser.cpp index acfd399..1235713 100644 --- a/src/providers/anthropic/anthropic_response_parser.cpp +++ b/src/providers/anthropic/anthropic_response_parser.cpp @@ -56,7 +56,7 @@ GenerateResult AnthropicResponseParser::parse_success_response( // Add assistant response to messages if (!result.text.empty()) { - result.response_messages.push_back({kMessageRoleAssistant, result.text}); + result.response_messages.push_back(Message::assistant(result.text)); } } diff --git a/src/providers/openai/openai_request_builder.cpp b/src/providers/openai/openai_request_builder.cpp index e270523..f2ed8be 100644 --- a/src/providers/openai/openai_request_builder.cpp +++ b/src/providers/openai/openai_request_builder.cpp @@ -15,9 +15,62 @@ nlohmann::json OpenAIRequestBuilder::build_request_json( if (!options.messages.empty()) { // Use provided messages for (const auto& msg : options.messages) { - request["messages"].push_back( - {{"role", utils::message_role_to_string(msg.role)}, - {"content", msg.content}}); + nlohmann::json message; + + // Handle different content types + if (msg.has_tool_results()) { + // OpenAI expects each tool result as a separate message with role + // "tool" + for (const auto& result : msg.get_tool_results()) { + nlohmann::json tool_message; + tool_message["role"] = "tool"; + tool_message["tool_call_id"] = result.tool_call_id; + + if (!result.is_error) { + tool_message["content"] = result.result.dump(); + } else { + tool_message["content"] = "Error: " + result.result.dump(); + } + + request["messages"].push_back(tool_message); + } + continue; // Skip adding the main message + } + + // Handle messages with text and/or tool calls + message["role"] = utils::message_role_to_string(msg.role); + + // Get text content (accumulate all text parts) + std::string text_content = msg.get_text(); + + // Get tool calls + auto tool_calls = msg.get_tool_calls(); + + // Set content - OpenAI expects both text and tool calls in the same + // message + if (!text_content.empty()) { + message["content"] = text_content; + } + + if (!tool_calls.empty()) { + nlohmann::json tool_calls_array = nlohmann::json::array(); + for (const auto& tool_call : tool_calls) { + tool_calls_array.push_back( + {{"id", tool_call.id}, + {"type", "function"}, + {"function", + {{"name", tool_call.tool_name}, + {"arguments", tool_call.arguments.dump()}}}}); + } + message["tool_calls"] = tool_calls_array; + } + + // Skip empty messages + if (text_content.empty() && tool_calls.empty()) { + continue; + } + + request["messages"].push_back(message); } } else { // Build from system + prompt diff --git a/src/providers/openai/openai_response_parser.cpp b/src/providers/openai/openai_response_parser.cpp index 9800c04..90c520a 100644 --- a/src/providers/openai/openai_response_parser.cpp +++ b/src/providers/openai/openai_response_parser.cpp @@ -92,8 +92,7 @@ GenerateResult OpenAIResponseParser::parse_success_response( // Add assistant response to messages if (!result.text.empty()) { - result.response_messages.push_back( - {kMessageRoleAssistant, result.text}); + result.response_messages.push_back(Message::assistant(result.text)); } } diff --git a/src/tools/multi_step_coordinator.cpp b/src/tools/multi_step_coordinator.cpp index c22fc0c..ed7a4b8 100644 --- a/src/tools/multi_step_coordinator.cpp +++ b/src/tools/multi_step_coordinator.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -7,7 +9,8 @@ namespace ai { GenerateResult MultiStepCoordinator::execute_multi_step( const GenerateOptions& initial_options, - std::function generate_func) { + const std::function& + generate_func) { if (initial_options.max_steps <= 1) { // Single step - just execute normally return generate_func(initial_options); @@ -19,10 +22,21 @@ GenerateResult MultiStepCoordinator::execute_multi_step( for (int step = 0; step < initial_options.max_steps; ++step) { ai::logger::log_debug("Executing step {} of {}", step + 1, initial_options.max_steps); + ai::logger::log_debug("Current messages count: {}", + current_options.messages.size()); + ai::logger::log_debug("System prompt: {}", + current_options.system.empty() + ? "empty" + : current_options.system.substr(0, 100) + "..."); // Execute the current step GenerateResult step_result = generate_func(current_options); + ai::logger::log_debug( + "Step {} result - text: '{}', tool_calls: {}, finish_reason: {}", + step + 1, step_result.text, step_result.tool_calls.size(), + static_cast(step_result.finish_reason)); + // Check for errors if (!step_result.is_success()) { // If this is the first step, return the error @@ -113,6 +127,8 @@ GenerateResult MultiStepCoordinator::execute_multi_step( } // Create next step options with tool results (including errors) + ai::logger::log_debug("Creating next step options with {} tool results", + tool_results.size()); current_options = create_next_step_options(initial_options, step_result, tool_results); } else { @@ -136,6 +152,10 @@ GenerateOptions MultiStepCoordinator::create_next_step_options( const GenerateOptions& base_options, const GenerateResult& previous_result, const std::vector& tool_results) { + ai::logger::log_debug( + "create_next_step_options: base messages count={}, tool_results count={}", + base_options.messages.size(), tool_results.size()); + GenerateOptions next_options = base_options; // Build the messages for the next step @@ -144,22 +164,30 @@ GenerateOptions MultiStepCoordinator::create_next_step_options( // If we started with a simple prompt, convert to messages if (!base_options.prompt.empty() && next_messages.empty()) { if (!base_options.system.empty()) { - next_messages.push_back(Message::system(base_options.system)); + next_messages.push_back(Message::user(base_options.system)); } + next_messages.push_back(Message::user(base_options.prompt)); } // Add assistant's response with tool calls if (previous_result.has_tool_calls()) { - // Create assistant message with tool calls (this would need proper - // formatting) For now, we'll add the text response - if (!previous_result.text.empty()) { - next_messages.push_back(Message::assistant(previous_result.text)); + // Convert ToolCall to ToolCallContent + std::vector tool_call_contents; + tool_call_contents.reserve(previous_result.tool_calls.size()); + for (const auto& tc : previous_result.tool_calls) { + tool_call_contents.emplace_back(tc.id, tc.tool_name, tc.arguments); } + // Create assistant message with tool calls + next_messages.push_back(Message::assistant_with_tools(previous_result.text, + tool_call_contents)); + // Add tool results as messages Messages tool_messages = tool_results_to_messages(previous_result.tool_calls, tool_results); + ai::logger::log_debug("Adding {} tool result messages", + tool_messages.size()); next_messages.insert(next_messages.end(), tool_messages.begin(), tool_messages.end()); } @@ -167,6 +195,10 @@ GenerateOptions MultiStepCoordinator::create_next_step_options( next_options.messages = next_messages; next_options.prompt = ""; // Clear prompt since we're using messages + ai::logger::log_debug( + "Final next_options: messages count={}, system prompt length={}", + next_options.messages.size(), next_options.system.length()); + return next_options; } @@ -181,30 +213,28 @@ Messages MultiStepCoordinator::tool_results_to_messages( results_by_id[result.tool_call_id] = result; } - // Add messages for each tool call and result + // Convert tool results to ToolResultContent + std::vector tool_result_contents; for (const auto& tool_call : tool_calls) { auto result_it = results_by_id.find(tool_call.id); if (result_it != results_by_id.end()) { const ToolResult& result = result_it->second; - // Create a message with the tool result - // In a real implementation, this would use proper tool message formatting - std::string content; if (result.is_success()) { - content = "Tool '" + tool_call.tool_name + - "' returned: " + result.result.dump(); + tool_result_contents.emplace_back(tool_call.id, result.result, false); } else { - content = "Tool '" + tool_call.tool_name + - "' failed: " + result.error_message(); + // For errors, create a simple error object + JsonValue error_json = {{"error", result.error_message()}}; + tool_result_contents.emplace_back(tool_call.id, error_json, true); } - - // For now, we'll use user messages to represent tool results - // In a complete implementation, this would be a proper "tool" role - // message - messages.push_back(Message::user(content)); } } + // Create a single user message with all tool results + if (!tool_result_contents.empty()) { + messages.push_back(Message::tool_results(tool_result_contents)); + } + return messages; } diff --git a/test-services/clickhouse/docker-compose.yaml b/test-services/clickhouse/docker-compose.yaml new file mode 100644 index 0000000..682283c --- /dev/null +++ b/test-services/clickhouse/docker-compose.yaml @@ -0,0 +1,13 @@ +services: + ai-sdk-cpp-clickhouse-server: + image: clickhouse/clickhouse-server + container_name: ai-sdk-cpp-clickhouse-server + ports: + - "18123:8123" + - "19000:9000" + environment: + CLICKHOUSE_PASSWORD: changeme + ulimits: + nofile: + soft: 262144 + hard: 262144 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7240ef1..b51ed53 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable(ai_tests integration/openai_integration_test.cpp integration/anthropic_integration_test.cpp integration/tool_calling_integration_test.cpp + integration/clickhouse_integration_test.cpp # Utility classes utils/mock_openai_client.cpp @@ -27,6 +28,7 @@ target_link_libraries(ai_tests GTest::gtest GTest::gtest_main GTest::gmock + ClickHouse::Client # ClickHouse C++ client for integration tests ) # Include directories for tests diff --git a/tests/integration/anthropic_integration_test.cpp b/tests/integration/anthropic_integration_test.cpp index bb5ffb8..d3a12fc 100644 --- a/tests/integration/anthropic_integration_test.cpp +++ b/tests/integration/anthropic_integration_test.cpp @@ -1,11 +1,9 @@ #include "../utils/test_fixtures.h" #include "ai/anthropic.h" -#include "ai/logger.h" #include "ai/types/generate_options.h" #include "ai/types/stream_options.h" #include -#include #include #include @@ -109,10 +107,9 @@ TEST_F(AnthropicIntegrationTest, ConversationWithMessages) { } Messages conversation = { - Message(kMessageRoleUser, "Hello!"), - Message(kMessageRoleAssistant, - "Hello! I can help you with weather information."), - Message(kMessageRoleUser, "What's the weather like today?")}; + Message::user("Hello!"), + Message::assistant("Hello! I can help you with weather information."), + Message::user("What's the weather like today?")}; GenerateOptions options(ai::anthropic::models::kDefaultModel, std::move(conversation)); diff --git a/tests/integration/clickhouse_integration_test.cpp b/tests/integration/clickhouse_integration_test.cpp new file mode 100644 index 0000000..a9dbfde --- /dev/null +++ b/tests/integration/clickhouse_integration_test.cpp @@ -0,0 +1,424 @@ +#include "ai/anthropic.h" +#include "ai/openai.h" +#include "ai/tools.h" +#include "ai/types/generate_options.h" +#include "ai/types/tool.h" + +#include +#include +#include + +#include +#include + +namespace ai { +namespace test { + +// Utility function to generate random suffix for table names +std::string generateRandomSuffix(size_t length = 8) { + static const char alphabet[] = "abcdefghijklmnopqrstuvwxyz"; + static std::random_device rd; + static std::mt19937 gen(rd()); + static std::uniform_int_distribution<> dis(0, sizeof(alphabet) - 2); + + std::string suffix; + suffix.reserve(length); + for (size_t i = 0; i < length; ++i) { + suffix += alphabet[dis(gen)]; + } + return suffix; +} + +// ClickHouse connection parameters +const std::string kClickhouseHost = "localhost"; +const int kClickhousePort = 19000; // Native protocol port from docker-compose +const std::string kClickhouseUser = "default"; +const std::string kClickhousePassword = "changeme"; + +// System prompt for SQL generation +const std::string kSystemPrompt = + R"(You are a ClickHouse SQL code generator. Your ONLY job is to output SQL statements wrapped in tags. + +TOOLS AVAILABLE: +- list_databases(): Lists all databases +- list_tables_in_database(database): Lists tables in a specific database +- get_schema_for_table(database, table): Gets schema for existing tables only + +CRITICAL RULES: +1. ALWAYS output SQL wrapped in tags, no matter what +2. NEVER ask questions or request clarification +3. NEVER explain your SQL or add any other text +4. NEVER use markdown code blocks (```sql) +5. For CREATE TABLE requests, ALWAYS generate a reasonable schema based on the table name + +RESPONSE FORMAT - Must be EXACTLY: + +[SQL STATEMENT] + + +TASK-SPECIFIC INSTRUCTIONS: + +For "create a table named X for github events": +- Generate CREATE TABLE with columns: id String, type String, actor_id UInt64, actor_login String, repo_id UInt64, repo_name String, created_at DateTime, payload String +- Use MergeTree() engine with ORDER BY (created_at, repo_id) + +For "insert 3 rows into users table": +- If table exists, check schema with tools +- Generate INSERT with columns: id, name, age +- Use sample data like (1, 'Alice', 28), (2, 'Bob', 35), (3, 'Charlie', 42) + +For "show all users from X" or "find all Y from X": +- Generate appropriate SELECT statement + +IMPORTANT: Even if you use tools and find the database/table doesn't exist, still generate the SQL as requested. The test will handle any errors.)"; + +// Tool implementations +class ClickHouseTools { + public: + explicit ClickHouseTools(std::shared_ptr client) + : client_(std::move(client)) {} + + Tool createListDatabasesTool() { + return create_simple_tool( + "list_databases", + "List all available databases in the ClickHouse instance", {}, + [this](const JsonValue& args, const ToolExecutionContext& context) { + try { + std::vector databases; + client_->Select( + "SELECT name FROM system.databases ORDER BY name", + [&databases](const clickhouse::Block& block) { + for (size_t i = 0; i < block.GetRowCount(); ++i) { + databases.push_back(std::string( + block[0]->As()->At(i))); + } + }); + + std::stringstream result; + result << "Found " << databases.size() << " databases:\n"; + for (const auto& db : databases) { + result << "- " << db << "\n"; + } + + return JsonValue{{"result", result.str()}, + {"databases", databases}}; + } catch (const std::exception& e) { + throw std::runtime_error(std::string("Error: ") + e.what()); + } + }); + } + + Tool createListTablesInDatabaseTool() { + return create_simple_tool( + "list_tables_in_database", "List all tables in a specific database", + {{"database", "string"}}, + [this](const JsonValue& args, const ToolExecutionContext& context) { + try { + std::string database = args["database"].get(); + std::vector tables; + + std::string query = + "SELECT name FROM system.tables WHERE database = '" + database + + "' ORDER BY name"; + client_->Select(query, [&tables](const clickhouse::Block& block) { + for (size_t i = 0; i < block.GetRowCount(); ++i) { + tables.push_back(std::string( + block[0]->As()->At(i))); + } + }); + + std::stringstream result; + result << "Found " << tables.size() << " tables in database '" + << database << "':\n"; + for (const auto& table : tables) { + result << "- " << table << "\n"; + } + if (tables.empty()) { + result << "(No tables found in this database)\n"; + } + + return JsonValue{{"result", result.str()}, + {"database", database}, + {"tables", tables}}; + } catch (const std::exception& e) { + throw std::runtime_error(std::string("Error: ") + e.what()); + } + }); + } + + Tool createGetSchemaForTableTool() { + return create_simple_tool( + "get_schema_for_table", + "Get the CREATE TABLE statement (schema) for a specific table", + {{"database", "string"}, {"table", "string"}}, + [this](const JsonValue& args, const ToolExecutionContext& context) { + try { + std::string database = args["database"].get(); + std::string table = args["table"].get(); + + std::string query = + "SHOW CREATE TABLE `" + database + "`.`" + table + "`"; + std::string schema; + + client_->Select(query, [&schema](const clickhouse::Block& block) { + if (block.GetRowCount() > 0) { + schema = block[0]->As()->At(0); + } + }); + + if (schema.empty()) { + throw std::runtime_error("Could not retrieve schema for " + + database + "." + table); + } + + return JsonValue{{"result", "Schema for " + database + "." + table + + ":\n" + schema}, + {"database", database}, + {"table", table}, + {"schema", schema}}; + } catch (const std::exception& e) { + throw std::runtime_error(std::string("Error: ") + e.what()); + } + }); + } + + private: + std::shared_ptr client_; +}; + +// Base test fixture +class ClickHouseIntegrationTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + // Generate random suffix for table names to allow parallel test execution + table_suffix_ = generateRandomSuffix(); + // Use unique database name for each test to allow parallel execution + test_db_name_ = "test_db_" + generateRandomSuffix(); + + // Initialize ClickHouse client + clickhouse::ClientOptions options; + options.SetHost(kClickhouseHost); + options.SetPort(kClickhousePort); + options.SetUser(kClickhouseUser); + options.SetPassword(kClickhousePassword); + + clickhouse_client_ = std::make_shared(options); + + // Create test database with unique name + clickhouse_client_->Execute("CREATE DATABASE IF NOT EXISTS " + + test_db_name_); + + // Initialize tools + tools_helper_ = std::make_unique(clickhouse_client_); + tools_ = { + {"list_databases", tools_helper_->createListDatabasesTool()}, + {"list_tables_in_database", + tools_helper_->createListTablesInDatabaseTool()}, + {"get_schema_for_table", tools_helper_->createGetSchemaForTableTool()}}; + + // Initialize AI client based on provider + provider_type_ = GetParam(); + if (provider_type_ == "openai") { + const char* api_key = std::getenv("OPENAI_API_KEY"); + if (!api_key) { + use_real_api_ = false; + return; + } + client_ = std::make_shared(openai::create_client()); + model_ = openai::models::kGpt4o; + } else if (provider_type_ == "anthropic") { + const char* api_key = std::getenv("ANTHROPIC_API_KEY"); + if (!api_key) { + use_real_api_ = false; + return; + } + client_ = std::make_shared(anthropic::create_client()); + model_ = anthropic::models::kClaudeSonnet4; + } + use_real_api_ = true; + } + + void TearDown() override { + // Clean up test database + if (clickhouse_client_) { + clickhouse_client_->Execute("DROP DATABASE IF EXISTS " + test_db_name_); + } + } + + std::string executeSQLGeneration(const std::string& prompt) { + GenerateOptions options(model_, prompt); + options.system = kSystemPrompt; + options.tools = tools_; + options.max_tokens = 1000; + options.max_steps = 5; + options.temperature = 1; + + auto result = client_->generate_text(options); + + if (!result.is_success()) { + throw std::runtime_error("Failed to generate SQL: " + + result.error_message()); + } + + // Extract SQL from the final response text + std::string response_text = result.text; + if (!result.steps.empty()) { + // Get the last non-empty step's text + for (auto it = result.steps.rbegin(); it != result.steps.rend(); ++it) { + if (!it->text.empty()) { + response_text = it->text; + break; + } + } + } + + return extractSQL(response_text); + } + + // Extract SQL from tags + std::string extractSQL(const std::string& input) { + size_t start = input.find(""); + if (start == std::string::npos) { + throw std::runtime_error("No tag found in response: " + input); + } + start += 5; // Length of "" + + size_t end = input.find("", start); + if (end == std::string::npos) { + throw std::runtime_error("No closing tag found in response: " + + input); + } + + std::string sql = input.substr(start, end - start); + + // Trim whitespace + size_t first = sql.find_first_not_of(" \t\n\r"); + size_t last = sql.find_last_not_of(" \t\n\r"); + if (first != std::string::npos && last != std::string::npos) { + sql = sql.substr(first, last - first + 1); + } + + return sql; + } + + std::shared_ptr clickhouse_client_; + std::unique_ptr tools_helper_; + ToolSet tools_; + std::shared_ptr client_; + std::string model_; + std::string provider_type_; + std::string table_suffix_; + std::string test_db_name_; + bool use_real_api_ = false; +}; + +// Tests +TEST_P(ClickHouseIntegrationTest, CreateTableForGithubEvents) { + if (!use_real_api_) { + GTEST_SKIP() << "No API key set for " << GetParam(); + } + + std::string table_name = "github_events_" + table_suffix_; + + // Clean up any existing table + clickhouse_client_->Execute("DROP TABLE IF EXISTS " + test_db_name_ + "." + + table_name); + clickhouse_client_->Execute("DROP TABLE IF EXISTS default." + table_name); + + std::string sql = executeSQLGeneration("create a table named " + table_name + + " for github events in " + + test_db_name_ + " database"); + + // Execute the generated SQL + ASSERT_NO_THROW(clickhouse_client_->Execute("USE " + test_db_name_)); + ASSERT_NO_THROW(clickhouse_client_->Execute(sql)); + + // Verify table was created + bool table_exists = false; + clickhouse_client_->Select( + "EXISTS TABLE " + test_db_name_ + "." + table_name, + [&table_exists](const clickhouse::Block& block) { + if (block.GetRowCount() > 0) { + table_exists = block[0]->As()->At(0); + } + }); + EXPECT_TRUE(table_exists); +} + +TEST_P(ClickHouseIntegrationTest, InsertAndQueryData) { + if (!use_real_api_) { + GTEST_SKIP() << "No API key set for " << GetParam(); + } + + std::string table_name = "users_" + table_suffix_; + + // First create a users table + clickhouse_client_->Execute("USE " + test_db_name_); + clickhouse_client_->Execute("DROP TABLE IF EXISTS " + table_name); + clickhouse_client_->Execute( + "CREATE TABLE " + table_name + + " (id UInt64, name String, age UInt8) ENGINE = MergeTree() ORDER BY id"); + + // Generate INSERT SQL + std::string insert_sql = + executeSQLGeneration("insert 3 rows with random values into " + + table_name + " table in " + test_db_name_); + ASSERT_NO_THROW(clickhouse_client_->Execute(insert_sql)); + + // Generate SELECT SQL + std::string select_sql = executeSQLGeneration( + "show all users from " + table_name + " in " + test_db_name_); + + // Verify data was inserted + size_t row_count = 0; + clickhouse_client_->Select(select_sql, + [&row_count](const clickhouse::Block& block) { + row_count += block.GetRowCount(); + }); + EXPECT_EQ(row_count, 3); +} + +TEST_P(ClickHouseIntegrationTest, ExploreExistingSchema) { + if (!use_real_api_) { + GTEST_SKIP() << "No API key set for " << GetParam(); + } + + std::string orders_table = "orders_" + table_suffix_; + std::string products_table = "products_" + table_suffix_; + + // Create some test tables + clickhouse_client_->Execute("USE " + test_db_name_); + clickhouse_client_->Execute("DROP TABLE IF EXISTS " + orders_table); + clickhouse_client_->Execute("DROP TABLE IF EXISTS " + products_table); + clickhouse_client_->Execute( + "CREATE TABLE " + orders_table + + " (id UInt64, user_id UInt64, amount Decimal(10,2), status String) " + "ENGINE = MergeTree() ORDER BY id"); + clickhouse_client_->Execute("CREATE TABLE " + products_table + + " (id UInt64, name String, price Decimal(10,2), " + "stock UInt32) ENGINE = MergeTree() ORDER BY id"); + + // Add some data to orders table + clickhouse_client_->Execute( + "INSERT INTO " + orders_table + + " VALUES (1, 100, 50.25, 'completed'), (2, 101, 150.00, 'pending'), (3, " + "100, 200.50, 'completed')"); + + // Test schema exploration + std::string sql = executeSQLGeneration( + "find all orders with amount greater than 100 from " + orders_table + + " in " + test_db_name_); + + // The generated SQL should reference the correct table and columns + EXPECT_TRUE(sql.find(orders_table) != std::string::npos); + EXPECT_TRUE(sql.find("amount") != std::string::npos); + EXPECT_TRUE(sql.find("100") != std::string::npos); +} + +// Instantiate tests for both providers +INSTANTIATE_TEST_SUITE_P(Providers, + ClickHouseIntegrationTest, + ::testing::Values("openai", "anthropic")); + +} // namespace test +} // namespace ai diff --git a/tests/integration/openai_integration_test.cpp b/tests/integration/openai_integration_test.cpp index 89896a4..3f7fae7 100644 --- a/tests/integration/openai_integration_test.cpp +++ b/tests/integration/openai_integration_test.cpp @@ -1,11 +1,9 @@ #include "../utils/test_fixtures.h" -#include "ai/logger.h" #include "ai/openai.h" #include "ai/types/generate_options.h" #include "ai/types/stream_options.h" #include -#include #include #include @@ -109,11 +107,10 @@ TEST_F(OpenAIIntegrationTest, ConversationWithMessages) { } Messages conversation = { - Message(kMessageRoleSystem, "You are a helpful weather assistant."), - Message(kMessageRoleUser, "Hello!"), - Message(kMessageRoleAssistant, - "Hello! I can help you with weather information."), - Message(kMessageRoleUser, "What's the weather like today?")}; + Message::system("You are a helpful weather assistant."), + Message::user("Hello!"), + Message::assistant("Hello! I can help you with weather information."), + Message::user("What's the weather like today?")}; GenerateOptions options(ai::openai::models::kGpt4oMini, std::move(conversation)); diff --git a/tests/unit/types_test.cpp b/tests/unit/types_test.cpp index 4ccca83..455ec19 100644 --- a/tests/unit/types_test.cpp +++ b/tests/unit/types_test.cpp @@ -46,8 +46,7 @@ TEST_F(GenerateOptionsTest, ConstructorWithSystemPrompt) { } TEST_F(GenerateOptionsTest, ConstructorWithMessages) { - Messages messages = {Message(kMessageRoleUser, "Hello"), - Message(kMessageRoleAssistant, "Hi there!")}; + Messages messages = {Message::user("Hello"), Message::assistant("Hi there!")}; GenerateOptions options("gpt-4o", std::move(messages)); EXPECT_EQ(options.model, "gpt-4o"); @@ -70,7 +69,7 @@ TEST_F(GenerateOptionsTest, ValidationEmptyPromptAndMessages) { } TEST_F(GenerateOptionsTest, ValidationWithValidMessages) { - Messages messages = {Message(kMessageRoleUser, "Hello")}; + Messages messages = {Message::user("Hello")}; GenerateOptions options("gpt-4o", std::move(messages)); EXPECT_TRUE(options.is_valid()); @@ -168,43 +167,42 @@ TEST_F(GenerateResultTest, MetadataFields) { TEST_F(GenerateResultTest, ResponseMessages) { GenerateResult result("Response", kFinishReasonStop, Usage{}); - result.response_messages.push_back( - Message(kMessageRoleAssistant, "Response")); + result.response_messages.push_back(Message::assistant("Response")); EXPECT_EQ(result.response_messages.size(), 1); EXPECT_EQ(result.response_messages[0].role, kMessageRoleAssistant); - EXPECT_EQ(result.response_messages[0].content, "Response"); + EXPECT_EQ(result.response_messages[0].get_text(), "Response"); } // Message Tests class MessageTest : public AITestFixture {}; TEST_F(MessageTest, Constructor) { - Message msg(kMessageRoleUser, "Hello, world!"); + Message msg = Message::user("Hello, world!"); EXPECT_EQ(msg.role, kMessageRoleUser); - EXPECT_EQ(msg.content, "Hello, world!"); + EXPECT_EQ(msg.get_text(), "Hello, world!"); } TEST_F(MessageTest, SystemMessage) { - Message msg(kMessageRoleSystem, "You are a helpful assistant."); + Message msg = Message::system("You are a helpful assistant."); EXPECT_EQ(msg.role, kMessageRoleSystem); - EXPECT_EQ(msg.content, "You are a helpful assistant."); + EXPECT_EQ(msg.get_text(), "You are a helpful assistant."); } TEST_F(MessageTest, AssistantMessage) { - Message msg(kMessageRoleAssistant, "How can I help you?"); + Message msg = Message::assistant("How can I help you?"); EXPECT_EQ(msg.role, kMessageRoleAssistant); - EXPECT_EQ(msg.content, "How can I help you?"); + EXPECT_EQ(msg.get_text(), "How can I help you?"); } TEST_F(MessageTest, EmptyContent) { - Message msg(kMessageRoleUser, ""); + Message msg = Message::user(""); EXPECT_EQ(msg.role, kMessageRoleUser); - EXPECT_TRUE(msg.content.empty()); + EXPECT_TRUE(msg.get_text().empty()); } // Usage Tests @@ -382,20 +380,20 @@ TEST_F(TypeIntegrationTest, MessageConversationFlow) { Messages conversation; // Start conversation - conversation.emplace_back(kMessageRoleSystem, "You are helpful"); - conversation.emplace_back(kMessageRoleUser, "Hello"); + conversation.push_back(Message::system("You are helpful")); + conversation.push_back(Message::user("Hello")); EXPECT_EQ(conversation.size(), 2); // Add assistant response - conversation.emplace_back(kMessageRoleAssistant, "Hi there!"); + conversation.push_back(Message::assistant("Hi there!")); // Continue conversation - conversation.emplace_back(kMessageRoleUser, "How are you?"); + conversation.push_back(Message::user("How are you?")); EXPECT_EQ(conversation.size(), 4); EXPECT_EQ(conversation.back().role, kMessageRoleUser); - EXPECT_EQ(conversation.back().content, "How are you?"); + EXPECT_EQ(conversation.back().get_text(), "How are you?"); } } // namespace test diff --git a/tests/utils/test_fixtures.cpp b/tests/utils/test_fixtures.cpp index 001fe99..d45814e 100644 --- a/tests/utils/test_fixtures.cpp +++ b/tests/utils/test_fixtures.cpp @@ -84,15 +84,15 @@ Messages OpenAITestFixture::createSampleConversation() { } Message OpenAITestFixture::createUserMessage(const std::string& content) { - return Message(kMessageRoleUser, content); + return Message::user(content); } Message OpenAITestFixture::createAssistantMessage(const std::string& content) { - return Message(kMessageRoleAssistant, content); + return Message::assistant(content); } Message OpenAITestFixture::createSystemMessage(const std::string& content) { - return Message(kMessageRoleSystem, content); + return Message::system(content); } // AnthropicTestFixture implementation @@ -152,17 +152,17 @@ Messages AnthropicTestFixture::createSampleAnthropicConversation() { Message AnthropicTestFixture::createAnthropicUserMessage( const std::string& content) { - return Message(kMessageRoleUser, content); + return Message::user(content); } Message AnthropicTestFixture::createAnthropicAssistantMessage( const std::string& content) { - return Message(kMessageRoleAssistant, content); + return Message::assistant(content); } Message AnthropicTestFixture::createAnthropicSystemMessage( const std::string& content) { - return Message(kMessageRoleSystem, content); + return Message::system(content); } // TestDataGenerator implementation @@ -176,7 +176,7 @@ std::vector TestDataGenerator::generateOptionsVariations() { variations.emplace_back("gpt-4o", "You are helpful", "User question"); // With messages - Messages msgs = {Message(kMessageRoleUser, "Hello")}; + Messages msgs = {Message::user("Hello")}; variations.emplace_back("gpt-4o", std::move(msgs)); // With all parameters diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 9f405dc..85c4427 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -23,4 +23,5 @@ add_subdirectory(concurrentqueue-cmake) # Add googletest (only if BUILD_TESTS is ON) if(BUILD_TESTS) add_subdirectory(googletest-cmake) + add_subdirectory(clickhouse-cmake) endif() \ No newline at end of file diff --git a/third_party/clickhouse-cmake/CMakeLists.txt b/third_party/clickhouse-cmake/CMakeLists.txt new file mode 100644 index 0000000..1b42e06 --- /dev/null +++ b/third_party/clickhouse-cmake/CMakeLists.txt @@ -0,0 +1,35 @@ +# ClickHouse C++ Client CMake wrapper +# This wrapper provides a consistent interface for the ClickHouse C++ client library + +# Only build ClickHouse client for tests +if(NOT BUILD_TESTS) + return() +endif() + +# Add ClickHouse client as subdirectory +add_subdirectory(../clickhouse-cpp clickhouse-cpp EXCLUDE_FROM_ALL) + +# Create an interface target that properly exposes ClickHouse client +add_library(clickhouse-cpp-client INTERFACE) + +# Link to the actual ClickHouse library +target_link_libraries(clickhouse-cpp-client + INTERFACE + clickhouse-cpp-lib +) + +# Set properties for consistency with other dependencies +set_target_properties(clickhouse-cpp-lib PROPERTIES + CXX_STANDARD 20 + CXX_STANDARD_REQUIRED ON +) + +# Disable warnings for third-party code +if(MSVC) + target_compile_options(clickhouse-cpp-lib PRIVATE /W0) +else() + target_compile_options(clickhouse-cpp-lib PRIVATE -w) +endif() + +# Create alias for consistent naming +add_library(ClickHouse::Client ALIAS clickhouse-cpp-client) \ No newline at end of file diff --git a/third_party/clickhouse-cpp b/third_party/clickhouse-cpp new file mode 160000 index 0000000..cae657a --- /dev/null +++ b/third_party/clickhouse-cpp @@ -0,0 +1 @@ +Subproject commit cae657a672ff09b715d7127b13eb25d63bea01d4