Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demos/benches/pq.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use criterion::{criterion_group, criterion_main, Criterion};
use quantization::encoded_vectors::{DistanceType, EncodedVectors, VectorParameters};
use quantization::encoded_vectors_pq::EncodedVectorsPQ;
use quantization::encoded_vectors_pq::{CentroidsParameters, EncodedVectorsPQ};
use rand::Rng;

#[cfg(target_arch = "x86_64")]
Expand All @@ -27,7 +27,7 @@ fn encode_bench(c: &mut Criterion) {
distance_type: DistanceType::Dot,
invert: false,
},
2,
CentroidsParameters::KMeans { chunk_size: 2 },
num_cpus::get(),
|| false,
)
Expand Down
6 changes: 4 additions & 2 deletions demos/src/ann_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod ann_benchmark_data;
mod metrics;

use quantization::encoded_vectors::{DistanceType, EncodedVectors};
use quantization::encoded_vectors_pq::EncodedVectorsPQ;
use quantization::encoded_vectors_pq::{CentroidsParameters, EncodedVectorsPQ};
use quantization::{EncodedVectorsU8, VectorParameters};

#[cfg(target_arch = "x86_64")]
Expand Down Expand Up @@ -136,7 +136,9 @@ fn main() {
data_iter,
Vec::<u8>::new(),
&vector_parameters,
args.chunk_size,
CentroidsParameters::KMeans {
chunk_size: args.chunk_size,
},
num_cpus::get(),
|| false,
)
Expand Down
111 changes: 98 additions & 13 deletions quantization/src/encoded_vectors_pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ pub struct EncodedQueryPQ {
lut: Vec<f32>,
}

pub enum CentroidsParameters {
KMeans { chunk_size: usize },
Custom { codebook: Vec<Vec<Vec<f32>>> },
}

