Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 96 additions & 27 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,37 +69,82 @@ fn download_compiled_library(out_dir: &Path) -> Result<(), Box<dyn std::error::E
let (os, arch) = get_platform_info();
let version = get_catboost_version();

// CORRECT: These URLs and filenames point to the required shared libraries.
let (lib_filename, download_url) = match (os.as_str(), arch.as_str()) {
("linux", "x86_64") => (
"libcatboostmodel.so".to_string(), // The correct library name for the linker
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-x86_64-{}.so",
version,version
// Parse version to determine URL format
// v1.0.x - v1.1.x use simple filenames
// v1.2+ use platform-specific versioned filenames
let version_parts: Vec<&str> = version.split('.').collect();
let major: u32 = version_parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(1);
let minor: u32 = version_parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);

let use_new_format = major > 1 || (major == 1 && minor >= 2);

// Determine download URL based on version and platform
let (lib_filename, download_url) = if use_new_format {
// v1.2+ format with platform and version in filename
match (os.as_str(), arch.as_str()) {
("linux", "x86_64") => (
"libcatboostmodel.so".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-x86_64-{}.so",
version, version
),
),
("linux", "aarch64") => (
"libcatboostmodel.so".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-aarch64-{}.so",
version, version
),
),
("darwin", "x86_64") | ("darwin", "aarch64") => (
"libcatboostmodel.dylib".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-darwin-universal2-{}.dylib",
version, version
),
),
),
("linux", "aarch64") => (
"libcatboostmodel.so".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-aarch64-{}.so",
version, version
("windows", "x86_64") => (
"catboostmodel.dll".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/catboostmodel-windows-x86_64-{}.dll",
version, version
),
),
),
("darwin", "x86_64") | ("darwin", "aarch64") => (
"libcatboostmodel.dylib".to_string(), // The correct library name for macOS
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-darwin-universal2-{}.dylib",
version, version
("windows", "aarch64") => (
"catboostmodel.dll".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/catboostmodel-windows-aarch64-{}.dll",
version, version
),
),
),
("windows", "x86_64") => (
"catboostmodel.dll".to_string(), // The correct library name for Windows
format!(
"https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll",
version
_ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()),
}
} else {
// v1.0.x - v1.1.x format with simple filenames
match os.as_str() {
"linux" => (
"libcatboostmodel.so".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel.so",
version
),
),
"darwin" => (
"libcatboostmodel.dylib".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel.dylib",
version
),
),
),
_ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()),
"windows" => (
"catboostmodel.dll".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll",
version
),
),
_ => return Err(format!("Unsupported platform: {}", os).into()),
}
};

