Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 155 additions & 50 deletions db/src/queries/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ pub fn clear_project_data(db: &dyn Database) -> Result<(), Box<dyn Error>> {
}

/// Import modules to SurrealDB
pub fn import_modules(
db: &dyn Database,
graph: &CallGraph,
) -> Result<usize, Box<dyn Error>> {
pub fn import_modules(db: &dyn Database, graph: &CallGraph) -> Result<usize, Box<dyn Error>> {
// Collect unique modules from all data sources
let mut modules = std::collections::HashSet::new();
modules.extend(graph.specs.keys().cloned());
Expand All @@ -118,10 +115,7 @@ pub fn import_modules(
/// Functions are created from function_locations, which contains the actual
/// function definitions. Specs are metadata that belong to functions and are
/// linked via name/arity matching, not imported as separate function records.
pub fn import_functions(
db: &dyn Database,
graph: &CallGraph,
) -> Result<usize, Box<dyn Error>> {
pub fn import_functions(db: &dyn Database, graph: &CallGraph) -> Result<usize, Box<dyn Error>> {
use std::collections::HashSet;
let mut count = 0;
let mut seen: HashSet<(String, String, i64)> = HashSet::new();
Expand Down Expand Up @@ -165,10 +159,7 @@ pub fn import_functions(
}

/// Import calls to SurrealDB
pub fn import_calls(
db: &dyn Database,
graph: &CallGraph,
) -> Result<usize, Box<dyn Error>> {
pub fn import_calls(db: &dyn Database, graph: &CallGraph) -> Result<usize, Box<dyn Error>> {
let mut count = 0;

for call in &graph.calls {
Expand Down Expand Up @@ -257,10 +248,7 @@ fn parse_function_ref(func_ref: &str) -> (&str, i64) {
}

/// Import structs to SurrealDB (as fields)
pub fn import_structs(
db: &dyn Database,
graph: &CallGraph,
) -> Result<usize, Box<dyn Error>> {
pub fn import_structs(db: &dyn Database, graph: &CallGraph) -> Result<usize, Box<dyn Error>> {
let mut count = 0;

for (module_name, def) in &graph.structs {
Expand Down Expand Up @@ -344,18 +332,15 @@ pub fn import_function_locations(
}

/// Import specs to SurrealDB with array fields preserved
pub fn import_specs(
db: &dyn Database,
graph: &CallGraph,
) -> Result<usize, Box<dyn Error>> {
pub fn import_specs(db: &dyn Database, graph: &CallGraph) -> Result<usize, Box<dyn Error>> {
let mut count = 0;

for (module_name, specs) in &graph.specs {
for spec in specs {
// Import each clause as a separate spec row with clause_index
for (clause_index, clause) in spec.clauses.iter().enumerate() {
let query = r#"
CREATE specs:[$module_name, $function_name, $arity, $clause_index] SET
UPSERT specs:[$module_name, $function_name, $arity, $clause_index] SET
module_name = $module_name,
function_name = $function_name,
arity = $arity,
Expand Down Expand Up @@ -387,10 +372,7 @@ pub fn import_specs(
}

/// Import types to SurrealDB
pub fn import_types(
db: &dyn Database,
graph: &CallGraph,
) -> Result<usize, Box<dyn Error>> {
pub fn import_types(db: &dyn Database, graph: &CallGraph) -> Result<usize, Box<dyn Error>> {
let mut count = 0;

for (module_name, types) in &graph.types {
Expand Down Expand Up @@ -539,10 +521,7 @@ pub fn create_has_field_relationships(
///
/// Creates schemas and imports all data (modules, functions, calls, structs, locations).
/// This is the core import logic used by both the CLI command and test utilities.
pub fn import_graph(
db: &dyn Database,
graph: &CallGraph,
) -> Result<ImportResult, Box<dyn Error>> {
pub fn import_graph(db: &dyn Database, graph: &CallGraph) -> Result<ImportResult, Box<dyn Error>> {
let mut result = ImportResult::default();

result.schemas = create_schema(db)?;
Expand Down Expand Up @@ -570,10 +549,7 @@ pub fn import_graph(
///
/// Convenience wrapper for tests that parses JSON and calls `import_graph`.
#[cfg(any(test, feature = "test-utils"))]
pub fn import_json_str(
db: &dyn Database,
content: &str,
) -> Result<ImportResult, Box<dyn Error>> {
pub fn import_json_str(db: &dyn Database, content: &str) -> Result<ImportResult, Box<dyn Error>> {
let graph: CallGraph =
serde_json::from_str(content).map_err(|e| ImportError::JsonParseFailed {
message: e.to_string(),
Expand Down Expand Up @@ -744,8 +720,14 @@ mod tests {
let input_arr = row.get(0).and_then(|v| v.as_array());
let return_arr = row.get(1).and_then(|v| v.as_array());

assert!(input_arr.is_some(), "input_strings should be stored as array");
assert!(return_arr.is_some(), "return_strings should be stored as array");
assert!(
input_arr.is_some(),
"input_strings should be stored as array"
);
assert!(
return_arr.is_some(),
"return_strings should be stored as array"
);

// Verify array contents
let inputs = input_arr.unwrap();
Expand All @@ -758,6 +740,80 @@ mod tests {
assert_eq!(returns[0].as_str(), Some(":ok"));
}

/// Test import_specs upserts duplicate natural keys (same module/name/arity/clause_index)
#[test]
fn test_import_specs_upserts_duplicate_record_ids() {
let db = crate::open_mem_db().unwrap();
crate::queries::schema::create_schema(&*db).unwrap();

let json = r#"{
"specs": {
"MyApp.Integrations.UrlShortener": [
{
"name": "shorten",
"arity": 3,
"line": 10,
"kind": "callback",
"clauses": [
{
"full": "@callback shorten(url(), config(), opts()) :: result()",
"input_strings": ["url()", "config()", "opts()"],
"return_strings": ["result()"]
}
]
},
{
"name": "shorten",
"arity": 3,
"line": 12,
"kind": "spec",
"clauses": [
{
"full": "@spec shorten(url(), config(), opts()) :: result()",
"input_strings": ["url()", "config()", "opts()"],
"return_strings": ["result()"]
}
]
}
]
},
"function_locations": {},
"calls": [],
"structs": {},
"types": {}
}"#;

let graph: CallGraph = serde_json::from_str(json).unwrap();
import_modules(&*db, &graph).unwrap();
import_functions(&*db, &graph).unwrap();

let result = import_specs(&*db, &graph);
assert!(
result.is_ok(),
"Import specs should upsert duplicates: {:?}",
result.err()
);
assert_eq!(
result.unwrap(),
2,
"Should process both entries even with duplicate spec IDs"
);

let query = "SELECT full FROM specs";
let rows = db.execute_query(query, QueryParams::new()).unwrap();
assert_eq!(
rows.rows().len(),
1,
"Duplicate specs should collapse into one record"
);

let row = rows.rows().first().unwrap();
assert_eq!(
row.get(0).and_then(|v| v.as_str()),
Some("@spec shorten(url(), config(), opts()) :: result()")
);
}

/// Test import_function_locations creates clauses
#[test]
fn test_import_function_locations_creates_clauses() {
Expand Down Expand Up @@ -1240,7 +1296,8 @@ mod tests {

// Verify call counts were updated during import
// Columns in alphabetical order: incoming_call_count (0), name (1), outgoing_call_count (2)
let query = "SELECT name, incoming_call_count, outgoing_call_count FROM functions ORDER BY name";
let query =
"SELECT name, incoming_call_count, outgoing_call_count FROM functions ORDER BY name";
let rows = db.execute_query(query, QueryParams::new()).unwrap();
let counts: std::collections::HashMap<String, (i64, i64)> = rows
.rows()
Expand Down Expand Up @@ -1332,18 +1389,29 @@ mod tests {
// Before update_call_counts, all counts should be 0
// Note: SurrealDB returns columns in alphabetical order, so:
// incoming_call_count (0), name (1), outgoing_call_count (2)
let query = "SELECT name, incoming_call_count, outgoing_call_count FROM functions ORDER BY name";
let query =
"SELECT name, incoming_call_count, outgoing_call_count FROM functions ORDER BY name";
let rows = db.execute_query(query, QueryParams::new()).unwrap();
for row in rows.rows() {
let incoming = row.get(0).and_then(|v| v.as_i64()).unwrap_or(-1);
let outgoing = row.get(2).and_then(|v| v.as_i64()).unwrap_or(-1);
assert_eq!(incoming, 0, "Before update, incoming_call_count should be 0");
assert_eq!(outgoing, 0, "Before update, outgoing_call_count should be 0");
assert_eq!(
incoming, 0,
"Before update, incoming_call_count should be 0"
);
assert_eq!(
outgoing, 0,
"Before update, outgoing_call_count should be 0"
);
}

// Run update_call_counts
let result = update_call_counts(&*db);
assert!(result.is_ok(), "update_call_counts should succeed: {:?}", result.err());
assert!(
result.is_ok(),
"update_call_counts should succeed: {:?}",
result.err()
);

// Verify counts after update
let rows = db.execute_query(query, QueryParams::new()).unwrap();
Expand All @@ -1360,16 +1428,32 @@ mod tests {
.collect();

// get_user: calls Repo.get, not called by anyone in our graph
assert_eq!(counts.get("get_user"), Some(&(0, 1)), "get_user: incoming=0, outgoing=1");
assert_eq!(
counts.get("get_user"),
Some(&(0, 1)),
"get_user: incoming=0, outgoing=1"
);

// Repo.get: called by get_user, doesn't call anything
assert_eq!(counts.get("get"), Some(&(1, 0)), "get: incoming=1, outgoing=0");
assert_eq!(
counts.get("get"),
Some(&(1, 0)),
"get: incoming=1, outgoing=0"
);

// index: calls list_users, not called by anyone
assert_eq!(counts.get("index"), Some(&(0, 1)), "index: incoming=0, outgoing=1");
assert_eq!(
counts.get("index"),
Some(&(0, 1)),
"index: incoming=0, outgoing=1"
);

// list_users: called by index, doesn't call anything
assert_eq!(counts.get("list_users"), Some(&(1, 0)), "list_users: incoming=1, outgoing=0");
assert_eq!(
counts.get("list_users"),
Some(&(1, 0)),
"list_users: incoming=1, outgoing=0"
);
}

/// Test update_call_counts handles functions with multiple incoming/outgoing calls
Expand Down Expand Up @@ -1433,7 +1517,8 @@ mod tests {

// Query counts
// Columns in alphabetical order: incoming_call_count (0), name (1), outgoing_call_count (2)
let query = "SELECT name, incoming_call_count, outgoing_call_count FROM functions ORDER BY name";
let query =
"SELECT name, incoming_call_count, outgoing_call_count FROM functions ORDER BY name";
let rows = db.execute_query(query, QueryParams::new()).unwrap();
let counts: std::collections::HashMap<String, (i64, i64)> = rows
.rows()
Expand All @@ -1447,16 +1532,32 @@ mod tests {
.collect();

// Repo.get: called twice (by get_user and update_user)
assert_eq!(counts.get("get"), Some(&(2, 0)), "get: incoming=2, outgoing=0");
assert_eq!(
counts.get("get"),
Some(&(2, 0)),
"get: incoming=2, outgoing=0"
);

// Repo.update: called once (by update_user)
assert_eq!(counts.get("update"), Some(&(1, 0)), "update: incoming=1, outgoing=0");
assert_eq!(
counts.get("update"),
Some(&(1, 0)),
"update: incoming=1, outgoing=0"
);

// get_user: makes 1 call (to Repo.get)
assert_eq!(counts.get("get_user"), Some(&(0, 1)), "get_user: incoming=0, outgoing=1");
assert_eq!(
counts.get("get_user"),
Some(&(0, 1)),
"get_user: incoming=0, outgoing=1"
);

// update_user: makes 2 calls (to Repo.get and Repo.update)
assert_eq!(counts.get("update_user"), Some(&(0, 2)), "update_user: incoming=0, outgoing=2");
assert_eq!(
counts.get("update_user"),
Some(&(0, 2)),
"update_user: incoming=0, outgoing=2"
);
}

/// Test update_call_counts handles empty calls table (no calls)
Expand Down Expand Up @@ -1489,7 +1590,11 @@ mod tests {

// Run update_call_counts - should not error even with no calls
let result = update_call_counts(&*db);
assert!(result.is_ok(), "update_call_counts should succeed with no calls: {:?}", result.err());
assert!(
result.is_ok(),
"update_call_counts should succeed with no calls: {:?}",
result.err()
);

// Verify counts are 0
// Columns in alphabetical order: incoming_call_count (0), name (1), outgoing_call_count (2)
Expand Down