#[derive(Serialize, Deserialize)]
struct Metadata {
centroids: Vec<Vec<f32>>,
Expand All @@ -50,31 +55,43 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {
/// * `data` - iterator over original vector data
/// * `storage_builder` - encoding result storage builder
/// * `vector_parameters` - parameters of original vector data (dimension, distance, ect)
/// * `chunk_size` - Max size of f32 chunk that replaced by centroid index (in original vector dimension)
/// * `centroid_parameters` - parameters of centroids (chunk size or custom centroids)
/// * `max_threads` - Max allowed threads for kmeans and encodind process
/// * `stop_condition` - Function that returns `true` if encoding should be stopped
pub fn encode<'a>(
data: impl Iterator<Item = &'a [f32]> + Clone + Send,
mut storage_builder: impl EncodedStorageBuilder<TStorage> + Send,
vector_parameters: &VectorParameters,
chunk_size: usize,
centroid_parameters: CentroidsParameters,
max_kmeans_threads: usize,
stop_condition: impl Fn() -> bool + Sync,
) -> Result<Self, EncodingError> {
debug_assert!(validate_vector_parameters(data.clone(), vector_parameters).is_ok());

// first, divide vector into chunks
let vector_division = Self::get_vector_division(vector_parameters.dim, chunk_size);
let vector_division = match &centroid_parameters {
CentroidsParameters::KMeans { chunk_size } => {
Self::get_vector_division(vector_parameters.dim, *chunk_size)
}
CentroidsParameters::Custom { codebook } => {
Self::get_vector_division_from_codebook(codebook, vector_parameters.dim)?
}
};

// then, find flattened centroid positions
let centroids = Self::find_centroids(
data.clone(),
&vector_division,
vector_parameters,
CENTROIDS_COUNT,
max_kmeans_threads,
&stop_condition,
)?;
let centroids = match centroid_parameters {
CentroidsParameters::KMeans { .. } => Self::find_centroids(
data.clone(),
&vector_division,
vector_parameters,
CENTROIDS_COUNT,
max_kmeans_threads,
&stop_condition,
)?,
CentroidsParameters::Custom { codebook } => {
Self::get_centroids_from_codebook(&codebook, vector_parameters.dim)?
}
};

// finally, encode data
#[allow(clippy::redundant_clone)]
Expand Down Expand Up @@ -108,9 +125,31 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {

pub fn get_quantized_vector_size(
vector_parameters: &VectorParameters,
chunk_size: usize,
centroid_parameters: &CentroidsParameters,
) -> usize {
(0..vector_parameters.dim).step_by(chunk_size).count()
match centroid_parameters {
CentroidsParameters::KMeans { chunk_size } => {
(0..vector_parameters.dim).step_by(*chunk_size).count()
}
CentroidsParameters::Custom { codebook } => codebook.len(),
}
}

/// Get codebook. Converts internal centroid format into codebook format
pub fn get_codebook(&self) -> Vec<Vec<Vec<f32>>> {
let mut result = vec![];
for range in &self.metadata.vector_division {
let mut chunk_centroids = vec![];
for i in 0..self.metadata.centroids.len() {
chunk_centroids.push(self.metadata.centroids[i][range.clone()].to_owned());
}
result.push(chunk_centroids);
}
result
}

pub fn get_encoded_vectors(&self) -> &TStorage {
&self.encoded_vectors
}

fn get_vector_division(dim: usize, chunk_size: usize) -> Vec<Range<usize>> {
Expand All @@ -120,6 +159,52 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {
.collect()
}

fn get_vector_division_from_codebook(
codebook: &[Vec<Vec<f32>>],
dim: usize,
) -> Result<Vec<Range<usize>>, EncodingError> {
let mut vector_division: Vec<Range<usize>> = vec![];
for chunk_centroids in codebook {
let chunk_size = chunk_centroids[0].len();
let start = vector_division.last().map(|x| x.end).unwrap_or(0);
let range = start..start + chunk_size;
vector_division.push(range);
}
if vector_division.last().map(|x| x.end).unwrap_or(0) == dim {
Ok(vector_division)
} else {
Err(EncodingError::ArgumentsError(format!(
"Codebook does not match vector dimension {}",
dim
)))
}
}

fn get_centroids_from_codebook(
codebook: &[Vec<Vec<f32>>],
dim: usize,
) -> Result<Vec<Vec<f32>>, EncodingError> {
let centroids_count = codebook[0].len();
if centroids_count != 256 {
return Err(EncodingError::ArgumentsError(format!(
"Centroids count in codebook {} does not equal 256",
centroids_count
)));
}

let mut centroids = vec![];
for i in 0..centroids_count {
let mut centroid = vec![];
for chunk_centroids in codebook {
centroid.extend_from_slice(&chunk_centroids[i]);
}
assert_eq!(centroid.len(), dim);
centroids.push(centroid);
}

Ok(centroids)
}

/// Encode whole storage
///
/// # Arguments
Expand Down
3 changes: 2 additions & 1 deletion quantization/tests/empty_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod metrics;
mod tests {
use quantization::{
encoded_vectors::{DistanceType, EncodedVectors, VectorParameters},
encoded_vectors_pq::CentroidsParameters,
encoded_vectors_u8::EncodedVectorsU8,
EncodedVectorsPQ,
};
Expand Down Expand Up @@ -65,7 +66,7 @@ mod tests {
vector_data.iter().map(|v| v.as_slice()),
Vec::<u8>::new(),
&vector_parameters,
2,
CentroidsParameters::KMeans { chunk_size: 2 },
1,
|| false,
)
Expand Down
3 changes: 2 additions & 1 deletion quantization/tests/stop_condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod tests {

use quantization::{
encoded_vectors::{DistanceType, VectorParameters},
encoded_vectors_pq::CentroidsParameters,
encoded_vectors_u8::EncodedVectorsU8,
EncodedVectorsPQ, EncodingError,
};
Expand Down Expand Up @@ -76,7 +77,7 @@ mod tests {
(0..vector_parameters.count).map(|_| zero_vector.as_slice()),
Vec::<u8>::new(),
&vector_parameters,
2,
CentroidsParameters::KMeans { chunk_size: 2 },
1,
|| stopped_ref.load(Ordering::Relaxed),
)
Expand Down
60 changes: 52 additions & 8 deletions quantization/tests/test_pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod tests {

use quantization::{
encoded_vectors::{DistanceType, EncodedVectors, VectorParameters},
encoded_vectors_pq::EncodedVectorsPQ,
encoded_vectors_pq::{CentroidsParameters, EncodedVectorsPQ},
};
use rand::{Rng, SeedableRng};

Expand Down Expand Up @@ -35,7 +35,7 @@ mod tests {
distance_type: DistanceType::Dot,
invert: false,
},
1,
CentroidsParameters::KMeans { chunk_size: 1 },
1,
|| false,
)
Expand Down Expand Up @@ -67,7 +67,7 @@ mod tests {
distance_type: DistanceType::L2,
invert: false,
},
1,
CentroidsParameters::KMeans { chunk_size: 1 },
1,
|| false,
)
Expand Down Expand Up @@ -99,7 +99,7 @@ mod tests {
distance_type: DistanceType::Dot,
invert: true,
},
1,
CentroidsParameters::KMeans { chunk_size: 1 },
1,
|| false,
)
Expand Down Expand Up @@ -131,7 +131,7 @@ mod tests {
distance_type: DistanceType::L2,
invert: true,
},
1,
CentroidsParameters::KMeans { chunk_size: 1 },
1,
|| false,
)
Expand Down Expand Up @@ -162,7 +162,7 @@ mod tests {
distance_type: DistanceType::Dot,
invert: false,
},
1,
CentroidsParameters::KMeans { chunk_size: 1 },
1,
|| false,
)
Expand Down Expand Up @@ -192,7 +192,7 @@ mod tests {
distance_type: DistanceType::Dot,
invert: true,
},
1,
CentroidsParameters::KMeans { chunk_size: 1 },
1,
|| false,
)
Expand All @@ -205,6 +205,50 @@ mod tests {
}
}

#[test]
fn test_custom_centroids() {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let mut vector_data: Vec<Vec<_>> = vec![];
for _ in 0..VECTORS_COUNT {
vector_data.push((0..VECTOR_DIM).map(|_| rng.gen()).collect());
}

let encoded = EncodedVectorsPQ::encode(
vector_data.iter().map(Vec::as_slice),
vec![],
&VectorParameters {
dim: VECTOR_DIM,
count: VECTORS_COUNT,
distance_type: DistanceType::Dot,
invert: true,
},
CentroidsParameters::KMeans { chunk_size: 1 },
1,
|| false,
)
.unwrap();
let codebook = encoded.get_codebook();

let encoded_custom = EncodedVectorsPQ::encode(
vector_data.iter().map(Vec::as_slice),
vec![],
&VectorParameters {
dim: VECTOR_DIM,
count: VECTORS_COUNT,
distance_type: DistanceType::Dot,
invert: true,
},
CentroidsParameters::Custom { codebook },
1,
|| false,
)
.unwrap();

let data_orig = encoded.get_encoded_vectors().to_owned();
let data_custom = encoded_custom.get_encoded_vectors().to_owned();
assert_eq!(data_orig, data_custom);
}

// ignore this test because it requires long time
// this test should be started separately of with `--test-threads=1` flag
// because `num_threads::num_threads()` is used to check that all encode threads finished
Expand Down Expand Up @@ -243,7 +287,7 @@ mod tests {
distance_type: DistanceType::Dot,
invert: false,
},
1,
CentroidsParameters::KMeans { chunk_size: 1 },
5,
|| false,
)
Expand Down