println!(
Expand Down Expand Up @@ -136,6 +181,30 @@ fn main() {
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let cb_model_interface_root = out_dir.join("libs/model_interface");

// Parse version for feature detection
let version = get_catboost_version();
let version_parts: Vec<&str> = version.split('.').collect();
let major: u32 = version_parts.get(0).and_then(|s| s.parse().ok()).unwrap_or(1);
let minor: u32 = version_parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
let patch: u32 = version_parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0);

// Emit cfg flags for version-specific features
// v1.1.1+: Embedding features support
if major > 1 || (major == 1 && minor > 1) || (major == 1 && minor == 1 && patch >= 1) {
println!("cargo:rustc-cfg=catboost_embeddings");
}

// v1.2+: Text features count function
if major > 1 || (major == 1 && minor >= 2) {
println!("cargo:rustc-cfg=catboost_text_count");
}

// v1.2.3+: Staged predictions and feature indices
if major > 1 || (major == 1 && minor > 2) || (major == 1 && minor == 2 && patch >= 3) {
println!("cargo:rustc-cfg=catboost_staged_prediction");
println!("cargo:rustc-cfg=catboost_feature_indices");
}

// Download the model interface headers
if let Err(e) = download_model_interface_headers(&out_dir) {
eprintln!("Failed to download model interface headers: {}", e);
Expand Down
81 changes: 64 additions & 17 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,54 @@ impl Model {
.collect::<Vec<_>>();

let mut prediction = vec![0.0; object_count.unwrap() * self.get_dimensions_count()];
CatBoostError::check_return_value(unsafe {
sys::CalcModelPredictionWithHashedCatFeaturesAndTextAndEmbeddingFeatures(
self.handle,
object_count.unwrap(),
float_features_ptr.as_mut_ptr(),
if features.float_features.as_ref().is_empty() { 0 } else { features.float_features.as_ref()[0].as_ref().len() },
hashed_cat_features_ptr.as_mut_ptr(),
if features.cat_features.as_ref().is_empty() { 0 } else { features.cat_features.as_ref()[0].as_ref().len() },
text_features_ptr.as_mut_ptr(),
if features.text_features.as_ref().is_empty() { 0 } else { features.text_features.as_ref()[0].as_ref().len() },
embedding_features_ptr.as_mut_ptr(),
embedding_dimensions.as_mut_ptr(),
embedding_dimensions.len(),
prediction.as_mut_ptr(),
prediction.len(),
)
})?;

#[cfg(catboost_embeddings)]
{
// v1.1.1+: Use function with embedding support
CatBoostError::check_return_value(unsafe {
sys::CalcModelPredictionWithHashedCatFeaturesAndTextAndEmbeddingFeatures(
self.handle,
object_count.unwrap(),
float_features_ptr.as_mut_ptr(),
if features.float_features.as_ref().is_empty() { 0 } else { features.float_features.as_ref()[0].as_ref().len() },
hashed_cat_features_ptr.as_mut_ptr(),
if features.cat_features.as_ref().is_empty() { 0 } else { features.cat_features.as_ref()[0].as_ref().len() },
text_features_ptr.as_mut_ptr(),
if features.text_features.as_ref().is_empty() { 0 } else { features.text_features.as_ref()[0].as_ref().len() },
embedding_features_ptr.as_mut_ptr(),
embedding_dimensions.as_mut_ptr(),
embedding_dimensions.len(),
prediction.as_mut_ptr(),
prediction.len(),
)
})?;
}

#[cfg(not(catboost_embeddings))]
{
// v1.0.x: Use function without embedding support (embeddings will be ignored)
if !features.embedding_features.as_ref().is_empty() {
return Err(CatBoostError {
description: "Embedding features are not supported in this CatBoost version. Please use v1.1.1 or later.".to_string()
});
}

CatBoostError::check_return_value(unsafe {
sys::CalcModelPredictionWithHashedCatFeaturesAndTextFeatures(
self.handle,
object_count.unwrap(),
float_features_ptr.as_mut_ptr(),
if features.float_features.as_ref().is_empty() { 0 } else { features.float_features.as_ref()[0].as_ref().len() },
hashed_cat_features_ptr.as_mut_ptr(),
if features.cat_features.as_ref().is_empty() { 0 } else { features.cat_features.as_ref()[0].as_ref().len() },
text_features_ptr.as_mut_ptr(),
if features.text_features.as_ref().is_empty() { 0 } else { features.text_features.as_ref()[0].as_ref().len() },
prediction.as_mut_ptr(),
prediction.len(),
)
})?;
}

Ok(prediction)
}

Expand Down Expand Up @@ -237,15 +268,31 @@ impl Model {
}

/// Get expected text feature count for model
/// Only available in CatBoost v1.2+
#[cfg(catboost_text_count)]
pub fn get_text_features_count(&self) -> usize {
unsafe { sys::GetTextFeaturesCount(self.handle) }
}

/// Get expected text feature count for model (returns 0 for older versions)
#[cfg(not(catboost_text_count))]
pub fn get_text_features_count(&self) -> usize {
0
}

/// Get expected embedding feature count for model
/// Only available in CatBoost v1.1.1+
#[cfg(catboost_embeddings)]
pub fn get_embedding_features_count(&self) -> usize {
unsafe { sys::GetEmbeddingFeaturesCount(self.handle) }
}

/// Get expected embedding feature count for model (returns 0 for older versions)
#[cfg(not(catboost_embeddings))]
pub fn get_embedding_features_count(&self) -> usize {
0
}

/// Get number of trees in model
pub fn get_tree_count(&self) -> usize {
unsafe { sys::GetTreeCount(self.handle)}
Expand Down