diff --git a/.gitignore b/.gitignore index 1bcc7a0..82e4c59 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ cozo.sqlite /phoenix_call_graph.json INPUT_FORMAT.md docs/tickets +/tickets/ .claude scratch /AGENTS.md diff --git a/CLAUDE.md b/CLAUDE.md index 1fdd3a9..30eef68 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -60,7 +60,7 @@ This is a Rust CLI tool for querying call graph data stored in a CozoDB SQLite d - `commands/mod.rs` - `Command` enum, `Execute` trait, `CommonArgs`, dispatch via enum_dispatch - `commands//` - Individual command modules (27 commands, directory structure) - `output.rs` - `OutputFormat` enum, `Outputable` and `TableFormatter` traits -- `dedup.rs` - Deduplication utilities (`sort_and_deduplicate`, `DeduplicationFilter`) +- `dedup.rs` - Deduplication utilities (`sort_and_deduplicate`, `deduplicate_retain`) - `utils.rs` - Presentation helpers (`group_by_module`, `convert_to_module_groups`, `format_type_definition`) - `test_macros.rs` - Declarative test macros for CLI, execute, and output tests diff --git a/cli/src/commands/god_modules/execute.rs b/cli/src/commands/god_modules/execute.rs index f3ca9e7..ae4a5b5 100644 --- a/cli/src/commands/god_modules/execute.rs +++ b/cli/src/commands/god_modules/execute.rs @@ -4,7 +4,7 @@ use serde::Serialize; use super::GodModulesCmd; use crate::commands::Execute; -use db::queries::hotspots::{find_hotspots, get_function_counts, get_module_loc, HotspotKind}; +use db::queries::hotspots::{get_function_counts, get_module_connectivity, get_module_loc}; use db::types::{ModuleCollectionResult, ModuleGroup}; /// A single god module entry @@ -37,30 +37,14 @@ impl Execute for GodModulesCmd { self.common.regex, )?; - // Get hotspot data (incoming/outgoing calls per function) - let hotspots = find_hotspots( + // Get module-level connectivity (aggregated at database level) + let module_connectivity = get_module_connectivity( db, - HotspotKind::Total, - self.module.as_deref(), &self.common.project, + self.module.as_deref(), self.common.regex, - u32::MAX, // Get all hotspots to aggregate connectivity - false, // Don't exclude generated functions - false, // Don't require outgoing calls )?; - // Aggregate connectivity (incoming/outgoing) per module - let mut module_connectivity: std::collections::HashMap = - std::collections::HashMap::new(); - - for hotspot in hotspots { - let entry = module_connectivity - .entry(hotspot.module) - .or_insert((0, 0)); - entry.0 += hotspot.incoming; - entry.1 += hotspot.outgoing; - } - // Build god modules: filter by thresholds and sort by total connectivity // Tuple: (module_name, func_count, loc, incoming, outgoing) let mut god_modules: Vec<(String, i64, i64, i64, i64)> = Vec::new(); diff --git a/cli/src/commands/god_modules/execute_tests.rs b/cli/src/commands/god_modules/execute_tests.rs new file mode 100644 index 0000000..289929e --- /dev/null +++ b/cli/src/commands/god_modules/execute_tests.rs @@ -0,0 +1,334 @@ +//! Execute tests for god_modules command. + +#[cfg(test)] +mod tests { + use super::super::GodModulesCmd; + use crate::commands::CommonArgs; + use crate::commands::Execute; + use rstest::{fixture, rstest}; + + crate::shared_fixture! { + fixture_name: populated_db, + fixture_type: call_graph, + project: "test_project", + } + + // ========================================================================= + // Core functionality tests + // ========================================================================= + + #[rstest] + fn test_god_modules_basic(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 1, + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + assert_eq!(result.kind_filter, Some("god".to_string())); + // Should have some modules that meet the criteria + assert!(result.total_items > 0); + } + + #[rstest] + fn test_god_modules_respects_function_count_threshold(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 100, // Very high threshold + min_loc: 1, + min_total: 1, + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + // With high threshold, might have no results + for item in &result.items { + let entry = &item.entries[0]; + assert!(entry.function_count >= 100, "Module {} has {} functions, expected >= 100", item.name, entry.function_count); + } + } + + #[rstest] + fn test_god_modules_respects_loc_threshold(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1000, // High LoC threshold + min_total: 1, + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + for item in &result.items { + let entry = &item.entries[0]; + assert!(entry.loc >= 1000, "Module {} has {} LoC, expected >= 1000", item.name, entry.loc); + } + } + + #[rstest] + fn test_god_modules_respects_total_threshold(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 10, // Require at least 10 total calls + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + for item in &result.items { + let entry = &item.entries[0]; + assert!(entry.total >= 10, "Module {} has {} total calls, expected >= 10", item.name, entry.total); + assert_eq!(entry.total, entry.incoming + entry.outgoing, "Total should equal incoming + outgoing"); + } + } + + #[rstest] + fn test_god_modules_sorted_by_connectivity(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 1, + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + if result.items.len() > 1 { + // Check that results are sorted by total connectivity (descending) + for i in 0..result.items.len() - 1 { + let current_total = result.items[i].entries[0].total; + let next_total = result.items[i + 1].entries[0].total; + assert!( + current_total >= next_total, + "Results not sorted: {} (total={}) should be >= {} (total={})", + result.items[i].name, current_total, + result.items[i + 1].name, next_total + ); + } + } + } + + #[rstest] + fn test_god_modules_with_module_filter(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 1, + module: Some("Accounts".to_string()), + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + // All results should contain "Accounts" + for item in &result.items { + assert!(item.name.contains("Accounts"), "Module {} doesn't contain 'Accounts'", item.name); + } + } + + #[rstest] + fn test_god_modules_respects_limit(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 1, + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 2, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + assert!(result.items.len() <= 2, "Expected at most 2 results, got {}", result.items.len()); + } + + #[rstest] + fn test_god_modules_entry_structure(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 1, + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + for item in &result.items { + // Each module should have exactly one entry + assert_eq!(item.entries.len(), 1, "Module {} should have exactly one entry", item.name); + + let entry = &item.entries[0]; + // All counts should be non-negative + assert!(entry.function_count >= 0); + assert!(entry.loc >= 0); + assert!(entry.incoming >= 0); + assert!(entry.outgoing >= 0); + assert!(entry.total >= 0); + + // Total should equal incoming + outgoing + assert_eq!(entry.total, entry.incoming + entry.outgoing); + + // function_count should be populated + assert_eq!(item.function_count, Some(entry.function_count)); + } + } + + #[rstest] + fn test_god_modules_all_thresholds_filter_everything(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 999999, // Impossible threshold + min_loc: 999999, + min_total: 999999, + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + // Should return empty results, not error + assert_eq!(result.total_items, 0); + assert!(result.items.is_empty()); + } + + #[rstest] + fn test_god_modules_module_pattern_no_match(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 1, + module: Some("NonExistentModule".to_string()), + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + // Should return empty results + assert_eq!(result.total_items, 0); + assert!(result.items.is_empty()); + assert_eq!(result.module_pattern, "NonExistentModule"); + } + + #[rstest] + fn test_god_modules_wrong_project(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 1, + module: None, + common: CommonArgs { + project: "wrong_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + // Should return empty results for non-existent project + assert_eq!(result.total_items, 0); + assert!(result.items.is_empty()); + } + + #[rstest] + fn test_god_modules_result_metadata(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 1, + module: Some("Accounts".to_string()), + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + // Verify result metadata is correct + assert_eq!(result.module_pattern, "Accounts"); + assert_eq!(result.function_pattern, None); + assert_eq!(result.kind_filter, Some("god".to_string())); + assert_eq!(result.name_filter, None); + } + + #[rstest] + fn test_god_modules_combined_thresholds(populated_db: db::DbInstance) { + let cmd = GodModulesCmd { + min_functions: 2, // Multiple filters + min_loc: 10, + min_total: 2, + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }; + let result = cmd.execute(&populated_db).expect("Execute should succeed"); + + // All results must satisfy ALL three criteria + for item in &result.items { + let entry = &item.entries[0]; + assert!(entry.function_count >= 2, "Module {} has {} functions, expected >= 2", item.name, entry.function_count); + assert!(entry.loc >= 10, "Module {} has {} LoC, expected >= 10", item.name, entry.loc); + assert!(entry.total >= 2, "Module {} has {} total, expected >= 2", item.name, entry.total); + } + } + + // ========================================================================= + // Error handling tests + // ========================================================================= + + crate::execute_empty_db_test! { + cmd_type: GodModulesCmd, + cmd: GodModulesCmd { + min_functions: 1, + min_loc: 1, + min_total: 1, + module: None, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 20, + }, + }, + } +} diff --git a/cli/src/commands/god_modules/mod.rs b/cli/src/commands/god_modules/mod.rs index b62914e..b081695 100644 --- a/cli/src/commands/god_modules/mod.rs +++ b/cli/src/commands/god_modules/mod.rs @@ -1,4 +1,5 @@ mod execute; +mod execute_tests; mod output; use std::error::Error; diff --git a/cli/src/commands/reverse_trace/execute.rs b/cli/src/commands/reverse_trace/execute.rs index aa96d56..21bec8d 100644 --- a/cli/src/commands/reverse_trace/execute.rs +++ b/cli/src/commands/reverse_trace/execute.rs @@ -28,8 +28,6 @@ fn build_reverse_trace_result( // Process depth 1 (direct callers of target function) if let Some(depth1_steps) = by_depth.get(&1) { - let mut filter = crate::dedup::DeduplicationFilter::new(); - for step in depth1_steps { let caller_key = ( step.caller_module.clone(), @@ -38,13 +36,16 @@ fn build_reverse_trace_result( 1i64, ); - // Add caller as root entry if not already added - if filter.should_process(caller_key.clone()) { + // Add caller as root entry if not already added (use HashMap for dedup check) + if !entry_index_map.contains_key(&caller_key) { let entry_idx = entries.len(); + // Insert into HashMap before pushing (reuse caller_key) + entry_index_map.insert(caller_key.clone(), entry_idx); + entries.push(TraceEntry { - module: step.caller_module.clone(), - function: step.caller_function.clone(), - arity: step.caller_arity, + module: caller_key.0, + function: caller_key.1, + arity: caller_key.2, kind: step.caller_kind.clone(), start_line: step.caller_start_line, end_line: step.caller_end_line, @@ -53,7 +54,6 @@ fn build_reverse_trace_result( line: step.line, parent_index: None, }); - entry_index_map.insert(caller_key, entry_idx); } } } @@ -61,8 +61,6 @@ fn build_reverse_trace_result( // Process deeper levels (additional callers) for depth in 2..=max_depth as i64 { if let Some(depth_steps) = by_depth.get(&depth) { - let mut filter = crate::dedup::DeduplicationFilter::new(); - for step in depth_steps { let caller_key = ( step.caller_module.clone(), @@ -71,31 +69,35 @@ fn build_reverse_trace_result( depth, ); - // Find parent index (the callee at previous depth, which is what called this caller) - let parent_key = ( - step.callee_module.clone(), - step.callee_function.clone(), - step.callee_arity, - depth - 1, - ); - - let parent_index = entry_index_map.get(&parent_key).copied(); - - if filter.should_process(caller_key.clone()) && parent_index.is_some() { - let entry_idx = entries.len(); - entries.push(TraceEntry { - module: step.caller_module.clone(), - function: step.caller_function.clone(), - arity: step.caller_arity, - kind: step.caller_kind.clone(), - start_line: step.caller_start_line, - end_line: step.caller_end_line, - file: step.file.clone(), - depth, - line: step.line, - parent_index, - }); - entry_index_map.insert(caller_key, entry_idx); + // Check if we already have this caller at this depth using HashMap + if !entry_index_map.contains_key(&caller_key) { + // Find parent index using HashMap (O(1) lookup) + let parent_key = ( + step.callee_module.clone(), + step.callee_function.clone(), + step.callee_arity, + depth - 1, + ); + let parent_index = entry_index_map.get(&parent_key).copied(); + + if parent_index.is_some() { + let entry_idx = entries.len(); + // Insert into HashMap before pushing (reuse caller_key) + entry_index_map.insert(caller_key.clone(), entry_idx); + + entries.push(TraceEntry { + module: caller_key.0, + function: caller_key.1, + arity: caller_key.2, + kind: step.caller_kind.clone(), + start_line: step.caller_start_line, + end_line: step.caller_end_line, + file: step.file.clone(), + depth, + line: step.line, + parent_index, + }); + } } } } diff --git a/cli/src/commands/search/execute_tests.rs b/cli/src/commands/search/execute_tests.rs index 2bb5361..abf6391 100644 --- a/cli/src/commands/search/execute_tests.rs +++ b/cli/src/commands/search/execute_tests.rs @@ -21,11 +21,11 @@ mod tests { test_name: test_search_modules_all, fixture: populated_db, cmd: SearchCmd { - pattern: "MyApp".to_string(), + pattern: ".*MyApp.*".to_string(), // Use regex for substring matching kind: SearchKind::Modules, common: CommonArgs { project: "test_project".to_string(), - regex: false, + regex: true, limit: 100, }, }, @@ -40,11 +40,11 @@ mod tests { test_name: test_search_functions_all, fixture: populated_db, cmd: SearchCmd { - pattern: "user".to_string(), + pattern: ".*user.*".to_string(), // Use regex for substring matching kind: SearchKind::Functions, common: CommonArgs { project: "test_project".to_string(), - regex: false, + regex: true, limit: 100, }, }, @@ -59,11 +59,11 @@ mod tests { test_name: test_search_functions_specific, fixture: populated_db, cmd: SearchCmd { - pattern: "get".to_string(), + pattern: ".*get.*".to_string(), // Use regex for substring matching kind: SearchKind::Functions, common: CommonArgs { project: "test_project".to_string(), - regex: false, + regex: true, limit: 100, }, }, @@ -113,6 +113,65 @@ mod tests { }, } + // Exact module match + crate::execute_test! { + test_name: test_search_modules_exact_match, + fixture: populated_db, + cmd: SearchCmd { + pattern: "MyApp.Accounts".to_string(), + kind: SearchKind::Modules, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 100, + }, + }, + assertions: |result| { + assert_eq!(result.modules.len(), 1); + assert_eq!(result.modules[0].name, "MyApp.Accounts"); + }, + } + + // Exact function match + crate::execute_test! { + test_name: test_search_functions_exact_match, + fixture: populated_db, + cmd: SearchCmd { + pattern: "get_user".to_string(), + kind: SearchKind::Functions, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 100, + }, + }, + assertions: |result| { + assert_eq!(result.total_functions, Some(2)); + // All functions should be exactly named get_user + for module in &result.function_modules { + for f in &module.functions { + assert_eq!(f.name, "get_user"); + } + } + }, + } + + // Exact match doesn't find partial matches + crate::execute_no_match_test! { + test_name: test_search_functions_exact_no_partial, + fixture: populated_db, + cmd: SearchCmd { + pattern: "user".to_string(), // Won't match get_user, list_users, etc. + kind: SearchKind::Functions, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 100, + }, + }, + empty_field: function_modules, + } + // ========================================================================= // No match / empty result tests // ========================================================================= @@ -171,11 +230,11 @@ mod tests { test_name: test_search_with_limit, fixture: populated_db, cmd: SearchCmd { - pattern: "user".to_string(), + pattern: ".*user.*".to_string(), // Use regex for substring matching kind: SearchKind::Functions, common: CommonArgs { project: "test_project".to_string(), - regex: false, + regex: true, limit: 1, }, }, @@ -201,4 +260,69 @@ mod tests { }, }, } + + #[rstest] + fn test_search_modules_invalid_regex(populated_db: db::DbInstance) { + use crate::commands::Execute; + + let cmd = SearchCmd { + pattern: "[invalid".to_string(), // Unclosed bracket + kind: SearchKind::Modules, + common: CommonArgs { + project: "test_project".to_string(), + regex: true, + limit: 100, + }, + }; + + let result = cmd.execute(&populated_db); + assert!(result.is_err(), "Should reject invalid regex pattern"); + + let err = result.unwrap_err(); + let msg = err.to_string(); + assert!(msg.contains("Invalid regex pattern"), "Error should mention 'Invalid regex pattern': {}", msg); + assert!(msg.contains("[invalid"), "Error should show the pattern: {}", msg); + } + + #[rstest] + fn test_search_functions_invalid_regex(populated_db: db::DbInstance) { + use crate::commands::Execute; + + let cmd = SearchCmd { + pattern: "*invalid".to_string(), // Invalid repetition + kind: SearchKind::Functions, + common: CommonArgs { + project: "test_project".to_string(), + regex: true, + limit: 100, + }, + }; + + let result = cmd.execute(&populated_db); + assert!(result.is_err(), "Should reject invalid regex pattern"); + + let err = result.unwrap_err(); + let msg = err.to_string(); + assert!(msg.contains("Invalid regex pattern"), "Error should mention 'Invalid regex pattern': {}", msg); + assert!(msg.contains("*invalid"), "Error should show the pattern: {}", msg); + } + + #[rstest] + fn test_search_invalid_regex_non_regex_mode_works(populated_db: db::DbInstance) { + use crate::commands::Execute; + + // Even invalid regex patterns should work in non-regex mode (treated as literals) + let cmd = SearchCmd { + pattern: "[invalid".to_string(), + kind: SearchKind::Modules, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, // Not using regex mode + limit: 100, + }, + }; + + let result = cmd.execute(&populated_db); + assert!(result.is_ok(), "Should accept any pattern in non-regex mode: {:?}", result.err()); + } } diff --git a/cli/src/commands/struct_usage/execute_tests.rs b/cli/src/commands/struct_usage/execute_tests.rs index 9874cc1..9e7b3e3 100644 --- a/cli/src/commands/struct_usage/execute_tests.rs +++ b/cli/src/commands/struct_usage/execute_tests.rs @@ -24,12 +24,12 @@ mod tests { test_name: test_struct_usage_finds_user_type, fixture: populated_db, cmd: StructUsageCmd { - pattern: "User.t".to_string(), + pattern: ".*User\\.t.*".to_string(), // Use regex for substring matching module: None, by_module: false, common: CommonArgs { project: "test_project".to_string(), - regex: false, + regex: true, limit: 100, }, }, @@ -49,12 +49,12 @@ mod tests { test_name: test_struct_usage_with_module_filter, fixture: populated_db, cmd: StructUsageCmd { - pattern: "User.t".to_string(), + pattern: ".*User\\.t.*".to_string(), // Use regex for substring matching module: Some("MyApp.Accounts".to_string()), by_module: false, common: CommonArgs { project: "test_project".to_string(), - regex: false, + regex: true, limit: 100, }, }, @@ -80,12 +80,12 @@ mod tests { test_name: test_struct_usage_by_module, fixture: populated_db, cmd: StructUsageCmd { - pattern: "User.t".to_string(), + pattern: ".*User\\.t.*".to_string(), // Use regex for substring matching module: None, by_module: true, common: CommonArgs { project: "test_project".to_string(), - regex: false, + regex: true, limit: 100, }, }, @@ -165,12 +165,12 @@ mod tests { test_name: test_struct_usage_with_limit, fixture: populated_db, cmd: StructUsageCmd { - pattern: "User.t".to_string(), + pattern: ".*User\\.t.*".to_string(), // Use regex for substring matching module: None, by_module: false, common: CommonArgs { project: "test_project".to_string(), - regex: false, + regex: true, limit: 1, }, }, @@ -208,6 +208,57 @@ mod tests { }, } + // Exact type match - search for integer() in inputs + crate::execute_test! { + test_name: test_struct_usage_exact_match, + fixture: populated_db, + cmd: StructUsageCmd { + pattern: "integer()".to_string(), + module: None, + by_module: false, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 100, + }, + }, + assertions: |result| { + match result { + StructUsageOutput::Detailed(ref detail) => { + assert!(detail.total_items > 0, "Should find exact match for integer()"); + // Verify we found functions using integer() + assert!(detail.items.len() >= 1, "Should find integer() in at least one module"); + } + _ => panic!("Expected Detailed output"), + } + }, + } + + // Exact match doesn't find partial matches + crate::execute_test! { + test_name: test_struct_usage_exact_no_partial, + fixture: populated_db, + cmd: StructUsageCmd { + pattern: "integer".to_string(), // Won't match "integer()" - missing parens + module: None, + by_module: false, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 100, + }, + }, + assertions: |result| { + match result { + StructUsageOutput::Detailed(ref detail) => { + assert_eq!(detail.total_items, 0, "Exact match should not find partial matches"); + assert!(detail.items.is_empty()); + } + _ => panic!("Expected Detailed output"), + } + }, + } + // ========================================================================= // Error handling tests // ========================================================================= diff --git a/cli/src/commands/trace/execute.rs b/cli/src/commands/trace/execute.rs index a4a28d3..a86f34e 100644 --- a/cli/src/commands/trace/execute.rs +++ b/cli/src/commands/trace/execute.rs @@ -84,34 +84,34 @@ fn build_trace_result( for depth in 2..=max_depth as i64 { if let Some(depth_calls) = by_depth.remove(&depth) { for call in depth_calls { - // Check if we already have this callee at this depth - let existing = entries.iter().position(|e| { - e.depth == depth - && e.module == call.callee.module.as_ref() - && e.function == call.callee.name.as_ref() - && e.arity == call.callee.arity - }); - - if existing.is_none() { - // Find parent index using references (no cloning) - let parent_index = entries.iter().position(|e| { - e.depth == depth - 1 - && e.module == call.caller.module.as_ref() - && e.function == call.caller.name.as_ref() - && e.arity == call.caller.arity - }); + // Check if we already have this callee at this depth using HashMap + let callee_key = ( + call.callee.module.to_string(), + call.callee.name.to_string(), + call.callee.arity, + depth, + ); + + if !entry_index_map.contains_key(&callee_key) { + // Find parent index using HashMap (O(1) lookup) + let parent_key = ( + call.caller.module.to_string(), + call.caller.name.to_string(), + call.caller.arity, + depth - 1, + ); + let parent_index = entry_index_map.get(&parent_key).copied(); if parent_index.is_some() { let entry_idx = entries.len(); + // Insert into HashMap before pushing (reuse callee_key) + entry_index_map.insert(callee_key.clone(), entry_idx); + // Convert from Rc to String for storage - let module = call.callee.module.to_string(); - let function = call.callee.name.to_string(); - let arity = call.callee.arity; - entry_index_map.insert((module.clone(), function.clone(), arity, depth), entry_idx); entries.push(TraceEntry { - module, - function, - arity, + module: callee_key.0, + function: callee_key.1, + arity: callee_key.2, kind: call.callee.kind.as_deref().unwrap_or("").to_string(), start_line: call.callee.start_line.unwrap_or(0), end_line: call.callee.end_line.unwrap_or(0), diff --git a/cli/src/commands/unused/execute_tests.rs b/cli/src/commands/unused/execute_tests.rs index 6a04534..c3211cd 100644 --- a/cli/src/commands/unused/execute_tests.rs +++ b/cli/src/commands/unused/execute_tests.rs @@ -47,13 +47,13 @@ mod tests { test_name: test_unused_with_module_filter, fixture: populated_db, cmd: UnusedCmd { - module: Some("Accounts".to_string()), + module: Some(".*Accounts.*".to_string()), // Use regex for substring matching private_only: false, public_only: false, exclude_generated: false, common: CommonArgs { project: "test_project".to_string(), - regex: false, + regex: true, limit: 100, }, }, @@ -82,6 +82,48 @@ mod tests { }, } + // Exact module match - MyApp.Accounts has 2 uncalled functions + crate::execute_test! { + test_name: test_unused_exact_module_match, + fixture: populated_db, + cmd: UnusedCmd { + module: Some("MyApp.Accounts".to_string()), + private_only: false, + public_only: false, + exclude_generated: false, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 100, + }, + }, + assertions: |result| { + assert_eq!(result.total_items, 2); + // Verify all results are from MyApp.Accounts + for module_group in &result.items { + assert_eq!(module_group.name, "MyApp.Accounts"); + } + }, + } + + // Exact match doesn't find partial matches + crate::execute_no_match_test! { + test_name: test_unused_exact_no_partial, + fixture: populated_db, + cmd: UnusedCmd { + module: Some("Accounts".to_string()), // Won't match "MyApp.Accounts" + private_only: false, + public_only: false, + exclude_generated: false, + common: CommonArgs { + project: "test_project".to_string(), + regex: false, + limit: 100, + }, + }, + empty_field: items, + } + // ========================================================================= // No match / empty result tests // ========================================================================= diff --git a/cli/src/dedup.rs b/cli/src/dedup.rs index 6fd3e55..160ea54 100644 --- a/cli/src/dedup.rs +++ b/cli/src/dedup.rs @@ -1,8 +1,8 @@ //! Deduplication utilities for reducing code duplication across commands. //! -//! This module provides reusable patterns for deduplicating collections using different strategies: -//! - Strategy A: HashSet retain pattern (deduplicate_retain) - for in-place deduplication after sorting -//! - Strategy B: HashSet prevention pattern (DeduplicationFilter) - for preventing duplicates during collection +//! This module provides reusable patterns for deduplicating collections: +//! - HashSet retain pattern (deduplicate_retain) - for in-place deduplication after sorting +//! - Combined sort and deduplicate operation (sort_and_deduplicate) use std::collections::HashSet; use std::hash::Hash; @@ -69,44 +69,3 @@ where items.sort_by(sort_cmp); deduplicate_retain(items, dedup_key); } - -/// Strategy B: HashSet prevention pattern - check before adding -/// -/// Use this when collecting items and you want to prevent duplicates from being added -/// in the first place, without needing to sort or post-process. -/// -/// # Example -/// ```ignore -/// let mut filter = DeduplicationFilter::new(); -/// for entry in entries { -/// if filter.should_process(entry_key) { -/// // Add entry to result -/// } -/// } -/// ``` -#[derive(Debug)] -pub struct DeduplicationFilter { - processed: HashSet, -} - -impl DeduplicationFilter { - /// Create a new empty deduplication filter - pub fn new() -> Self { - Self { - processed: HashSet::new(), - } - } - - /// Check if a key should be processed (inserted into the set) - /// - /// Returns true if the key is new and was successfully inserted, false if it was already present. - pub fn should_process(&mut self, key: K) -> bool { - self.processed.insert(key) - } -} - -impl Default for DeduplicationFilter { - fn default() -> Self { - Self::new() - } -} diff --git a/cli/src/utils.rs b/cli/src/utils.rs index b156fef..a699e86 100644 --- a/cli/src/utils.rs +++ b/cli/src/utils.rs @@ -1,5 +1,6 @@ //! Utility functions for code search CLI output and presentation. +use std::borrow::Cow; use std::collections::BTreeMap; use regex::Regex; use db::types::{ModuleGroup, Call}; @@ -215,15 +216,15 @@ where /// * `definition` - The raw type definition string from the database /// /// # Returns -/// The formatted type definition string -pub fn format_type_definition(definition: &str) -> String { +/// The formatted type definition (borrowed if unchanged, owned if formatted) +pub fn format_type_definition(definition: &str) -> Cow { // Check if this is a struct type definition if let Some(formatted) = try_format_struct_type(definition) { - return formatted; + return Cow::Owned(formatted); } // Return as-is if no transformation needed - definition.to_string() + Cow::Borrowed(definition) } /// Attempts to format a struct type definition. diff --git a/db/src/db.rs b/db/src/db.rs index 354acf3..72e9f59 100644 --- a/db/src/db.rs +++ b/db/src/db.rs @@ -50,7 +50,7 @@ pub enum DbError { MissingColumn { name: String }, } -pub type Params = BTreeMap; +pub type Params = BTreeMap<&'static str, DataValue>; pub fn open_db(path: &Path) -> Result> { DbInstance::new("sqlite", path, "").map_err(|e| { @@ -75,7 +75,13 @@ pub fn run_query( script: &str, params: Params, ) -> Result> { - db.run_script(script, params, ScriptMutability::Mutable) + // Convert &'static str keys to String for CozoDB + let params_owned: BTreeMap = params + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect(); + + db.run_script(script, params_owned, ScriptMutability::Mutable) .map_err(|e| { Box::new(DbError::QueryFailed { message: format!("{:?}", e), diff --git a/db/src/lib.rs b/db/src/lib.rs index 491f1e4..1b3346c 100644 --- a/db/src/lib.rs +++ b/db/src/lib.rs @@ -24,4 +24,4 @@ pub use types::{ TraceDirection, SharedStr }; -pub use query_builders::{ConditionBuilder, OptionalConditionBuilder}; +pub use query_builders::{ConditionBuilder, OptionalConditionBuilder, validate_regex_pattern, validate_regex_patterns}; diff --git a/db/src/queries/accepts.rs b/db/src/queries/accepts.rs index 4ffcf01..a321c57 100644 --- a/db/src/queries/accepts.rs +++ b/db/src/queries/accepts.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum AcceptsError { @@ -32,27 +33,22 @@ pub fn find_accepts( module_pattern: Option<&str>, limit: u32, ) -> Result, Box> { - // Build inputs string filter - let match_fn = if use_regex { - "regex_matches(inputs_string, $pattern)" - } else { - "str_includes(inputs_string, $pattern)" - }; + validate_regex_patterns(use_regex, &[Some(pattern), module_pattern])?; - // Build module filter - let module_filter = match module_pattern { - Some(_) if use_regex => "regex_matches(module, $module_pattern)", - Some(_) => "str_includes(module, $module_pattern)", - None => "true", - }; + // Build conditions using query builders + let pattern_cond = ConditionBuilder::new("inputs_string", "pattern").build(use_regex); + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); let script = format!( r#" ?[project, module, name, arity, inputs_string, return_string, line] := *specs{{project, module, name, arity, inputs_string, return_string, line}}, project == $project, - {match_fn}, - {module_filter} + {pattern_cond} + {module_cond} :order module, name, arity :limit {limit} @@ -60,12 +56,12 @@ pub fn find_accepts( ); let mut params = Params::new(); - params.insert("pattern".to_string(), DataValue::Str(pattern.into())); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("pattern", DataValue::Str(pattern.into())); + params.insert("project", DataValue::Str(project.into())); if let Some(mod_pat) = module_pattern { params.insert( - "module_pattern".to_string(), + "module_pattern", DataValue::Str(mod_pat.into()), ); } diff --git a/db/src/queries/calls.rs b/db/src/queries/calls.rs index bb668cf..c8642b2 100644 --- a/db/src/queries/calls.rs +++ b/db/src/queries/calls.rs @@ -11,7 +11,7 @@ use thiserror::Error; use crate::db::{extract_call_from_row, run_query, CallRowLayout, Params}; use crate::types::Call; -use crate::query_builders::{ConditionBuilder, OptionalConditionBuilder}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum CallsError { @@ -64,6 +64,8 @@ pub fn find_calls( use_regex: bool, limit: u32, ) -> Result, Box> { + validate_regex_patterns(use_regex, &[Some(module_pattern), function_pattern])?; + let (module_field, function_field, arity_field) = direction.filter_fields(); let order_clause = direction.order_clause(); @@ -103,19 +105,19 @@ pub fn find_calls( let mut params = Params::new(); params.insert( - "module_pattern".to_string(), + "module_pattern", DataValue::Str(module_pattern.into()), ); if let Some(fn_pat) = function_pattern { params.insert( - "function_pattern".to_string(), + "function_pattern", DataValue::Str(fn_pat.into()), ); } if let Some(a) = arity { - params.insert("arity".to_string(), DataValue::from(a)); + params.insert("arity", DataValue::from(a)); } - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| CallsError::QueryFailed { message: e.to_string(), diff --git a/db/src/queries/clusters.rs b/db/src/queries/clusters.rs index 892503c..186add8 100644 --- a/db/src/queries/clusters.rs +++ b/db/src/queries/clusters.rs @@ -29,7 +29,7 @@ pub fn get_module_calls(db: &cozo::DbInstance, project: &str) -> Result Result, Box> { - // Build optional module filter - let module_filter = match module_pattern { - Some(_) if use_regex => ", regex_matches(module, $module_pattern)".to_string(), - Some(_) => ", str_includes(module, $module_pattern)".to_string(), - None => String::new(), - }; + validate_regex_patterns(use_regex, &[module_pattern])?; + + // Build conditions using query builders + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); // Build optional generated filter let generated_filter = if exclude_generated { @@ -59,7 +61,7 @@ pub fn find_complexity_metrics( complexity >= $min_complexity, max_nesting_depth >= $min_depth, lines = end_line - start_line + 1 - {module_filter} + {module_cond} {generated_filter} :order -complexity, module, name @@ -68,11 +70,11 @@ pub fn find_complexity_metrics( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); - params.insert("min_complexity".to_string(), DataValue::from(min_complexity)); - params.insert("min_depth".to_string(), DataValue::from(min_depth)); + params.insert("project", DataValue::Str(project.into())); + params.insert("min_complexity", DataValue::from(min_complexity)); + params.insert("min_depth", DataValue::from(min_depth)); if let Some(pattern) = module_pattern { - params.insert("module_pattern".to_string(), DataValue::Str(pattern.into())); + params.insert("module_pattern", DataValue::Str(pattern.into())); } let rows = run_query(db, &script, params).map_err(|e| ComplexityError::QueryFailed { diff --git a/db/src/queries/cycles.rs b/db/src/queries/cycles.rs index 62568c8..e49d5ad 100644 --- a/db/src/queries/cycles.rs +++ b/db/src/queries/cycles.rs @@ -53,7 +53,7 @@ pub fn find_cycle_edges( "#.to_string(); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params)?; diff --git a/db/src/queries/dependencies.rs b/db/src/queries/dependencies.rs index 7266197..246ff07 100644 --- a/db/src/queries/dependencies.rs +++ b/db/src/queries/dependencies.rs @@ -94,10 +94,10 @@ pub fn find_dependencies( let mut params = Params::new(); params.insert( - "module_pattern".to_string(), + "module_pattern", DataValue::Str(module_pattern.into()), ); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| DependencyError::QueryFailed { message: e.to_string(), diff --git a/db/src/queries/duplicates.rs b/db/src/queries/duplicates.rs index 7d3adbd..67c73bf 100644 --- a/db/src/queries/duplicates.rs +++ b/db/src/queries/duplicates.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum DuplicatesError { @@ -31,15 +32,16 @@ pub fn find_duplicates( use_exact: bool, exclude_generated: bool, ) -> Result, Box> { + validate_regex_patterns(use_regex, &[module_pattern])?; + // Choose hash field based on exact flag let hash_field = if use_exact { "source_sha" } else { "ast_sha" }; - // Build optional module filter - let module_filter = match module_pattern { - Some(_) if use_regex => ", regex_matches(module, $module_pattern)".to_string(), - Some(_) => ", str_includes(module, $module_pattern)".to_string(), - None => String::new(), - }; + // Build conditions using query builders + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); // Build optional generated filter let generated_filter = if exclude_generated { @@ -64,7 +66,7 @@ pub fn find_duplicates( hash_counts[{hash_field}, cnt], cnt > 1, project == $project - {module_filter} + {module_cond} {generated_filter} :order {hash_field}, module, name, arity @@ -72,9 +74,9 @@ pub fn find_duplicates( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); if let Some(pattern) = module_pattern { - params.insert("module_pattern".to_string(), DataValue::Str(pattern.into())); + params.insert("module_pattern", DataValue::Str(pattern.into())); } let rows = run_query(db, &script, params).map_err(|e| DuplicatesError::QueryFailed { diff --git a/db/src/queries/file.rs b/db/src/queries/file.rs index 30f7dff..d26ef51 100644 --- a/db/src/queries/file.rs +++ b/db/src/queries/file.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder}; #[derive(Error, Debug)] pub enum FileError { @@ -37,12 +38,10 @@ pub fn find_functions_in_module( use_regex: bool, limit: u32, ) -> Result, Box> { - // Build module filter - let module_filter = if use_regex { - "regex_matches(module, $module_pattern)" - } else { - "module == $module_pattern" - }; + validate_regex_patterns(use_regex, &[Some(module_pattern)])?; + + // Build module filter using query builder + let module_filter = ConditionBuilder::new("module", "module_pattern").build(use_regex); // Query to find all functions in matching modules let script = format!( @@ -58,8 +57,8 @@ pub fn find_functions_in_module( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); - params.insert("module_pattern".to_string(), DataValue::Str(module_pattern.into())); + params.insert("project", DataValue::Str(project.into())); + params.insert("module_pattern", DataValue::Str(module_pattern.into())); let rows = run_query(db, &script, params).map_err(|e| FileError::QueryFailed { message: e.to_string(), diff --git a/db/src/queries/function.rs b/db/src/queries/function.rs index 2ff9f94..3cdae44 100644 --- a/db/src/queries/function.rs +++ b/db/src/queries/function.rs @@ -5,7 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, extract_string_or, run_query, Params}; -use crate::query_builders::{ConditionBuilder, OptionalConditionBuilder}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum FunctionError { @@ -33,6 +33,8 @@ pub fn find_functions( use_regex: bool, limit: u32, ) -> Result, Box> { + validate_regex_patterns(use_regex, &[Some(module_pattern), Some(function_pattern)])?; + // Build query conditions using helpers let module_cond = ConditionBuilder::new("module", "module_pattern").build(use_regex); let function_cond = ConditionBuilder::new("name", "function_pattern") @@ -57,12 +59,12 @@ pub fn find_functions( ); let mut params = Params::new(); - params.insert("module_pattern".to_string(), DataValue::Str(module_pattern.into())); - params.insert("function_pattern".to_string(), DataValue::Str(function_pattern.into())); + params.insert("module_pattern", DataValue::Str(module_pattern.into())); + params.insert("function_pattern", DataValue::Str(function_pattern.into())); if let Some(a) = arity { - params.insert("arity".to_string(), DataValue::from(a)); + params.insert("arity", DataValue::from(a)); } - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| FunctionError::QueryFailed { message: e.to_string(), diff --git a/db/src/queries/hotspots.rs b/db/src/queries/hotspots.rs index 71c73ac..3993ee1 100644 --- a/db/src/queries/hotspots.rs +++ b/db/src/queries/hotspots.rs @@ -6,6 +6,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_f64, extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, OptionalConditionBuilder}; /// What type of hotspots to find #[derive(Debug, Clone, Copy, Default, ValueEnum)] @@ -45,11 +46,13 @@ pub fn get_module_loc( module_pattern: Option<&str>, use_regex: bool, ) -> Result, Box> { - let module_filter = match module_pattern { - Some(_) if use_regex => ", regex_matches(module, $module_pattern)".to_string(), - Some(_) => ", str_includes(module, $module_pattern)".to_string(), - None => String::new(), - }; + validate_regex_patterns(use_regex, &[module_pattern])?; + + // Build conditions using query builders + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); let script = format!( r#" @@ -58,7 +61,7 @@ pub fn get_module_loc( *function_locations{{project, module, start_line, end_line}}, project == $project, lines = end_line - start_line + 1 - {module_filter} + {module_cond} ?[module, loc] := module_loc[module, loc] @@ -68,9 +71,9 @@ pub fn get_module_loc( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); if let Some(pattern) = module_pattern { - params.insert("module_pattern".to_string(), DataValue::Str(pattern.into())); + params.insert("module_pattern", DataValue::Str(pattern.into())); } let rows = run_query(db, &script, params).map_err(|e| HotspotsError::QueryFailed { @@ -96,19 +99,20 @@ pub fn get_function_counts( module_pattern: Option<&str>, use_regex: bool, ) -> Result, Box> { - // Build optional module filter - let module_filter = match module_pattern { - Some(_) if use_regex => ", regex_matches(module, $module_pattern)".to_string(), - Some(_) => ", str_includes(module, $module_pattern)".to_string(), - None => String::new(), - }; + validate_regex_patterns(use_regex, &[module_pattern])?; + + // Build conditions using query builders + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); let script = format!( r#" func_counts[module, count(name)] := *function_locations{{project, module, name}}, project == $project - {module_filter} + {module_cond} ?[module, func_count] := func_counts[module, func_count] @@ -118,9 +122,9 @@ pub fn get_function_counts( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); if let Some(pattern) = module_pattern { - params.insert("module_pattern".to_string(), DataValue::Str(pattern.into())); + params.insert("module_pattern", DataValue::Str(pattern.into())); } let rows = run_query(db, &script, params).map_err(|e| HotspotsError::QueryFailed { @@ -139,6 +143,114 @@ pub fn get_function_counts( Ok(counts) } +/// Get module-level connectivity (aggregated incoming/outgoing calls) +/// +/// Returns a HashMap of module name -> (incoming, outgoing) call counts. +/// This aggregates function-level hotspots to module level at the database layer, +/// avoiding the need to fetch all function hotspots. +pub fn get_module_connectivity( + db: &cozo::DbInstance, + project: &str, + module_pattern: Option<&str>, + use_regex: bool, +) -> Result, Box> { + validate_regex_patterns(use_regex, &[module_pattern])?; + + // Build conditions using query builders + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); + + // Aggregate incoming/outgoing calls at module level + let script = format!( + r#" + # Get canonical function names (no generated functions) + canonical[module, function] := + *calls{{project, callee_module, callee_function}}, + *function_locations{{project, module: callee_module, name: callee_function, generated_by}}, + project == $project, + module = callee_module, + function = callee_function, + generated_by == "" + + # Distinct outgoing calls per function + distinct_outgoing[caller_module, canonical_name, callee_module, callee_function] := + *calls{{project, caller_module, caller_function, callee_module, callee_function}}, + canonical[caller_module, canonical_name], + project == $project, + (caller_function == canonical_name or starts_with(caller_function, concat(canonical_name, "/"))) + + # Count outgoing calls per function + outgoing_counts[module, function, count(callee_function)] := + distinct_outgoing[module, function, callee_module, callee_function] + + # Distinct incoming calls per function + distinct_incoming[callee_module, callee_function, caller_module, caller_function] := + *calls{{project, caller_module, caller_function, callee_module, callee_function}}, + canonical[callee_module, callee_function], + project == $project + + # Count incoming calls per function + incoming_counts[module, function, count(caller_function)] := + distinct_incoming[module, function, caller_module, caller_function] + + # Function stats with defaults for missing counts + # Functions with both counts + func_stats[module, function, incoming, outgoing] := + canonical[module, function], + incoming_counts[module, function, incoming], + outgoing_counts[module, function, outgoing] + + # Functions with only incoming (no outgoing) + func_stats[module, function, incoming, outgoing] := + canonical[module, function], + incoming_counts[module, function, incoming], + not outgoing_counts[module, function, _], + outgoing = 0 + + # Functions with only outgoing (no incoming) + func_stats[module, function, incoming, outgoing] := + canonical[module, function], + not incoming_counts[module, function, _], + outgoing_counts[module, function, outgoing], + incoming = 0 + + # Aggregate to module level + module_connectivity[module, sum(incoming), sum(outgoing)] := + func_stats[module, function, incoming, outgoing] + {module_cond} + + ?[module, incoming, outgoing] := + module_connectivity[module, incoming, outgoing] + + :order -incoming + "#, + ); + + let mut params = Params::new(); + params.insert("project", DataValue::Str(project.into())); + if let Some(pattern) = module_pattern { + params.insert("module_pattern", DataValue::Str(pattern.into())); + } + + let rows = run_query(db, &script, params).map_err(|e| HotspotsError::QueryFailed { + message: e.to_string(), + })?; + + let mut connectivity = std::collections::HashMap::new(); + for row in rows.rows { + if row.len() >= 3 + && let Some(module) = extract_string(&row[0]) { + let incoming = extract_i64(&row[1], 0); + let outgoing = extract_i64(&row[2], 0); + connectivity.insert(module, (incoming, outgoing)); + } + } + + Ok(connectivity) +} + pub fn find_hotspots( db: &cozo::DbInstance, kind: HotspotKind, @@ -149,12 +261,13 @@ pub fn find_hotspots( exclude_generated: bool, require_outgoing: bool, ) -> Result, Box> { - // Build optional module filter - let module_filter = match module_pattern { - Some(_) if use_regex => ", regex_matches(module, $module_pattern)".to_string(), - Some(_) => ", str_includes(module, $module_pattern)".to_string(), - None => String::new(), - }; + validate_regex_patterns(use_regex, &[module_pattern])?; + + // Build conditions using query builders + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); // Build optional generated filter let generated_filter = if exclude_generated { @@ -227,7 +340,7 @@ pub fn find_hotspots( outgoing_counts[module, function, outgoing], total = incoming + outgoing, ratio = if(outgoing == 0, 9999.0, incoming / outgoing) - {module_filter} + {module_cond} {outgoing_filter} # Functions with only incoming (no outgoing) - leaf nodes @@ -238,7 +351,7 @@ pub fn find_hotspots( outgoing = 0, total = incoming, ratio = 9999.0 - {module_filter} + {module_cond} {outgoing_filter} # Functions with only outgoing (no incoming) @@ -248,7 +361,7 @@ pub fn find_hotspots( incoming = 0, total = outgoing, ratio = 0.0 - {module_filter} + {module_cond} :order -{order_by}, module, function :limit {limit} @@ -256,9 +369,9 @@ pub fn find_hotspots( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); if let Some(pattern) = module_pattern { - params.insert("module_pattern".to_string(), DataValue::Str(pattern.into())); + params.insert("module_pattern", DataValue::Str(pattern.into())); } let rows = run_query(db, &script, params).map_err(|e| HotspotsError::QueryFailed { @@ -288,3 +401,256 @@ pub fn find_hotspots( Ok(results) } + +#[cfg(test)] +mod tests { + use super::*; + use rstest::{fixture, rstest}; + + #[fixture] + fn populated_db() -> cozo::DbInstance { + crate::test_utils::call_graph_db("default") + } + + #[rstest] + fn test_get_module_connectivity_returns_results(populated_db: cozo::DbInstance) { + let result = get_module_connectivity( + &populated_db, + "default", + None, + false, + ); + + if let Err(ref e) = result { + eprintln!("Error: {}", e); + } + assert!(result.is_ok()); + let connectivity = result.unwrap(); + assert!(!connectivity.is_empty()); + } + + #[rstest] + fn test_get_module_connectivity_has_valid_counts(populated_db: cozo::DbInstance) { + let connectivity = get_module_connectivity( + &populated_db, + "default", + None, + false, + ).unwrap(); + + // All modules should have non-negative counts + for (module, (incoming, outgoing)) in &connectivity { + assert!(*incoming >= 0, "Module {} has negative incoming: {}", module, incoming); + assert!(*outgoing >= 0, "Module {} has negative outgoing: {}", module, outgoing); + } + } + + #[rstest] + fn test_get_module_connectivity_with_module_filter(populated_db: cozo::DbInstance) { + let connectivity = get_module_connectivity( + &populated_db, + "default", + Some("Accounts"), + false, + ).unwrap(); + + // All modules should contain "Accounts" + for module in connectivity.keys() { + assert!(module.contains("Accounts"), "Module {} doesn't contain 'Accounts'", module); + } + } + + #[rstest] + fn test_get_module_connectivity_aggregates_correctly(populated_db: cozo::DbInstance) { + // Get module-level connectivity + let module_conn = get_module_connectivity( + &populated_db, + "default", + None, + false, + ).unwrap(); + + // Get function-level hotspots + let function_hotspots = find_hotspots( + &populated_db, + HotspotKind::Total, + None, + "default", + false, + u32::MAX, + false, + false, + ).unwrap(); + + // Manually aggregate function hotspots by module + let mut manual_agg: std::collections::HashMap = std::collections::HashMap::new(); + for hotspot in function_hotspots { + let entry = manual_agg.entry(hotspot.module).or_insert((0, 0)); + entry.0 += hotspot.incoming; + entry.1 += hotspot.outgoing; + } + + // The two approaches should produce the same results + assert_eq!(module_conn.len(), manual_agg.len(), "Different number of modules"); + + for (module, (conn_in, conn_out)) in &module_conn { + let (manual_in, manual_out) = manual_agg.get(module) + .expect(&format!("Module {} not found in manual aggregation", module)); + assert_eq!(conn_in, manual_in, "Module {} has different incoming: {} vs {}", module, conn_in, manual_in); + assert_eq!(conn_out, manual_out, "Module {} has different outgoing: {} vs {}", module, conn_out, manual_out); + } + } + + #[rstest] + fn test_get_module_loc_returns_results(populated_db: cozo::DbInstance) { + let result = get_module_loc( + &populated_db, + "default", + None, + false, + ); + + assert!(result.is_ok()); + let loc_map = result.unwrap(); + assert!(!loc_map.is_empty()); + } + + #[rstest] + fn test_get_function_counts_returns_results(populated_db: cozo::DbInstance) { + let result = get_function_counts( + &populated_db, + "default", + None, + false, + ); + + assert!(result.is_ok()); + let counts = result.unwrap(); + assert!(!counts.is_empty()); + } + + #[rstest] + fn test_module_connectivity_returns_fewer_rows(populated_db: cozo::DbInstance) { + // Get module-level connectivity (NEW approach) + let module_conn = get_module_connectivity( + &populated_db, + "default", + None, + false, + ).unwrap(); + + // Get function-level hotspots (OLD approach) + let function_hotspots = find_hotspots( + &populated_db, + HotspotKind::Total, + None, + "default", + false, + u32::MAX, + false, + false, + ).unwrap(); + + // The new approach should return FAR fewer rows + println!("Module connectivity rows: {}", module_conn.len()); + println!("Function hotspots rows: {}", function_hotspots.len()); + + // For any non-trivial codebase, there are more functions than modules + assert!( + module_conn.len() <= function_hotspots.len(), + "Module connectivity ({} rows) should return same or fewer rows than function hotspots ({} rows)", + module_conn.len(), + function_hotspots.len() + ); + + // Calculate reduction percentage + if function_hotspots.len() > 0 { + let reduction = 100.0 * (1.0 - (module_conn.len() as f64 / function_hotspots.len() as f64)); + println!("Row reduction: {:.1}%", reduction); + + // In a typical codebase, we expect significant reduction + // (unless every module has exactly 1 function, which is unlikely) + } + } + + #[rstest] + fn test_get_module_connectivity_nonexistent_project(populated_db: cozo::DbInstance) { + let connectivity = get_module_connectivity( + &populated_db, + "nonexistent_project", + None, + false, + ).unwrap(); + + // Should return empty for non-existent project + assert!(connectivity.is_empty()); + } + + #[rstest] + fn test_get_module_connectivity_nonexistent_module(populated_db: cozo::DbInstance) { + let connectivity = get_module_connectivity( + &populated_db, + "default", + Some("NonExistentModule"), + false, + ).unwrap(); + + // Should return empty when module pattern matches nothing + assert!(connectivity.is_empty()); + } + + #[rstest] + fn test_get_module_connectivity_with_regex(populated_db: cozo::DbInstance) { + let connectivity = get_module_connectivity( + &populated_db, + "default", + Some(".*Accounts.*"), + true, // use regex + ).unwrap(); + + // Should return results matching the regex + for module in connectivity.keys() { + assert!(module.contains("Accounts"), "Module {} doesn't match regex pattern", module); + } + } + + #[rstest] + fn test_get_module_loc_nonexistent_project(populated_db: cozo::DbInstance) { + let loc_map = get_module_loc( + &populated_db, + "nonexistent_project", + None, + false, + ).unwrap(); + + assert!(loc_map.is_empty()); + } + + #[rstest] + fn test_get_function_counts_nonexistent_project(populated_db: cozo::DbInstance) { + let counts = get_function_counts( + &populated_db, + "nonexistent_project", + None, + false, + ).unwrap(); + + assert!(counts.is_empty()); + } + + #[rstest] + fn test_get_module_connectivity_all_values_positive(populated_db: cozo::DbInstance) { + let connectivity = get_module_connectivity( + &populated_db, + "default", + None, + false, + ).unwrap(); + + // Verify all counts are non-negative (sanity check) + for (module, (incoming, outgoing)) in &connectivity { + assert!(*incoming >= 0, "Module {} has negative incoming", module); + assert!(*outgoing >= 0, "Module {} has negative outgoing", module); + } + } +} diff --git a/db/src/queries/import.rs b/db/src/queries/import.rs index dcc7ce0..afa04a8 100644 --- a/db/src/queries/import.rs +++ b/db/src/queries/import.rs @@ -91,7 +91,7 @@ pub fn clear_project_data(db: &DbInstance, project: &str) -> Result<(), Box Result, Box> { - // Build optional module filter - let module_filter = match module_pattern { - Some(_) if use_regex => ", regex_matches(module, $module_pattern)".to_string(), - Some(_) => ", str_includes(module, $module_pattern)".to_string(), - None => String::new(), - }; + validate_regex_patterns(use_regex, &[module_pattern])?; + + // Build conditions using query builders + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); // Build optional generated filter let generated_filter = if include_generated { @@ -55,7 +57,7 @@ pub fn find_large_functions( project == $project, lines = end_line - start_line + 1, lines >= $min_lines - {module_filter} + {module_cond} {generated_filter} :order -lines, module, name @@ -64,10 +66,10 @@ pub fn find_large_functions( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); - params.insert("min_lines".to_string(), DataValue::from(min_lines)); + params.insert("project", DataValue::Str(project.into())); + params.insert("min_lines", DataValue::from(min_lines)); if let Some(pattern) = module_pattern { - params.insert("module_pattern".to_string(), DataValue::Str(pattern.into())); + params.insert("module_pattern", DataValue::Str(pattern.into())); } let rows = run_query(db, &script, params).map_err(|e| LargeFunctionsError::QueryFailed { diff --git a/db/src/queries/location.rs b/db/src/queries/location.rs index 16d3c08..627e20b 100644 --- a/db/src/queries/location.rs +++ b/db/src/queries/location.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, extract_string_or, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum LocationError { @@ -37,18 +38,14 @@ pub fn find_locations( use_regex: bool, limit: u32, ) -> Result, Box> { - // Build the query based on whether we're using regex or exact match - let fn_cond = if use_regex { - "regex_matches(name, $function_pattern)".to_string() - } else { - "name == $function_pattern".to_string() - }; + validate_regex_patterns(use_regex, &[module_pattern, Some(function_pattern)])?; - let module_cond = match module_pattern { - Some(_) if use_regex => ", regex_matches(module, $module_pattern)".to_string(), - Some(_) => ", module == $module_pattern".to_string(), - None => String::new(), - }; + // Build conditions using query builders + let fn_cond = ConditionBuilder::new("name", "function_pattern").build(use_regex); + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); let arity_cond = if arity.is_some() { ", arity == $arity" @@ -72,14 +69,14 @@ pub fn find_locations( ); let mut params = Params::new(); - params.insert("function_pattern".to_string(), DataValue::Str(function_pattern.into())); + params.insert("function_pattern", DataValue::Str(function_pattern.into())); if let Some(mod_pat) = module_pattern { - params.insert("module_pattern".to_string(), DataValue::Str(mod_pat.into())); + params.insert("module_pattern", DataValue::Str(mod_pat.into())); } if let Some(a) = arity { - params.insert("arity".to_string(), DataValue::Num(Num::Int(a))); + params.insert("arity", DataValue::Num(Num::Int(a))); } - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| LocationError::QueryFailed { message: e.to_string(), diff --git a/db/src/queries/many_clauses.rs b/db/src/queries/many_clauses.rs index 10408da..498c654 100644 --- a/db/src/queries/many_clauses.rs +++ b/db/src/queries/many_clauses.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum ManyClausesError { @@ -34,12 +35,13 @@ pub fn find_many_clauses( include_generated: bool, limit: u32, ) -> Result, Box> { - // Build optional module filter - let module_filter = match module_pattern { - Some(_) if use_regex => ", regex_matches(module, $module_pattern)".to_string(), - Some(_) => ", str_includes(module, $module_pattern)".to_string(), - None => String::new(), - }; + validate_regex_patterns(use_regex, &[module_pattern])?; + + // Build conditions using query builders + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); // Build optional generated filter let generated_filter = if include_generated { @@ -53,7 +55,7 @@ pub fn find_many_clauses( clause_counts[module, name, arity, count(line), min(start_line), max(end_line), file, generated_by] := *function_locations{{project, module, name, arity, line, start_line, end_line, file, generated_by}}, project == $project - {module_filter} + {module_cond} {generated_filter} ?[module, name, arity, clauses, first_line, last_line, file, generated_by] := @@ -66,10 +68,10 @@ pub fn find_many_clauses( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); - params.insert("min_clauses".to_string(), DataValue::from(min_clauses)); + params.insert("project", DataValue::Str(project.into())); + params.insert("min_clauses", DataValue::from(min_clauses)); if let Some(pattern) = module_pattern { - params.insert("module_pattern".to_string(), DataValue::Str(pattern.into())); + params.insert("module_pattern", DataValue::Str(pattern.into())); } let rows = run_query(db, &script, params).map_err(|e| ManyClausesError::QueryFailed { diff --git a/db/src/queries/path.rs b/db/src/queries/path.rs index b602991..aca289e 100644 --- a/db/src/queries/path.rs +++ b/db/src/queries/path.rs @@ -103,17 +103,17 @@ pub fn find_paths( ); let mut params = Params::new(); - params.insert("from_module".to_string(), DataValue::Str(from_module.into())); - params.insert("from_function".to_string(), DataValue::Str(from_function.into())); - params.insert("to_module".to_string(), DataValue::Str(to_module.into())); - params.insert("to_function".to_string(), DataValue::Str(to_function.into())); + params.insert("from_module", DataValue::Str(from_module.into())); + params.insert("from_function", DataValue::Str(from_function.into())); + params.insert("to_module", DataValue::Str(to_module.into())); + params.insert("to_function", DataValue::Str(to_function.into())); if let Some(a) = from_arity { - params.insert("from_arity".to_string(), DataValue::from(a)); + params.insert("from_arity", DataValue::from(a)); } if let Some(a) = to_arity { - params.insert("to_arity".to_string(), DataValue::from(a)); + params.insert("to_arity", DataValue::from(a)); } - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| PathError::QueryFailed { message: e.to_string(), diff --git a/db/src/queries/returns.rs b/db/src/queries/returns.rs index e4248c7..83324a6 100644 --- a/db/src/queries/returns.rs +++ b/db/src/queries/returns.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum ReturnsError { @@ -31,27 +32,22 @@ pub fn find_returns( module_pattern: Option<&str>, limit: u32, ) -> Result, Box> { - // Build return string filter - let match_fn = if use_regex { - "regex_matches(return_string, $pattern)" - } else { - "str_includes(return_string, $pattern)" - }; + validate_regex_patterns(use_regex, &[Some(pattern), module_pattern])?; - // Build module filter - let module_filter = match module_pattern { - Some(_) if use_regex => "regex_matches(module, $module_pattern)", - Some(_) => "str_includes(module, $module_pattern)", - None => "true", - }; + // Build conditions using query builders + let pattern_cond = ConditionBuilder::new("return_string", "pattern").build(use_regex); + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); let script = format!( r#" ?[project, module, name, arity, return_string, line] := *specs{{project, module, name, arity, return_string, line}}, project == $project, - {match_fn}, - {module_filter} + {pattern_cond} + {module_cond} :order module, name, arity :limit {limit} @@ -59,12 +55,12 @@ pub fn find_returns( ); let mut params = Params::new(); - params.insert("pattern".to_string(), DataValue::Str(pattern.into())); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("pattern", DataValue::Str(pattern.into())); + params.insert("project", DataValue::Str(project.into())); if let Some(mod_pat) = module_pattern { params.insert( - "module_pattern".to_string(), + "module_pattern", DataValue::Str(mod_pat.into()), ); } diff --git a/db/src/queries/reverse_trace.rs b/db/src/queries/reverse_trace.rs index 1c65ea4..287edf6 100644 --- a/db/src/queries/reverse_trace.rs +++ b/db/src/queries/reverse_trace.rs @@ -92,12 +92,12 @@ pub fn reverse_trace_calls( ); let mut params = Params::new(); - params.insert("module_pattern".to_string(), DataValue::Str(module_pattern.into())); - params.insert("function_pattern".to_string(), DataValue::Str(function_pattern.into())); + params.insert("module_pattern", DataValue::Str(module_pattern.into())); + params.insert("function_pattern", DataValue::Str(function_pattern.into())); if let Some(a) = arity { - params.insert("arity".to_string(), DataValue::from(a)); + params.insert("arity", DataValue::from(a)); } - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| ReverseTraceError::QueryFailed { message: e.to_string(), diff --git a/db/src/queries/search.rs b/db/src/queries/search.rs index 13f12e0..e38b816 100644 --- a/db/src/queries/search.rs +++ b/db/src/queries/search.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, extract_string_or, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder}; #[derive(Error, Debug)] pub enum SearchError { @@ -37,20 +38,22 @@ pub fn search_modules( limit: u32, use_regex: bool, ) -> Result, Box> { - let match_fn = if use_regex { "regex_matches" } else { "str_includes" }; + validate_regex_patterns(use_regex, &[Some(pattern)])?; + + let match_cond = ConditionBuilder::new("name", "pattern").build(use_regex); let script = format!( r#" ?[project, name, source] := *modules{{project, name, source}}, project = $project, - {match_fn}(name, $pattern) + {match_cond} :limit {limit} :order name "#, ); let mut params = Params::new(); - params.insert("pattern".to_string(), DataValue::Str(pattern.into())); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("pattern", DataValue::Str(pattern.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| SearchError::QueryFailed { message: e.to_string(), @@ -76,20 +79,22 @@ pub fn search_functions( limit: u32, use_regex: bool, ) -> Result, Box> { - let match_fn = if use_regex { "regex_matches" } else { "str_includes" }; + validate_regex_patterns(use_regex, &[Some(pattern)])?; + + let match_cond = ConditionBuilder::new("name", "pattern").build(use_regex); let script = format!( r#" ?[project, module, name, arity, return_type] := *functions{{project, module, name, arity, return_type}}, project = $project, - {match_fn}(name, $pattern) + {match_cond} :limit {limit} :order module, name, arity "#, ); let mut params = Params::new(); - params.insert("pattern".to_string(), DataValue::Str(pattern.into())); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("pattern", DataValue::Str(pattern.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| SearchError::QueryFailed { message: e.to_string(), @@ -115,3 +120,80 @@ pub fn search_functions( Ok(results) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_search_modules_invalid_regex() { + let db = crate::test_utils::call_graph_db("default"); + + // Invalid regex pattern: unclosed bracket + let result = search_modules(&db, "[invalid", "test_project", 10, true); + + assert!(result.is_err(), "Should reject invalid regex"); + let err = result.unwrap_err(); + let msg = err.to_string(); + assert!(msg.contains("Invalid regex pattern"), "Error should mention invalid regex: {}", msg); + assert!(msg.contains("[invalid"), "Error should show the pattern: {}", msg); + } + + #[test] + fn test_search_functions_invalid_regex() { + let db = crate::test_utils::call_graph_db("default"); + + // Invalid regex pattern: invalid repetition + let result = search_functions(&db, "*invalid", "test_project", 10, true); + + assert!(result.is_err(), "Should reject invalid regex"); + let err = result.unwrap_err(); + let msg = err.to_string(); + assert!(msg.contains("Invalid regex pattern"), "Error should mention invalid regex: {}", msg); + assert!(msg.contains("*invalid"), "Error should show the pattern: {}", msg); + } + + #[test] + fn test_search_modules_valid_regex() { + let db = crate::test_utils::call_graph_db("default"); + + // Valid regex pattern should not error on validation (may or may not find results) + let result = search_modules(&db, "^test.*$", "test_project", 10, true); + + // Should not fail on validation (may return empty results, that's fine) + assert!(result.is_ok(), "Should accept valid regex: {:?}", result.err()); + } + + #[test] + fn test_search_functions_valid_regex() { + let db = crate::test_utils::call_graph_db("default"); + + // Valid regex pattern should not error on validation + let result = search_functions(&db, "^get_.*$", "test_project", 10, true); + + // Should not fail on validation + assert!(result.is_ok(), "Should accept valid regex: {:?}", result.err()); + } + + #[test] + fn test_search_modules_non_regex_mode() { + let db = crate::test_utils::call_graph_db("default"); + + // Even invalid regex should work in non-regex mode (treated as literal string) + let result = search_modules(&db, "[invalid", "test_project", 10, false); + + // Should succeed (no regex validation in non-regex mode) + assert!(result.is_ok(), "Should accept any pattern in non-regex mode: {:?}", result.err()); + } + + #[test] + fn test_search_functions_non_regex_mode() { + let db = crate::test_utils::call_graph_db("default"); + + // Even invalid regex should work in non-regex mode + let result = search_functions(&db, "*invalid", "test_project", 10, false); + + // Should succeed (no regex validation in non-regex mode) + assert!(result.is_ok(), "Should accept any pattern in non-regex mode: {:?}", result.err()); + } +} diff --git a/db/src/queries/specs.rs b/db/src/queries/specs.rs index fa64ee6..3292cb5 100644 --- a/db/src/queries/specs.rs +++ b/db/src/queries/specs.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum SpecsError { @@ -35,34 +36,26 @@ pub fn find_specs( use_regex: bool, limit: u32, ) -> Result, Box> { - // Build module filter - let module_filter = if use_regex { - "regex_matches(module, $module_pattern)" - } else { - "module == $module_pattern" - }; - - // Build function filter - let function_filter = match function_pattern { - Some(_) if use_regex => ", regex_matches(name, $function_pattern)", - Some(_) => ", str_includes(name, $function_pattern)", - None => "", - }; - - // Build kind filter - let kind_filter_sql = match kind_filter { - Some(_) => ", kind == $kind", - None => "", - }; + validate_regex_patterns(use_regex, &[Some(module_pattern), function_pattern])?; + + // Build conditions using query builders + let module_cond = ConditionBuilder::new("module", "module_pattern").build(use_regex); + let function_cond = OptionalConditionBuilder::new("name", "function_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(function_pattern.is_some(), use_regex); + let kind_cond = OptionalConditionBuilder::new("kind", "kind") + .with_leading_comma() + .build(kind_filter.is_some()); let script = format!( r#" ?[project, module, name, arity, kind, line, inputs_string, return_string, full] := *specs{{project, module, name, arity, kind, line, inputs_string, return_string, full}}, project == $project, - {module_filter} - {function_filter} - {kind_filter_sql} + {module_cond} + {function_cond} + {kind_cond} :order module, name, arity :limit {limit} @@ -70,18 +63,18 @@ pub fn find_specs( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); params.insert( - "module_pattern".to_string(), + "module_pattern", DataValue::Str(module_pattern.into()), ); if let Some(func) = function_pattern { - params.insert("function_pattern".to_string(), DataValue::Str(func.into())); + params.insert("function_pattern", DataValue::Str(func.into())); } if let Some(kind) = kind_filter { - params.insert("kind".to_string(), DataValue::Str(kind.into())); + params.insert("kind", DataValue::Str(kind.into())); } let rows = run_query(db, &script, params).map_err(|e| SpecsError::QueryFailed { diff --git a/db/src/queries/struct_usage.rs b/db/src/queries/struct_usage.rs index 07b2b8e..0a2560c 100644 --- a/db/src/queries/struct_usage.rs +++ b/db/src/queries/struct_usage.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum StructUsageError { @@ -32,27 +33,28 @@ pub fn find_struct_usage( module_pattern: Option<&str>, limit: u32, ) -> Result, Box> { - // Build pattern matching function for both inputs and return - let match_fn = if use_regex { + validate_regex_patterns(use_regex, &[Some(pattern), module_pattern])?; + + // Build pattern matching function for both inputs and return (manual OR condition) + let match_cond = if use_regex { "regex_matches(inputs_string, $pattern) or regex_matches(return_string, $pattern)" } else { - "str_includes(inputs_string, $pattern) or str_includes(return_string, $pattern)" + "inputs_string == $pattern or return_string == $pattern" }; - // Build module filter - let module_filter = match module_pattern { - Some(_) if use_regex => "regex_matches(module, $module_pattern)", - Some(_) => "str_includes(module, $module_pattern)", - None => "true", - }; + // Build module filter using OptionalConditionBuilder + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); let script = format!( r#" ?[project, module, name, arity, inputs_string, return_string, line] := *specs{{project, module, name, arity, inputs_string, return_string, line}}, project == $project, - {match_fn}, - {module_filter} + {match_cond} + {module_cond} :order module, name, arity :limit {limit} @@ -60,12 +62,12 @@ pub fn find_struct_usage( ); let mut params = Params::new(); - params.insert("pattern".to_string(), DataValue::Str(pattern.into())); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("pattern", DataValue::Str(pattern.into())); + params.insert("project", DataValue::Str(project.into())); if let Some(mod_pat) = module_pattern { params.insert( - "module_pattern".to_string(), + "module_pattern", DataValue::Str(mod_pat.into()), ); } diff --git a/db/src/queries/structs.rs b/db/src/queries/structs.rs index 10686a5..0774601 100644 --- a/db/src/queries/structs.rs +++ b/db/src/queries/structs.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_bool, extract_string, extract_string_or, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder}; #[derive(Error, Debug)] pub enum StructError { @@ -47,11 +48,9 @@ pub fn find_struct_fields( use_regex: bool, limit: u32, ) -> Result, Box> { - let module_cond = if use_regex { - "regex_matches(module, $module_pattern)".to_string() - } else { - "module == $module_pattern".to_string() - }; + validate_regex_patterns(use_regex, &[Some(module_pattern)])?; + + let module_cond = ConditionBuilder::new("module", "module_pattern").build(use_regex); let project_cond = ", project == $project"; @@ -67,8 +66,8 @@ pub fn find_struct_fields( ); let mut params = Params::new(); - params.insert("module_pattern".to_string(), DataValue::Str(module_pattern.into())); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("module_pattern", DataValue::Str(module_pattern.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| StructError::QueryFailed { message: e.to_string(), diff --git a/db/src/queries/trace.rs b/db/src/queries/trace.rs index 83b6e7f..482d5ae 100644 --- a/db/src/queries/trace.rs +++ b/db/src/queries/trace.rs @@ -6,7 +6,7 @@ use thiserror::Error; use crate::db::{extract_i64, extract_string, extract_string_or, run_query, Params}; use crate::types::{Call, FunctionRef}; -use crate::query_builders::{ConditionBuilder, OptionalConditionBuilder}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum TraceError { @@ -24,6 +24,8 @@ pub fn trace_calls( max_depth: u32, limit: u32, ) -> Result, Box> { + validate_regex_patterns(use_regex, &[Some(module_pattern), Some(function_pattern)])?; + // Build the starting conditions for the recursive query using helpers let module_cond = ConditionBuilder::new("caller_module", "module_pattern").build(use_regex); let function_cond = ConditionBuilder::new("caller_name", "function_pattern").build(use_regex); @@ -75,12 +77,12 @@ pub fn trace_calls( ); let mut params = Params::new(); - params.insert("module_pattern".to_string(), DataValue::Str(module_pattern.into())); - params.insert("function_pattern".to_string(), DataValue::Str(function_pattern.into())); + params.insert("module_pattern", DataValue::Str(module_pattern.into())); + params.insert("function_pattern", DataValue::Str(function_pattern.into())); if let Some(a) = arity { - params.insert("arity".to_string(), DataValue::from(a)); + params.insert("arity", DataValue::from(a)); } - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); let rows = run_query(db, &script, params).map_err(|e| TraceError::QueryFailed { message: e.to_string(), diff --git a/db/src/queries/types.rs b/db/src/queries/types.rs index aa08274..97aa7ff 100644 --- a/db/src/queries/types.rs +++ b/db/src/queries/types.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, ConditionBuilder, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum TypesError { @@ -33,34 +34,26 @@ pub fn find_types( use_regex: bool, limit: u32, ) -> Result, Box> { - // Build module filter - let module_filter = if use_regex { - "regex_matches(module, $module_pattern)" - } else { - "module == $module_pattern" - }; - - // Build name filter - let name_filter_sql = match name_filter { - Some(_) if use_regex => ", regex_matches(name, $name_pattern)", - Some(_) => ", str_includes(name, $name_pattern)", - None => "", - }; - - // Build kind filter - let kind_filter_sql = match kind_filter { - Some(_) => ", kind == $kind", - None => "", - }; + validate_regex_patterns(use_regex, &[Some(module_pattern), name_filter])?; + + // Build conditions using query builders + let module_cond = ConditionBuilder::new("module", "module_pattern").build(use_regex); + let name_cond = OptionalConditionBuilder::new("name", "name_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(name_filter.is_some(), use_regex); + let kind_cond = OptionalConditionBuilder::new("kind", "kind") + .with_leading_comma() + .build(kind_filter.is_some()); let script = format!( r#" ?[project, module, name, kind, params, line, definition] := *types{{project, module, name, kind, params, line, definition}}, project == $project, - {module_filter} - {name_filter_sql} - {kind_filter_sql} + {module_cond} + {name_cond} + {kind_cond} :order module, name :limit {limit} @@ -68,18 +61,18 @@ pub fn find_types( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); params.insert( - "module_pattern".to_string(), + "module_pattern", DataValue::Str(module_pattern.into()), ); if let Some(name) = name_filter { - params.insert("name_pattern".to_string(), DataValue::Str(name.into())); + params.insert("name_pattern", DataValue::Str(name.into())); } if let Some(kind) = kind_filter { - params.insert("kind".to_string(), DataValue::Str(kind.into())); + params.insert("kind", DataValue::Str(kind.into())); } let rows = run_query(db, &script, params).map_err(|e| TypesError::QueryFailed { diff --git a/db/src/queries/unused.rs b/db/src/queries/unused.rs index 070f534..4193de5 100644 --- a/db/src/queries/unused.rs +++ b/db/src/queries/unused.rs @@ -5,6 +5,7 @@ use serde::Serialize; use thiserror::Error; use crate::db::{extract_i64, extract_string, run_query, Params}; +use crate::query_builders::{validate_regex_patterns, OptionalConditionBuilder}; #[derive(Error, Debug)] pub enum UnusedError { @@ -49,12 +50,13 @@ pub fn find_unused_functions( exclude_generated: bool, limit: u32, ) -> Result, Box> { - // Build optional module filter - let module_filter = match module_pattern { - Some(_) if use_regex => ", regex_matches(module, $module_pattern)".to_string(), - Some(_) => ", str_includes(module, $module_pattern)".to_string(), - None => String::new(), - }; + validate_regex_patterns(use_regex, &[module_pattern])?; + + // Build conditions using query builders + let module_cond = OptionalConditionBuilder::new("module", "module_pattern") + .with_leading_comma() + .with_regex() + .build_with_regex(module_pattern.is_some(), use_regex); // Build kind filter for private_only/public_only let kind_filter = if private_only { @@ -74,7 +76,7 @@ pub fn find_unused_functions( defined[module, name, arity, kind, file, start_line] := *function_locations{{project, module, name, arity, kind, file, start_line}}, project == $project - {module_filter} + {module_cond} {kind_filter} # All functions that are called (as callees) @@ -96,9 +98,9 @@ pub fn find_unused_functions( ); let mut params = Params::new(); - params.insert("project".to_string(), DataValue::Str(project.into())); + params.insert("project", DataValue::Str(project.into())); if let Some(pattern) = module_pattern { - params.insert("module_pattern".to_string(), DataValue::Str(pattern.into())); + params.insert("module_pattern", DataValue::Str(pattern.into())); } let rows = run_query(db, &script, params).map_err(|e| UnusedError::QueryFailed { diff --git a/db/src/query_builders.rs b/db/src/query_builders.rs index 6ff2609..22ac9cc 100644 --- a/db/src/query_builders.rs +++ b/db/src/query_builders.rs @@ -1,4 +1,99 @@ //! Query condition builders for CozoScript +//! +//! # Regex Validation Strategy +//! +//! This module validates regex patterns using the standard Rust `regex` crate before +//! passing them to CozoDB. While this means patterns are compiled twice (once during +//! validation, once by CozoDB during query execution), this is an intentional design +//! decision that provides significant benefits: +//! +//! - **Same Engine**: CozoDB uses `regex = "1.10.4"` (the same crate we use), so +//! validation results perfectly match CozoDB's behavior. There are no false positives +//! or negatives due to engine differences. +//! +//! - **Better UX**: Early validation at the CLI boundary provides clear, actionable error +//! messages. Without this, users would get cryptic CozoDB query errors that are harder +//! to understand and debug. +//! +//! - **Acceptable Cost**: Regex compilation is fast (~1ms per pattern), making the +//! performance overhead negligible compared to the UX improvement. +//! +//! See: https://github.com/cozodb/cozo/blob/main/cozo-core/Cargo.toml for CozoDB's +//! regex dependency version. + +use std::error::Error; + +/// Validates a regex pattern string +/// +/// # Arguments +/// * `pattern` - The regex pattern to validate +/// +/// # Returns +/// * `Ok(())` if the pattern is valid +/// * `Err` with a user-friendly error message if the pattern is invalid +/// +/// # Examples +/// ``` +/// use db::query_builders::validate_regex_pattern; +/// +/// assert!(validate_regex_pattern("^hello.*world$").is_ok()); +/// assert!(validate_regex_pattern("[invalid").is_err()); +/// ``` +pub fn validate_regex_pattern(pattern: &str) -> Result<(), Box> { + regex::Regex::new(pattern).map_err(|e| -> Box { + format!( + "Invalid regex pattern '{}': {}", + pattern, + e.to_string() + ) + .into() + })?; + Ok(()) +} + +/// Validates multiple regex patterns at once (only if regex mode is enabled) +/// +/// This is a convenience helper for query functions that accept multiple optional +/// patterns. It validates all patterns only when `use_regex` is true. +/// +/// # Arguments +/// * `use_regex` - Whether regex mode is enabled +/// * `patterns` - Slice of optional pattern strings to validate +/// +/// # Returns +/// * `Ok(())` if all patterns are valid (or if `use_regex` is false) +/// * `Err` with a user-friendly error message if any pattern is invalid +/// +/// # Examples +/// ``` +/// use db::query_builders::validate_regex_patterns; +/// +/// // Non-regex mode: accepts any patterns (no validation) +/// assert!(validate_regex_patterns(false, &[Some("[invalid")]).is_ok()); +/// +/// // Regex mode: validates all patterns +/// assert!(validate_regex_patterns(true, &[Some("^hello$"), Some("world.*")]).is_ok()); +/// assert!(validate_regex_patterns(true, &[Some("[invalid")]).is_err()); +/// +/// // None patterns are skipped +/// assert!(validate_regex_patterns(true, &[Some("valid"), None, Some("also.*valid")]).is_ok()); +/// ``` +pub fn validate_regex_patterns( + use_regex: bool, + patterns: &[Option<&str>], +) -> Result<(), Box> { + if !use_regex { + return Ok(()); + } + + for pattern_opt in patterns { + if let Some(pattern) = pattern_opt { + validate_regex_pattern(pattern)?; + } + } + + Ok(()) +} /// Builds SQL WHERE clause conditions for query patterns (exact or regex matching) /// @@ -15,8 +110,8 @@ /// let cond = builder.build(true); // "regex_matches(module, $module_pattern)" /// ``` pub struct ConditionBuilder { - field_name: String, - param_name: String, + field_name: &'static str, + param_name: &'static str, with_leading_comma: bool, } @@ -26,10 +121,10 @@ impl ConditionBuilder { /// # Arguments /// * `field_name` - The SQL field name (e.g., "module", "caller_module") /// * `param_name` - The parameter name (e.g., "module_pattern", "function_pattern") - pub fn new(field_name: &str, param_name: &str) -> Self { + pub fn new(field_name: &'static str, param_name: &'static str) -> Self { Self { - field_name: field_name.to_string(), - param_name: param_name.to_string(), + field_name, + param_name, with_leading_comma: false, } } @@ -69,10 +164,10 @@ impl ConditionBuilder { /// Handles the pattern of generating conditions only when values are present. /// For function-matching conditions, supports both exact and regex matching. pub struct OptionalConditionBuilder { - field_name: String, - param_name: String, + field_name: &'static str, + param_name: &'static str, with_leading_comma: bool, - when_none: Option, // Alternative condition when value is None + when_none: Option<&'static str>, // Alternative condition when value is None supports_regex: bool, // Whether to use regex_matches when value is present } @@ -82,10 +177,10 @@ impl OptionalConditionBuilder { /// # Arguments /// * `field_name` - The SQL field name /// * `param_name` - The parameter name - pub fn new(field_name: &str, param_name: &str) -> Self { + pub fn new(field_name: &'static str, param_name: &'static str) -> Self { Self { - field_name: field_name.to_string(), - param_name: param_name.to_string(), + field_name, + param_name, with_leading_comma: false, when_none: None, supports_regex: false, @@ -105,8 +200,8 @@ impl OptionalConditionBuilder { } /// Sets an alternative condition when the value is None (e.g., "true" for no-op) - pub fn when_none(mut self, condition: &str) -> Self { - self.when_none = Some(condition.to_string()); + pub fn when_none(mut self, condition: &'static str) -> Self { + self.when_none = Some(condition); self } @@ -199,4 +294,140 @@ mod tests { assert_eq!(builder.build(true), ", arity == $arity"); assert_eq!(builder.build(false), ", true"); } + + // ========================================================================= + // Regex validation tests + // ========================================================================= + + #[test] + fn test_validate_regex_pattern_valid() { + // Simple patterns + assert!(validate_regex_pattern("hello").is_ok()); + assert!(validate_regex_pattern("^start").is_ok()); + assert!(validate_regex_pattern("end$").is_ok()); + + // Character classes + assert!(validate_regex_pattern("[abc]").is_ok()); + assert!(validate_regex_pattern("[a-z]").is_ok()); + assert!(validate_regex_pattern("[^abc]").is_ok()); + + // Quantifiers + assert!(validate_regex_pattern("a*").is_ok()); + assert!(validate_regex_pattern("a+").is_ok()); + assert!(validate_regex_pattern("a?").is_ok()); + assert!(validate_regex_pattern("a{2,4}").is_ok()); + + // Groups and alternation + assert!(validate_regex_pattern("(foo|bar)").is_ok()); + assert!(validate_regex_pattern("(?:non-capturing)").is_ok()); + + // Common real-world patterns + assert!(validate_regex_pattern(r"^get_user$").is_ok()); + assert!(validate_regex_pattern(r"\.(Accounts|Users)$").is_ok()); + assert!(validate_regex_pattern(r"MyApp\..*\.Service$").is_ok()); + } + + #[test] + fn test_validate_regex_pattern_invalid() { + // Unclosed brackets + let err = validate_regex_pattern("[invalid").unwrap_err(); + assert!(err.to_string().contains("Invalid regex pattern")); + assert!(err.to_string().contains("[invalid")); + + // Unclosed parenthesis + let err = validate_regex_pattern("(unclosed").unwrap_err(); + assert!(err.to_string().contains("Invalid regex pattern")); + + // Invalid repetition + let err = validate_regex_pattern("*invalid").unwrap_err(); + assert!(err.to_string().contains("Invalid regex pattern")); + + // Invalid escape + let err = validate_regex_pattern(r"\k").unwrap_err(); + assert!(err.to_string().contains("Invalid regex pattern")); + + // Invalid quantifier + let err = validate_regex_pattern("a{,}").unwrap_err(); + assert!(err.to_string().contains("Invalid regex pattern")); + } + + #[test] + fn test_validate_regex_pattern_empty() { + // Empty pattern is valid (matches everything) + assert!(validate_regex_pattern("").is_ok()); + } + + #[test] + fn test_validate_regex_pattern_error_message_format() { + let err = validate_regex_pattern("[unclosed").unwrap_err(); + let msg = err.to_string(); + + // Should contain the pattern itself + assert!(msg.contains("[unclosed"), "Error should show the pattern: {}", msg); + + // Should contain "Invalid regex pattern" + assert!(msg.contains("Invalid regex pattern"), "Error should say 'Invalid regex pattern': {}", msg); + } + + // ========================================================================= + // Regex patterns validation helper tests + // ========================================================================= + + #[test] + fn test_validate_regex_patterns_non_regex_mode() { + // Non-regex mode should accept any patterns without validation + assert!(validate_regex_patterns(false, &[Some("[invalid")]).is_ok()); + assert!(validate_regex_patterns(false, &[Some("*bad"), Some("(unclosed")]).is_ok()); + } + + #[test] + fn test_validate_regex_patterns_all_valid() { + // All valid patterns should succeed + assert!(validate_regex_patterns(true, &[Some("^hello$"), Some("world.*")]).is_ok()); + assert!(validate_regex_patterns(true, &[Some("test")]).is_ok()); + } + + #[test] + fn test_validate_regex_patterns_with_invalid() { + // Should fail on first invalid pattern + let result = validate_regex_patterns(true, &[Some("valid"), Some("[invalid")]); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().contains("[invalid")); + } + + #[test] + fn test_validate_regex_patterns_with_none() { + // None patterns should be skipped + assert!(validate_regex_patterns(true, &[Some("valid"), None, Some("also.*valid")]).is_ok()); + assert!(validate_regex_patterns(true, &[None, None]).is_ok()); + } + + #[test] + fn test_validate_regex_patterns_empty() { + // Empty slice should succeed + assert!(validate_regex_patterns(true, &[]).is_ok()); + } + + #[test] + fn test_validate_regex_patterns_fails_on_first_invalid() { + // Should return error from first invalid pattern, not continue + let result = validate_regex_patterns(true, &[Some("[first_bad"), Some("(second_bad")]); + assert!(result.is_err()); + let err = result.unwrap_err(); + // Should show first invalid pattern + assert!(err.to_string().contains("[first_bad")); + } + + #[test] + fn test_validate_regex_patterns_mixed_valid_and_none() { + // Mix of valid patterns and None should work + assert!(validate_regex_patterns(true, &[ + Some("^test.*$"), + None, + Some("[a-z]+"), + None, + Some("\\d{3}") + ]).is_ok()); + } } diff --git a/docs/NEW_COMMANDS.md b/docs/NEW_COMMANDS.md index 7be4312..69cc3f6 100644 --- a/docs/NEW_COMMANDS.md +++ b/docs/NEW_COMMANDS.md @@ -269,14 +269,15 @@ crate::dedup::deduplicate_retain(&mut calls, |c| { ``` Use when: Items are already sorted and you want to remove duplicates while preserving order -**DeduplicationFilter** - For prevention during collection +**HashMap for deduplication** - For prevention during collection ```rust -let mut filter = crate::dedup::DeduplicationFilter::new(); -if filter.should_process(key) { +let mut seen: HashMap = HashMap::new(); +if !seen.contains_key(&key) { + seen.insert(key, index); // Add entry to result } ``` -Use when: Building results incrementally and preventing duplicates before adding +Use when: Building results incrementally and need both deduplication and O(1) lookups (e.g., parent index lookups in trace commands) **Benefits:** - No more HashSet boilerplate scattered across commands @@ -471,7 +472,7 @@ impl TableFormatter for ModuleGroupResult { - [ ] Use `crate::dedup::*` utilities for deduplication (if needed): - [ ] `sort_and_deduplicate()` for combined sort + dedup - [ ] `deduplicate_retain()` for post-sort dedup - - [ ] `DeduplicationFilter` for incremental collection + - [ ] Use HashMap/HashSet directly for incremental collection with O(1) lookups - [ ] Implement `TableFormatter` trait in `output.rs` (NOT `Outputable`) - [ ] Document `file` field population decision with inline comments - [ ] Test `to_table()` output against expected string constants