From 287de76e9c5ab73420cadb955614469f1f475e28 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 01:05:11 +0200 Subject: [PATCH 1/2] support all verions. --- build.rs | 116 +++++++++++++++++++++++++++++++++++++++------------ src/model.rs | 81 +++++++++++++++++++++++++++-------- 2 files changed, 153 insertions(+), 44 deletions(-) diff --git a/build.rs b/build.rs index 52e8380..0444d74 100644 --- a/build.rs +++ b/build.rs @@ -69,37 +69,75 @@ fn download_compiled_library(out_dir: &Path) -> Result<(), Box ( - "libcatboostmodel.so".to_string(), // The correct library name for the linker - format!( - "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-x86_64-{}.so", - version,version + // Parse version to determine URL format + // v1.0.x - v1.1.x use simple filenames + // v1.2+ use platform-specific versioned filenames + let version_parts: Vec<&str> = version.split('.').collect(); + let major: u32 = version_parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(1); + let minor: u32 = version_parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0); + + let use_new_format = major > 1 || (major == 1 && minor >= 2); + + // Determine download URL based on version and platform + let (lib_filename, download_url) = if use_new_format { + // v1.2+ format with platform and version in filename + match (os.as_str(), arch.as_str()) { + ("linux", "x86_64") => ( + "libcatboostmodel.so".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-x86_64-{}.so", + version, version + ), ), - ), - ("linux", "aarch64") => ( - "libcatboostmodel.so".to_string(), - format!( - "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-aarch64-{}.so", - version, version + ("linux", "aarch64") => ( + "libcatboostmodel.so".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-aarch64-{}.so", + version, version + ), ), - ), - ("darwin", "x86_64") | ("darwin", "aarch64") => ( - "libcatboostmodel.dylib".to_string(), // The correct library name for macOS - format!( - "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-darwin-universal2-{}.dylib", - version, version + ("darwin", "x86_64") | ("darwin", "aarch64") => ( + "libcatboostmodel.dylib".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-darwin-universal2-{}.dylib", + version, version + ), ), - ), - ("windows", "x86_64") => ( - "catboostmodel.dll".to_string(), // The correct library name for Windows - format!( - "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll", - version + ("windows", "x86_64") => ( + "catboostmodel.dll".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll", + version + ), + ), + _ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()), + } + } else { + // v1.0.x - v1.1.x format with simple filenames + match os.as_str() { + "linux" => ( + "libcatboostmodel.so".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel.so", + version + ), ), - ), - _ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()), + "darwin" => ( + "libcatboostmodel.dylib".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel.dylib", + version + ), + ), + "windows" => ( + "catboostmodel.dll".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll", + version + ), + ), + _ => return Err(format!("Unsupported platform: {}", os).into()), + } }; println!( @@ -136,6 +174,30 @@ fn main() { let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); let cb_model_interface_root = out_dir.join("libs/model_interface"); + // Parse version for feature detection + let version = get_catboost_version(); + let version_parts: Vec<&str> = version.split('.').collect(); + let major: u32 = version_parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(1); + let minor: u32 = version_parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0); + let patch: u32 = version_parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0); + + // Emit cfg flags for version-specific features + // v1.1.1+: Embedding features support + if major > 1 || (major == 1 && minor > 1) || (major == 1 && minor == 1 && patch >= 1) { + println!("cargo:rustc-cfg=catboost_embeddings"); + } + + // v1.2+: Text features count function + if major > 1 || (major == 1 && minor >= 2) { + println!("cargo:rustc-cfg=catboost_text_count"); + } + + // v1.2.3+: Staged predictions and feature indices + if major > 1 || (major == 1 && minor > 2) || (major == 1 && minor == 2 && patch >= 3) { + println!("cargo:rustc-cfg=catboost_staged_prediction"); + println!("cargo:rustc-cfg=catboost_feature_indices"); + } + // Download the model interface headers if let Err(e) = download_model_interface_headers(&out_dir) { eprintln!("Failed to download model interface headers: {}", e); diff --git a/src/model.rs b/src/model.rs index 8a9835c..9b19180 100644 --- a/src/model.rs +++ b/src/model.rs @@ -183,23 +183,54 @@ impl Model { .collect::>(); let mut prediction = vec![0.0; object_count.unwrap() * self.get_dimensions_count()]; - CatBoostError::check_return_value(unsafe { - sys::CalcModelPredictionWithHashedCatFeaturesAndTextAndEmbeddingFeatures( - self.handle, - object_count.unwrap(), - float_features_ptr.as_mut_ptr(), - if features.float_features.as_ref().is_empty() { 0 } else { features.float_features.as_ref()[0].as_ref().len() }, - hashed_cat_features_ptr.as_mut_ptr(), - if features.cat_features.as_ref().is_empty() { 0 } else { features.cat_features.as_ref()[0].as_ref().len() }, - text_features_ptr.as_mut_ptr(), - if features.text_features.as_ref().is_empty() { 0 } else { features.text_features.as_ref()[0].as_ref().len() }, - embedding_features_ptr.as_mut_ptr(), - embedding_dimensions.as_mut_ptr(), - embedding_dimensions.len(), - prediction.as_mut_ptr(), - prediction.len(), - ) - })?; + + #[cfg(catboost_embeddings)] + { + // v1.1.1+: Use function with embedding support + CatBoostError::check_return_value(unsafe { + sys::CalcModelPredictionWithHashedCatFeaturesAndTextAndEmbeddingFeatures( + self.handle, + object_count.unwrap(), + float_features_ptr.as_mut_ptr(), + if features.float_features.as_ref().is_empty() { 0 } else { features.float_features.as_ref()[0].as_ref().len() }, + hashed_cat_features_ptr.as_mut_ptr(), + if features.cat_features.as_ref().is_empty() { 0 } else { features.cat_features.as_ref()[0].as_ref().len() }, + text_features_ptr.as_mut_ptr(), + if features.text_features.as_ref().is_empty() { 0 } else { features.text_features.as_ref()[0].as_ref().len() }, + embedding_features_ptr.as_mut_ptr(), + embedding_dimensions.as_mut_ptr(), + embedding_dimensions.len(), + prediction.as_mut_ptr(), + prediction.len(), + ) + })?; + } + + #[cfg(not(catboost_embeddings))] + { + // v1.0.x: Use function without embedding support (embeddings will be ignored) + if !features.embedding_features.as_ref().is_empty() { + return Err(CatBoostError { + description: "Embedding features are not supported in this CatBoost version. Please use v1.1.1 or later.".to_string() + }); + } + + CatBoostError::check_return_value(unsafe { + sys::CalcModelPredictionWithHashedCatFeaturesAndTextFeatures( + self.handle, + object_count.unwrap(), + float_features_ptr.as_mut_ptr(), + if features.float_features.as_ref().is_empty() { 0 } else { features.float_features.as_ref()[0].as_ref().len() }, + hashed_cat_features_ptr.as_mut_ptr(), + if features.cat_features.as_ref().is_empty() { 0 } else { features.cat_features.as_ref()[0].as_ref().len() }, + text_features_ptr.as_mut_ptr(), + if features.text_features.as_ref().is_empty() { 0 } else { features.text_features.as_ref()[0].as_ref().len() }, + prediction.as_mut_ptr(), + prediction.len(), + ) + })?; + } + Ok(prediction) } @@ -237,15 +268,31 @@ impl Model { } /// Get expected text feature count for model + /// Only available in CatBoost v1.2+ + #[cfg(catboost_text_count)] pub fn get_text_features_count(&self) -> usize { unsafe { sys::GetTextFeaturesCount(self.handle) } } + /// Get expected text feature count for model (returns 0 for older versions) + #[cfg(not(catboost_text_count))] + pub fn get_text_features_count(&self) -> usize { + 0 + } + /// Get expected embedding feature count for model + /// Only available in CatBoost v1.1.1+ + #[cfg(catboost_embeddings)] pub fn get_embedding_features_count(&self) -> usize { unsafe { sys::GetEmbeddingFeaturesCount(self.handle) } } + /// Get expected embedding feature count for model (returns 0 for older versions) + #[cfg(not(catboost_embeddings))] + pub fn get_embedding_features_count(&self) -> usize { + 0 + } + /// Get number of trees in model pub fn get_tree_count(&self) -> usize { unsafe { sys::GetTreeCount(self.handle)} From 14a4f87aafa0c89a34a6328eca70586ca53c9d42 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Sun, 9 Nov 2025 09:16:45 +0200 Subject: [PATCH 2/2] change winbdows download url. --- build.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/build.rs b/build.rs index 0444d74..4a0658e 100644 --- a/build.rs +++ b/build.rs @@ -106,8 +106,15 @@ fn download_compiled_library(out_dir: &Path) -> Result<(), Box ( "catboostmodel.dll".to_string(), format!( - "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll", - version + "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel-windows-x86_64-{}.dll", + version, version + ), + ), + ("windows", "aarch64") => ( + "catboostmodel.dll".to_string(), + format!( + "https://github.com/catboost/catboost/releases/download/v{}/catboostmodel-windows-aarch64-{}.dll", + version, version ), ), _ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()),