From 1bbc7a7cfc2b2e215e0d22cbf18c6d303bbf4a04 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Mon, 19 Dec 2016 01:35:47 +0900 Subject: [PATCH 1/2] ENH: implement Agglomerative clustering --- README.md | 3 +- src/learning/agglomerative.rs | 415 ++++++++++++++++++++++++++++++++ src/lib.rs | 2 + tests/learning/agglomerative.rs | 54 +++++ tests/lib.rs | 1 + 5 files changed, 474 insertions(+), 1 deletion(-) create mode 100644 src/learning/agglomerative.rs create mode 100644 tests/learning/agglomerative.rs diff --git a/README.md b/README.md index 59590f07..7e3c9299 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ It aims to combine speed and ease of use - without requiring a huge number of ex This project began as a way for me to learn Rust and brush up on some less familiar machine learning algorithms and techniques. Now the project aims to provide a complete, easy to use, machine learning library for Rust. -This library is still very much in early stages of development. Although there are a good number of algorithms many other +This library is still very much in early stages of development. Although there are a good number of algorithms many other things are missing. Rusty-machine is probably not the best choice for any serious projects - but hopefully that can change in the near future! #### Contributing @@ -50,6 +50,7 @@ This is fairly complete but there is still lots of room for optimization and we - Logistic Regression - Generalized Linear Models - K-Means Clustering +- Agglomerative Clustering - Neural Networks - Gaussian Process Regression - Support Vector Machines diff --git a/src/learning/agglomerative.rs b/src/learning/agglomerative.rs new file mode 100644 index 00000000..8093ff9c --- /dev/null +++ b/src/learning/agglomerative.rs @@ -0,0 +1,415 @@ +//! Agglomerative (Hierarchical) Clustering Module +//! +//! Contains implementation of Agglomerative Clustering. +//! +//! # Usage +//! +//! ``` +//! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics}; +//! use rusty_machine::learning::SupModel; +//! use rusty_machine::linalg::{Matrix, Vector}; +//! +//! let inputs = Matrix::new(4, 2, vec![1., 3., 2., 3., 4., 3., 5., 3.]); +//! let mut agg = AgglomerativeClustering::new(2, Metrics::Single); +//! +//! // Train the model and get the clustering result +//! let res = agg.train(&inputs).unwrap(); +//! +//! assert_eq!(res, Vector::new(vec![0, 0, 1, 1])); +//! ``` +//! ``` + +use std::collections::{BTreeMap, HashMap}; +use std::f64; + +use linalg::{Matrix, BaseMatrix, Vector}; +use learning::{LearningResult}; + +/// Agglomerative clustering distances +#[derive(Debug)] +pub enum Metrics { + /// Single linkage clustering + Single, + /// Complete linkage clustering + Complete, + /// Average linkage clustering + Average, + /// Centroid linkage clustering + Centroid, + /// Median linkage clustering + Median, + + /// Ward criterion (uses Ward II) + Ward, + /// Ward I, + /// See "Ward’s Hierarchical Agglomerative Clustering Method: + /// Which Algorithms Implement Ward’s Criterion? (Murtagh, 2014)" + Ward1, + /// Ward II + Ward2, +} + +impl Metrics { + + // calculate distance using Lance-Williams algorithm + fn dist(&self, ci: &Cluster, cj: &Cluster, ck: &Cluster, dmat: &DistanceMatrix) -> f64 { + + let dik = dmat.get(ck.id, ci.id); + let djk = dmat.get(ck.id, cj.id); + + match self { + &Metrics::Single => { + // 0.5 * dik + 0.5 * djk + 0. * dij - 0.5 * (dik - djk).abs() + dik.min(djk) + }, + &Metrics::Complete => { + // 0.5 * dik + 0.5 * djk + 0. * dij + 0.5 * (dik - djk).abs() + dik.max(djk) + }, + &Metrics::Average => { + let s = ci.size + cj.size; + ci.size / s * dik + cj.size / s * djk + }, + &Metrics::Centroid => { + let s = ci.size + cj.size; + let ai = ci.size / s; + let aj = cj.size / s; + let dij = dmat.get(ci.id, cj.id); + ai * dik + aj * djk - ai * aj * dij + }, + &Metrics::Median => { + let dij = dmat.get(ci.id, cj.id); + 0.5 * dik + 0.5 * djk - 0.25 * dij + }, + &Metrics::Ward1 => { + let s = ci.size + cj.size + ck.size; + let dij = dmat.get(ci.id, cj.id); + (ci.size + ck.size) / s * dik + (cj.size + ck.size) / s * djk - ck.size / s * dij + }, + &Metrics::Ward | &Metrics::Ward2 => { + let s = ci.size + cj.size + ck.size; + let dij = dmat.get(ci.id, cj.id); + ((ci.size + ck.size) / s * dik.powf(2.) + (cj.size + ck.size) / s * djk.powf(2.) - ck.size / s * dij.powf(2.)).sqrt() + } + } + } +} + +struct Cluster { + /// Cluster id + id: usize, + /// Number of nodes (rows) which belongs to cluster + /// to avoid cast in the algorithm, store it as f64 + size: f64, + /// Row ids belong to the cluster + nodes: Vec, +} + +impl Cluster { + + /// Create new cluster + fn new(id: usize, nodes: Vec) -> Cluster { + Cluster { + id: id, + size: 1., + nodes: nodes, + } + } + + /// Create new cluster merging left and right + fn from_clusters(id: usize, left: Cluster, mut right: Cluster) -> Cluster { + let mut new_nodes = left.nodes; + new_nodes.append(&mut right.nodes); + Cluster { + id: id, + size: left.size + right.size, + nodes: new_nodes + } + } +} + +/// Distance Matrix +#[derive(Debug)] +struct DistanceMatrix { + // Distance is symmetric, no need to hold all pairs + // use HashMap to easier update + data: HashMap<(usize, usize), f64> +} + +impl DistanceMatrix { + + /// Create distance matrix fron input matrix + fn from_mat(inputs: &Matrix) -> Self { + assert!(inputs.rows() > 0, "input is empty"); + + let n = inputs.rows() - 1; + let mut data: HashMap<(usize, usize), f64> = HashMap::with_capacity(n * n); + + unsafe { + for i in 0..n { + for j in i..inputs.rows() { + let mut val = 0.; + for k in 0..inputs.cols() { + val += (inputs.get_unchecked([i, k]) - + inputs.get_unchecked([j, k])).abs().powf(2.); + } + val = val.sqrt(); + data.insert((i, j), val); + } + } + } + DistanceMatrix { + data: data + } + } + + /// Get distance between i-th and j-th item + fn get(&self, i: usize, j: usize) -> f64 { + if i == j { + 0. + } else if i > j { + *self.data.get(&(j, i)).unwrap() + } else { + *self.data.get(&(i, j)).unwrap() + } + } + + /// Add distance between i-th and j-th item + /// i must be smaller than j + fn insert(&mut self, i: usize, j: usize, dist: f64) { + assert!(i < j, "i must be smaller than j"); + self.data.insert((i, j), dist); + } + + /// Delete distance between i-th and j-th item + fn delete(&mut self, i: usize, j: usize) { + assert!(i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0"); + if i > j { + self.data.remove(&(j, i)); + } else { + self.data.remove(&(i, j)); + } + } +} + +/// Agglomerative clustering +#[derive(Debug)] +pub struct AgglomerativeClustering { + n: usize, + method: Metrics, + + // internally stores distances / merged history (currently for testing) + distances: Option>, + merged: Option> +} + +impl AgglomerativeClustering { + + /// Constructs an untrained Decision Tree with specified + /// + /// - `n` - Number of clusters + /// - `method` - Distance metrics + /// + /// # Examples + /// + /// ``` + /// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics}; + /// + /// let _ = AgglomerativeClustering::new(3, Metrics::Single); + /// ``` + pub fn new(n: usize, method: Metrics) -> Self { + AgglomerativeClustering { + n: n, + method: method, + + distances: None, + merged: None + } + } + + /// train the data and predict the cluster + pub fn train(&mut self, inputs: &Matrix) -> LearningResult> { + let mut dmat = DistanceMatrix::from_mat(&inputs); + + // initialize cluster + let mut clusters: Vec = (0..inputs.rows()).map(|i| Cluster::new(i, vec![i])) + .collect();; + // vec to store merged cluster distances + let mut distances: Vec = Vec::with_capacity(inputs.rows() - self.n); + let mut merged: Vec<(usize, usize)> = Vec::with_capacity(inputs.rows() - self.n); + + let mut id = inputs.rows(); + while clusters.len() > self.n { + let mut tmp_i = 0; + let mut tmp_j = 0; + let mut current_dist = f64::MAX; + + // loop with index to remember the position to be removed + for i in 0..(clusters.len() - 1) { + for j in (i + 1)..clusters.len() { + let ci = unsafe { clusters.get_unchecked(i) }; + let cj = unsafe { clusters.get_unchecked(j) }; + + let d = dmat.get(ci.id, cj.id); + if d < current_dist { + current_dist = d; + tmp_i = i; + tmp_j = j; + } + } + } + + distances.push(current_dist); + + // update cluster + // cj must be first because j > i + let cj = clusters.swap_remove(tmp_j); + let ci = clusters.swap_remove(tmp_i); + merged.push((ci.id, cj.id)); + + // update distances using Lance Williams algorithm + for ck in clusters.iter() { + let d = self.method.dist(&ci, &cj, ck, &dmat); + dmat.insert(ck.id, id, d); + + // remove unnecessary distances + dmat.delete(ck.id, ci.id); + dmat.delete(ck.id, cj.id); + } + + let new = Cluster::from_clusters(id, ci, cj); + id += 1; + clusters.push(new); + } + // store distances + self.distances = Some(distances); + // store merged history + self.merged = Some(merged); + + let mut sorter: BTreeMap = BTreeMap::new(); + for (i, c) in clusters.iter().enumerate() { + for n in c.nodes.iter() { + sorter.insert(*n, i); + } + } + let res: Vec = sorter.values().cloned().collect(); + Ok(Vector::new(res)) + } +} + + +#[cfg(test)] +mod tests { + + use super::{AgglomerativeClustering, DistanceMatrix, Metrics}; + + #[test] + fn test_distance_matrix() { + let data = matrix![1., 2.; + 2., 3.; + 0., 5.; + 3., 3.]; + + let m = DistanceMatrix::from_mat(&data); + + assert_eq!(m.get(0, 0), 0.); + + assert_eq!(m.get(0, 1), 2.0f64.sqrt()); + assert_eq!(m.get(1, 0), 2.0f64.sqrt()); + + assert_eq!(m.get(0, 2), 10.0f64.sqrt()); + assert_eq!(m.get(2, 0), 10.0f64.sqrt()); + + assert_eq!(m.get(0, 3), 5.0f64.sqrt()); + assert_eq!(m.get(3, 0), 5.0f64.sqrt()); + + assert_eq!(m.get(1, 1), 0.); + + assert_eq!(m.get(1, 2), 8.0f64.sqrt()); + assert_eq!(m.get(2, 1), 8.0f64.sqrt()); + + assert_eq!(m.get(1, 3), 1.); + assert_eq!(m.get(3, 1), 1.); + + assert_eq!(m.get(2, 2), 0.); + + assert_eq!(m.get(2, 3), 13.0f64.sqrt()); + assert_eq!(m.get(3, 2), 13.0f64.sqrt()); + } + + #[test] + fn test_distances() { + // test distances are calculated propery + let data = matrix![89., 90., 67. ,46., 50.; + 57., 70., 80., 85., 90.; + 80., 90., 35., 40., 50.; + 40., 60., 50., 45., 55.; + 78., 85., 45., 55., 60.; + 55., 65., 80., 75., 85.; + 90., 85., 88., 92., 95.]; + + let mut hclust = AgglomerativeClustering::new(1, Metrics::Single); + let _ = hclust.train(&data); + let exp = vec![12.409673645990857, 21.307275752662516, 28.478061731796284, + 38.1051177665153, 47.10626285325551, 54.31390245600108]; + assert_eq!(hclust.distances.unwrap(), exp); + let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; + assert_eq!(hclust.merged.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(1, Metrics::Complete); + let _ = hclust.train(&data); + let exp = vec![12.409673645990857, 21.307275752662516, 33.77869150810907, + 45.58508528016593, 60.13318551349163, 91.53141537199127]; + assert_eq!(hclust.distances.unwrap(), exp); + let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; + assert_eq!(hclust.merged.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(1, Metrics::Average); + let _ = hclust.train(&data); + let exp = vec![12.409673645990857, 21.307275752662516, 31.128376619952675, + 41.84510152334062, 53.305905710336944, 69.92295649225116]; + assert_eq!(hclust.distances.unwrap(), exp); + let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; + assert_eq!(hclust.merged.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(1, Metrics::Centroid); + let _ = hclust.train(&data); + let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045, + 38.7426831118429, 44.021013600051624, 44.02758328256392]; + assert_eq!(hclust.distances.unwrap(), exp); + let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; + assert_eq!(hclust.merged.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(1, Metrics::Median); + let _ = hclust.train(&data); + let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045, + 38.7426831118429, 45.898926771596045, 45.42216730738696]; + assert_eq!(hclust.distances.unwrap(), exp); + let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; + assert_eq!(hclust.merged.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward1); + let _ = hclust.train(&data); + let exp = vec![12.409673645990857, 21.307275752662516, 34.4020769090494, + 51.65691081579053, 66.03152040007744, 150.95171411164773]; + assert_eq!(hclust.distances.unwrap(), exp); + let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; + assert_eq!(hclust.merged.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward2); + let _ = hclust.train(&data); + let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334, + 47.97916214358062, 62.481997407253225, 115.91869071527186]; + assert_eq!(hclust.distances.unwrap(), exp); + let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; + assert_eq!(hclust.merged.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward); + let _ = hclust.train(&data); + let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334, + 47.97916214358062, 62.481997407253225, 115.91869071527186]; + assert_eq!(hclust.distances.unwrap(), exp); + let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; + assert_eq!(hclust.merged.unwrap(), exp); + } +} diff --git a/src/lib.rs b/src/lib.rs index a822f58a..4232f6ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ //! - Logistic Regression //! - Generalized Linear Models //! - K-Means Clustering +//! - Agglomerative Clustering //! - Neural Networks //! - Gaussian Process Regression //! - Support Vector Machines @@ -134,6 +135,7 @@ pub mod learning { pub mod lin_reg; pub mod logistic_reg; pub mod k_means; + pub mod agglomerative; pub mod nnet; pub mod gp; pub mod svm; diff --git a/tests/learning/agglomerative.rs b/tests/learning/agglomerative.rs new file mode 100644 index 00000000..247a5c92 --- /dev/null +++ b/tests/learning/agglomerative.rs @@ -0,0 +1,54 @@ +use rm::linalg::{Matrix, Vector}; +use rm::learning::agglomerative::{AgglomerativeClustering, Metrics}; + +#[test] +fn test_cluster() { + let data = Matrix::new(7, 5, vec![89., 90., 67. ,46., 50., + 57., 70., 80., 85., 90., + 80., 90., 35., 40., 50., + 40., 60., 50., 45., 55., + 78., 85., 45., 55., 60., + 55., 65., 80., 75., 85., + 90., 85., 88., 92., 95.]); + + let mut hclust = AgglomerativeClustering::new(3, Metrics::Single); + let res = hclust.train(&data); + let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); + assert_eq!(res.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(3, Metrics::Complete); + let res = hclust.train(&data); + let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); + assert_eq!(res.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(3, Metrics::Average); + let res = hclust.train(&data); + let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); + assert_eq!(res.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(3, Metrics::Centroid); + let res = hclust.train(&data); + let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); + assert_eq!(res.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(3, Metrics::Median); + let res = hclust.train(&data); + let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); + assert_eq!(res.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward1); + let res = hclust.train(&data); + let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); + assert_eq!(res.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward2); + let res = hclust.train(&data); + let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); + assert_eq!(res.unwrap(), exp); + + let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward); + let res = hclust.train(&data); + let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); + assert_eq!(res.unwrap(), exp); +} + diff --git a/tests/lib.rs b/tests/lib.rs index 11309cd6..df3c4bd4 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -8,6 +8,7 @@ pub mod learning { mod lin_reg; mod k_means; mod gp; + mod agglomerative; pub mod optim { mod grad_desc; From 84fe356d848328d1e8523cb5597676bdabe5438c Mon Sep 17 00:00:00 2001 From: sinhrks Date: Wed, 21 Dec 2016 06:52:54 +0900 Subject: [PATCH 2/2] rename linkage, use debug_assert --- src/learning/agglomerative.rs | 70 ++++++++++++++++----------------- tests/learning/agglomerative.rs | 18 ++++----- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/src/learning/agglomerative.rs b/src/learning/agglomerative.rs index 8093ff9c..a5924ea7 100644 --- a/src/learning/agglomerative.rs +++ b/src/learning/agglomerative.rs @@ -5,12 +5,12 @@ //! # Usage //! //! ``` -//! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics}; +//! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage}; //! use rusty_machine::learning::SupModel; //! use rusty_machine::linalg::{Matrix, Vector}; //! //! let inputs = Matrix::new(4, 2, vec![1., 3., 2., 3., 4., 3., 5., 3.]); -//! let mut agg = AgglomerativeClustering::new(2, Metrics::Single); +//! let mut agg = AgglomerativeClustering::new(2, Linkage::Single); //! //! // Train the model and get the clustering result //! let res = agg.train(&inputs).unwrap(); @@ -27,7 +27,7 @@ use learning::{LearningResult}; /// Agglomerative clustering distances #[derive(Debug)] -pub enum Metrics { +pub enum Linkage { /// Single linkage clustering Single, /// Complete linkage clustering @@ -49,7 +49,7 @@ pub enum Metrics { Ward2, } -impl Metrics { +impl Linkage { // calculate distance using Lance-Williams algorithm fn dist(&self, ci: &Cluster, cj: &Cluster, ck: &Cluster, dmat: &DistanceMatrix) -> f64 { @@ -58,38 +58,38 @@ impl Metrics { let djk = dmat.get(ck.id, cj.id); match self { - &Metrics::Single => { + &Linkage::Single => { // 0.5 * dik + 0.5 * djk + 0. * dij - 0.5 * (dik - djk).abs() dik.min(djk) }, - &Metrics::Complete => { + &Linkage::Complete => { // 0.5 * dik + 0.5 * djk + 0. * dij + 0.5 * (dik - djk).abs() dik.max(djk) }, - &Metrics::Average => { + &Linkage::Average => { let s = ci.size + cj.size; ci.size / s * dik + cj.size / s * djk }, - &Metrics::Centroid => { + &Linkage::Centroid => { let s = ci.size + cj.size; let ai = ci.size / s; let aj = cj.size / s; let dij = dmat.get(ci.id, cj.id); ai * dik + aj * djk - ai * aj * dij }, - &Metrics::Median => { + &Linkage::Median => { let dij = dmat.get(ci.id, cj.id); 0.5 * dik + 0.5 * djk - 0.25 * dij }, - &Metrics::Ward1 => { + &Linkage::Ward1 => { let s = ci.size + cj.size + ck.size; let dij = dmat.get(ci.id, cj.id); (ci.size + ck.size) / s * dik + (cj.size + ck.size) / s * djk - ck.size / s * dij }, - &Metrics::Ward | &Metrics::Ward2 => { + &Linkage::Ward | &Linkage::Ward2 => { let s = ci.size + cj.size + ck.size; let dij = dmat.get(ci.id, cj.id); - ((ci.size + ck.size) / s * dik.powf(2.) + (cj.size + ck.size) / s * djk.powf(2.) - ck.size / s * dij.powf(2.)).sqrt() + ((ci.size + ck.size) / s * dik * dik + (cj.size + ck.size) / s * djk * djk - ck.size / s * dij * dij).sqrt() } } } @@ -147,11 +147,11 @@ impl DistanceMatrix { unsafe { for i in 0..n { - for j in i..inputs.rows() { + for j in (i + 1)..inputs.rows() { let mut val = 0.; for k in 0..inputs.cols() { - val += (inputs.get_unchecked([i, k]) - - inputs.get_unchecked([j, k])).abs().powf(2.); + let d = inputs.get_unchecked([i, k]) - inputs.get_unchecked([j, k]); + val += d * d; } val = val.sqrt(); data.insert((i, j), val); @@ -177,13 +177,13 @@ impl DistanceMatrix { /// Add distance between i-th and j-th item /// i must be smaller than j fn insert(&mut self, i: usize, j: usize, dist: f64) { - assert!(i < j, "i must be smaller than j"); + debug_assert!(i < j, "i must be smaller than j"); self.data.insert((i, j), dist); } /// Delete distance between i-th and j-th item fn delete(&mut self, i: usize, j: usize) { - assert!(i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0"); + debug_assert!(i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0"); if i > j { self.data.remove(&(j, i)); } else { @@ -196,7 +196,7 @@ impl DistanceMatrix { #[derive(Debug)] pub struct AgglomerativeClustering { n: usize, - method: Metrics, + linkage: Linkage, // internally stores distances / merged history (currently for testing) distances: Option>, @@ -208,19 +208,19 @@ impl AgglomerativeClustering { /// Constructs an untrained Decision Tree with specified /// /// - `n` - Number of clusters - /// - `method` - Distance metrics + /// - `linkage` - Linkage method /// /// # Examples /// /// ``` - /// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics}; + /// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage}; /// - /// let _ = AgglomerativeClustering::new(3, Metrics::Single); + /// let _ = AgglomerativeClustering::new(3, Linkage::Single); /// ``` - pub fn new(n: usize, method: Metrics) -> Self { + pub fn new(n: usize, linkage: Linkage) -> Self { AgglomerativeClustering { n: n, - method: method, + linkage: linkage, distances: None, merged: None @@ -269,7 +269,7 @@ impl AgglomerativeClustering { // update distances using Lance Williams algorithm for ck in clusters.iter() { - let d = self.method.dist(&ci, &cj, ck, &dmat); + let d = self.linkage.dist(&ci, &cj, ck, &dmat); dmat.insert(ck.id, id, d); // remove unnecessary distances @@ -301,7 +301,7 @@ impl AgglomerativeClustering { #[cfg(test)] mod tests { - use super::{AgglomerativeClustering, DistanceMatrix, Metrics}; + use super::{AgglomerativeClustering, DistanceMatrix, Linkage}; #[test] fn test_distance_matrix() { @@ -348,7 +348,7 @@ mod tests { 55., 65., 80., 75., 85.; 90., 85., 88., 92., 95.]; - let mut hclust = AgglomerativeClustering::new(1, Metrics::Single); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Single); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 28.478061731796284, 38.1051177665153, 47.10626285325551, 54.31390245600108]; @@ -356,7 +356,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Complete); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Complete); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 33.77869150810907, 45.58508528016593, 60.13318551349163, 91.53141537199127]; @@ -364,7 +364,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Average); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Average); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 31.128376619952675, 41.84510152334062, 53.305905710336944, 69.92295649225116]; @@ -372,7 +372,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Centroid); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Centroid); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045, 38.7426831118429, 44.021013600051624, 44.02758328256392]; @@ -380,7 +380,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Median); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Median); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045, 38.7426831118429, 45.898926771596045, 45.42216730738696]; @@ -388,7 +388,7 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward1); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward1); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 34.4020769090494, 51.65691081579053, 66.03152040007744, 150.95171411164773]; @@ -396,18 +396,18 @@ mod tests { let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward2); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward2); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334, - 47.97916214358062, 62.481997407253225, 115.91869071527186]; + 47.97916214358062, 62.48199740725323, 115.91869071527186]; assert_eq!(hclust.distances.unwrap(), exp); let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward); + let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward); let _ = hclust.train(&data); let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334, - 47.97916214358062, 62.481997407253225, 115.91869071527186]; + 47.97916214358062, 62.48199740725323, 115.91869071527186]; assert_eq!(hclust.distances.unwrap(), exp); let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)]; assert_eq!(hclust.merged.unwrap(), exp); diff --git a/tests/learning/agglomerative.rs b/tests/learning/agglomerative.rs index 247a5c92..6306f1ec 100644 --- a/tests/learning/agglomerative.rs +++ b/tests/learning/agglomerative.rs @@ -1,5 +1,5 @@ use rm::linalg::{Matrix, Vector}; -use rm::learning::agglomerative::{AgglomerativeClustering, Metrics}; +use rm::learning::agglomerative::{AgglomerativeClustering, Linkage}; #[test] fn test_cluster() { @@ -11,42 +11,42 @@ fn test_cluster() { 55., 65., 80., 75., 85., 90., 85., 88., 92., 95.]); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Single); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Single); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Complete); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Complete); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Average); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Average); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Centroid); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Centroid); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Median); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Median); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward1); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward1); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward2); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward2); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp); - let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward); + let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward); let res = hclust.train(&data); let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]); assert_eq!(res.unwrap(), exp);