From b7d8112bc2182af7105e1c7d3792ad96a3d6a9e2 Mon Sep 17 00:00:00 2001 From: Rex Kerr Date: Sat, 21 Dec 2024 17:38:23 -0800 Subject: [PATCH] Changed Rust code to use more idiomatic iterator pattern. The previous code used indexing from ranges e.g. `for i in 1..=m`. While this works, it's usually slower and more error-prone than iterator-based traversal. This patch switches to the more idiomatic .iter().enumerate() pattern since this is how you would usually do this in Rust. This does seem also to improve performance (~20% on my machine). Furthermore, the working space for the two rows was initialized in an atypical way: instead of creating the working space with the content and size desired, it created empty Vecs and then put content into it. This did not alter performance on my machine, but it makes the code easier to read (because the pattern is more expected). I also switched the argument reading to `iter().enumerate()` since it also is more expected. (No performance implication.) None of this changes the algorithm from what is used in the C reference; rather, they express the same algorithm in a more idiomatic Rust style. --- levenshtein/rust/code.rs | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/levenshtein/rust/code.rs b/levenshtein/rust/code.rs index 78eac1e8..b516253b 100644 --- a/levenshtein/rust/code.rs +++ b/levenshtein/rust/code.rs @@ -27,30 +27,25 @@ fn levenshtein_distance(s1: &str, s2: &str) -> usize { }; let m = s1_bytes.len(); - let n = s2_bytes.len(); // Use two rows instead of full matrix for space optimization - let mut prev_row = Vec::with_capacity(m + 1); - let mut curr_row = Vec::with_capacity(m + 1); - - // Initialize first row - prev_row.extend(0..=m); - curr_row.resize(m + 1, 0); - + let mut prev_row: Vec = (0..=m).collect(); + let mut curr_row = vec![0; m+1]; + // Main computation loop - for j in 1..=n { - curr_row[0] = j; + for (j, b2) in s2_bytes.iter().enumerate() { + curr_row[0] = j+1; - for i in 1..=m { - let cost = if s1_bytes[i - 1] == s2_bytes[j - 1] { 0 } else { 1 }; + for (i, b1) in s1_bytes.iter().enumerate() { + let cost = if b1 == b2 { 0 } else { 1 }; // Calculate minimum of three operations - curr_row[i] = std::cmp::min( + curr_row[i+1] = std::cmp::min( std::cmp::min( - prev_row[i] + 1, // deletion - curr_row[i - 1] + 1, // insertion + prev_row[i+1] + 1, // deletion + curr_row[i] + 1, // insertion ), - prev_row[i - 1] + cost // substitution + prev_row[i] + cost // substitution ); } @@ -73,10 +68,10 @@ fn main() { let mut times = 0; // Compare all pairs of strings - for i in 0..args.len() { - for j in 0..args.len() { + for (i, arg1) in args.iter().enumerate() { + for (j, arg2) in args.iter().enumerate() { if i != j { - let distance = levenshtein_distance(&args[i], &args[j]); + let distance = levenshtein_distance(&arg1, &arg2); if let Some(current_min) = min_distance { if distance < current_min { min_distance = Some(distance);