diff --git a/demos/benches/pq.rs b/demos/benches/pq.rs index 5543609..06d3d16 100644 --- a/demos/benches/pq.rs +++ b/demos/benches/pq.rs @@ -1,6 +1,6 @@ use criterion::{criterion_group, criterion_main, Criterion}; use quantization::encoded_vectors::{DistanceType, EncodedVectors, VectorParameters}; -use quantization::encoded_vectors_pq::EncodedVectorsPQ; +use quantization::encoded_vectors_pq::{CentroidsParameters, EncodedVectorsPQ}; use rand::Rng; #[cfg(target_arch = "x86_64")] @@ -27,7 +27,7 @@ fn encode_bench(c: &mut Criterion) { distance_type: DistanceType::Dot, invert: false, }, - 2, + CentroidsParameters::KMeans { chunk_size: 2 }, num_cpus::get(), || false, ) diff --git a/demos/src/ann_benchmark.rs b/demos/src/ann_benchmark.rs index 39101d1..94fee89 100644 --- a/demos/src/ann_benchmark.rs +++ b/demos/src/ann_benchmark.rs @@ -2,7 +2,7 @@ mod ann_benchmark_data; mod metrics; use quantization::encoded_vectors::{DistanceType, EncodedVectors}; -use quantization::encoded_vectors_pq::EncodedVectorsPQ; +use quantization::encoded_vectors_pq::{CentroidsParameters, EncodedVectorsPQ}; use quantization::{EncodedVectorsU8, VectorParameters}; #[cfg(target_arch = "x86_64")] @@ -136,7 +136,9 @@ fn main() { data_iter, Vec::::new(), &vector_parameters, - args.chunk_size, + CentroidsParameters::KMeans { + chunk_size: args.chunk_size, + }, num_cpus::get(), || false, ) diff --git a/quantization/src/encoded_vectors_pq.rs b/quantization/src/encoded_vectors_pq.rs index f35e19e..d1a5d42 100644 --- a/quantization/src/encoded_vectors_pq.rs +++ b/quantization/src/encoded_vectors_pq.rs @@ -36,6 +36,11 @@ pub struct EncodedQueryPQ { lut: Vec, } +pub enum CentroidsParameters { + KMeans { chunk_size: usize }, + Custom { codebook: Vec>> }, +} + #[derive(Serialize, Deserialize)] struct Metadata { centroids: Vec>, @@ -50,31 +55,43 @@ impl EncodedVectorsPQ { /// * `data` - iterator over original vector data /// * `storage_builder` - encoding result storage builder /// * `vector_parameters` - parameters of original vector data (dimension, distance, ect) - /// * `chunk_size` - Max size of f32 chunk that replaced by centroid index (in original vector dimension) + /// * `centroid_parameters` - parameters of centroids (chunk size or custom centroids) /// * `max_threads` - Max allowed threads for kmeans and encodind process /// * `stop_condition` - Function that returns `true` if encoding should be stopped pub fn encode<'a>( data: impl Iterator + Clone + Send, mut storage_builder: impl EncodedStorageBuilder + Send, vector_parameters: &VectorParameters, - chunk_size: usize, + centroid_parameters: CentroidsParameters, max_kmeans_threads: usize, stop_condition: impl Fn() -> bool + Sync, ) -> Result { debug_assert!(validate_vector_parameters(data.clone(), vector_parameters).is_ok()); // first, divide vector into chunks - let vector_division = Self::get_vector_division(vector_parameters.dim, chunk_size); + let vector_division = match ¢roid_parameters { + CentroidsParameters::KMeans { chunk_size } => { + Self::get_vector_division(vector_parameters.dim, *chunk_size) + } + CentroidsParameters::Custom { codebook } => { + Self::get_vector_division_from_codebook(codebook, vector_parameters.dim)? + } + }; // then, find flattened centroid positions - let centroids = Self::find_centroids( - data.clone(), - &vector_division, - vector_parameters, - CENTROIDS_COUNT, - max_kmeans_threads, - &stop_condition, - )?; + let centroids = match centroid_parameters { + CentroidsParameters::KMeans { .. } => Self::find_centroids( + data.clone(), + &vector_division, + vector_parameters, + CENTROIDS_COUNT, + max_kmeans_threads, + &stop_condition, + )?, + CentroidsParameters::Custom { codebook } => { + Self::get_centroids_from_codebook(&codebook, vector_parameters.dim)? + } + }; // finally, encode data #[allow(clippy::redundant_clone)] @@ -108,9 +125,31 @@ impl EncodedVectorsPQ { pub fn get_quantized_vector_size( vector_parameters: &VectorParameters, - chunk_size: usize, + centroid_parameters: &CentroidsParameters, ) -> usize { - (0..vector_parameters.dim).step_by(chunk_size).count() + match centroid_parameters { + CentroidsParameters::KMeans { chunk_size } => { + (0..vector_parameters.dim).step_by(*chunk_size).count() + } + CentroidsParameters::Custom { codebook } => codebook.len(), + } + } + + /// Get codebook. Converts internal centroid format into codebook format + pub fn get_codebook(&self) -> Vec>> { + let mut result = vec![]; + for range in &self.metadata.vector_division { + let mut chunk_centroids = vec![]; + for i in 0..self.metadata.centroids.len() { + chunk_centroids.push(self.metadata.centroids[i][range.clone()].to_owned()); + } + result.push(chunk_centroids); + } + result + } + + pub fn get_encoded_vectors(&self) -> &TStorage { + &self.encoded_vectors } fn get_vector_division(dim: usize, chunk_size: usize) -> Vec> { @@ -120,6 +159,52 @@ impl EncodedVectorsPQ { .collect() } + fn get_vector_division_from_codebook( + codebook: &[Vec>], + dim: usize, + ) -> Result>, EncodingError> { + let mut vector_division: Vec> = vec![]; + for chunk_centroids in codebook { + let chunk_size = chunk_centroids[0].len(); + let start = vector_division.last().map(|x| x.end).unwrap_or(0); + let range = start..start + chunk_size; + vector_division.push(range); + } + if vector_division.last().map(|x| x.end).unwrap_or(0) == dim { + Ok(vector_division) + } else { + Err(EncodingError::ArgumentsError(format!( + "Codebook does not match vector dimension {}", + dim + ))) + } + } + + fn get_centroids_from_codebook( + codebook: &[Vec>], + dim: usize, + ) -> Result>, EncodingError> { + let centroids_count = codebook[0].len(); + if centroids_count != 256 { + return Err(EncodingError::ArgumentsError(format!( + "Centroids count in codebook {} does not equal 256", + centroids_count + ))); + } + + let mut centroids = vec![]; + for i in 0..centroids_count { + let mut centroid = vec![]; + for chunk_centroids in codebook { + centroid.extend_from_slice(&chunk_centroids[i]); + } + assert_eq!(centroid.len(), dim); + centroids.push(centroid); + } + + Ok(centroids) + } + /// Encode whole storage /// /// # Arguments diff --git a/quantization/tests/empty_storage.rs b/quantization/tests/empty_storage.rs index 605fedf..4d6165d 100644 --- a/quantization/tests/empty_storage.rs +++ b/quantization/tests/empty_storage.rs @@ -5,6 +5,7 @@ mod metrics; mod tests { use quantization::{ encoded_vectors::{DistanceType, EncodedVectors, VectorParameters}, + encoded_vectors_pq::CentroidsParameters, encoded_vectors_u8::EncodedVectorsU8, EncodedVectorsPQ, }; @@ -65,7 +66,7 @@ mod tests { vector_data.iter().map(|v| v.as_slice()), Vec::::new(), &vector_parameters, - 2, + CentroidsParameters::KMeans { chunk_size: 2 }, 1, || false, ) diff --git a/quantization/tests/stop_condition.rs b/quantization/tests/stop_condition.rs index b588376..2c56b58 100644 --- a/quantization/tests/stop_condition.rs +++ b/quantization/tests/stop_condition.rs @@ -10,6 +10,7 @@ mod tests { use quantization::{ encoded_vectors::{DistanceType, VectorParameters}, + encoded_vectors_pq::CentroidsParameters, encoded_vectors_u8::EncodedVectorsU8, EncodedVectorsPQ, EncodingError, }; @@ -76,7 +77,7 @@ mod tests { (0..vector_parameters.count).map(|_| zero_vector.as_slice()), Vec::::new(), &vector_parameters, - 2, + CentroidsParameters::KMeans { chunk_size: 2 }, 1, || stopped_ref.load(Ordering::Relaxed), ) diff --git a/quantization/tests/test_pq.rs b/quantization/tests/test_pq.rs index ef4836a..712a4d6 100644 --- a/quantization/tests/test_pq.rs +++ b/quantization/tests/test_pq.rs @@ -7,7 +7,7 @@ mod tests { use quantization::{ encoded_vectors::{DistanceType, EncodedVectors, VectorParameters}, - encoded_vectors_pq::EncodedVectorsPQ, + encoded_vectors_pq::{CentroidsParameters, EncodedVectorsPQ}, }; use rand::{Rng, SeedableRng}; @@ -35,7 +35,7 @@ mod tests { distance_type: DistanceType::Dot, invert: false, }, - 1, + CentroidsParameters::KMeans { chunk_size: 1 }, 1, || false, ) @@ -67,7 +67,7 @@ mod tests { distance_type: DistanceType::L2, invert: false, }, - 1, + CentroidsParameters::KMeans { chunk_size: 1 }, 1, || false, ) @@ -99,7 +99,7 @@ mod tests { distance_type: DistanceType::Dot, invert: true, }, - 1, + CentroidsParameters::KMeans { chunk_size: 1 }, 1, || false, ) @@ -131,7 +131,7 @@ mod tests { distance_type: DistanceType::L2, invert: true, }, - 1, + CentroidsParameters::KMeans { chunk_size: 1 }, 1, || false, ) @@ -162,7 +162,7 @@ mod tests { distance_type: DistanceType::Dot, invert: false, }, - 1, + CentroidsParameters::KMeans { chunk_size: 1 }, 1, || false, ) @@ -192,7 +192,7 @@ mod tests { distance_type: DistanceType::Dot, invert: true, }, - 1, + CentroidsParameters::KMeans { chunk_size: 1 }, 1, || false, ) @@ -205,6 +205,50 @@ mod tests { } } + #[test] + fn test_custom_centroids() { + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let mut vector_data: Vec> = vec![]; + for _ in 0..VECTORS_COUNT { + vector_data.push((0..VECTOR_DIM).map(|_| rng.gen()).collect()); + } + + let encoded = EncodedVectorsPQ::encode( + vector_data.iter().map(Vec::as_slice), + vec![], + &VectorParameters { + dim: VECTOR_DIM, + count: VECTORS_COUNT, + distance_type: DistanceType::Dot, + invert: true, + }, + CentroidsParameters::KMeans { chunk_size: 1 }, + 1, + || false, + ) + .unwrap(); + let codebook = encoded.get_codebook(); + + let encoded_custom = EncodedVectorsPQ::encode( + vector_data.iter().map(Vec::as_slice), + vec![], + &VectorParameters { + dim: VECTOR_DIM, + count: VECTORS_COUNT, + distance_type: DistanceType::Dot, + invert: true, + }, + CentroidsParameters::Custom { codebook }, + 1, + || false, + ) + .unwrap(); + + let data_orig = encoded.get_encoded_vectors().to_owned(); + let data_custom = encoded_custom.get_encoded_vectors().to_owned(); + assert_eq!(data_orig, data_custom); + } + // ignore this test because it requires long time // this test should be started separately of with `--test-threads=1` flag // because `num_threads::num_threads()` is used to check that all encode threads finished @@ -243,7 +287,7 @@ mod tests { distance_type: DistanceType::Dot, invert: false, }, - 1, + CentroidsParameters::KMeans { chunk_size: 1 }, 5, || false, )