From e434da67de3759636a530e626c8d0f26e2340057 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Thu, 13 Nov 2025 23:11:01 +0200 Subject: [PATCH 1/6] ues flags target arch and os for more elegant. --- build.rs | 156 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 88 insertions(+), 68 deletions(-) diff --git a/build.rs b/build.rs index 1d94bfe..b202673 100644 --- a/build.rs +++ b/build.rs @@ -9,22 +9,10 @@ fn get_catboost_version() -> String { env::var("CATBOOST_VERSION").unwrap_or_else(|_| "1.2.8".to_string()) } -fn get_platform_info() -> (String, String) { +fn get_arch_from_target() -> &'static str { let target = env::var("TARGET").unwrap(); - // Determine OS - let os = if target.contains("apple-darwin") { - "darwin" - } else if target.contains("linux") { - "linux" - } else if target.contains("windows") { - "windows" - } else { - panic!("Unsupported target: {}", target); - }; - - // Determine architecture - let arch = if target.contains("x86_64") { + if target.contains("x86_64") { "x86_64" } else if target.contains("aarch64") || target.contains("arm64") { "aarch64" @@ -32,9 +20,41 @@ fn get_platform_info() -> (String, String) { "i686" } else { panic!("Unsupported architecture for target: {}", target); - }; + } +} + +fn get_os_name() -> &'static str { + #[cfg(target_os = "macos")] + { + "darwin" + } - (os.to_string(), arch.to_string()) + #[cfg(target_os = "linux")] + { + "linux" + } + + #[cfg(target_os = "windows")] + { + "windows" + } +} + +fn get_lib_filename() -> &'static str { + #[cfg(target_os = "windows")] + { + "catboostmodel.dll" + } + + #[cfg(target_os = "macos")] + { + "libcatboostmodel.dylib" + } + + #[cfg(target_os = "linux")] + { + "libcatboostmodel.so" + } } fn download_model_interface_headers(out_dir: &Path) -> Result<(), Box> { @@ -66,7 +86,8 @@ fn download_model_interface_headers(out_dir: &Path) -> Result<(), Box Result<(), Box> { - let (os, arch) = get_platform_info(); + let os = get_os_name(); + let arch = get_arch_from_target(); let version = get_catboost_version(); // Create the library directory early @@ -91,7 +112,7 @@ fn download_compiled_library(out_dir: &Path) -> Result<(), Box ( "libcatboostmodel.so".to_string(), format!( @@ -165,7 +186,7 @@ fn download_compiled_library(out_dir: &Path) -> Result<(), Box ( "libcatboostmodel.so".to_string(), format!( @@ -317,15 +338,8 @@ fn main() { .write_to_file(out_dir.join("bindings.rs")) .expect("Couldn't write bindings."); - // 1. Get platform info using your existing function - let (os, _arch) = get_platform_info(); - - // 2. Determine the library filename based on the OS - let lib_filename = match os.as_str() { - "windows" => "catboostmodel.dll", - "darwin" => "libcatboostmodel.dylib", // "darwin" comes from your function - _ => "libcatboostmodel.so", // Default to Linux/Unix - }; + // 1. Get the library filename based on the OS (using compile-time cfg) + let lib_filename = get_lib_filename(); // 3. Copy the library from OUT_DIR/libs to the final target directory let lib_source_path = out_dir.join("libs").join(lib_filename); @@ -342,7 +356,8 @@ fn main() { // On macOS/Linux, change the install name/soname to use @loader_path/$ORIGIN // This needs to be done on the source library in OUT_DIR before linking - if os == "darwin" { + #[cfg(target_os = "macos")] + { use std::process::Command; let _ = Command::new("install_name_tool") .arg("-id") @@ -355,7 +370,10 @@ fn main() { .arg(format!("@loader_path/{}", lib_filename)) .arg(&lib_dest_path) .status(); - } else if os == "linux" { + } + + #[cfg(target_os = "linux")] + { use std::process::Command; // Use patchelf to set soname to just the library filename on Linux (if available) // This is optional - if patchelf is not installed, we just skip it @@ -379,51 +397,53 @@ fn main() { ); // 5. Set the rpath for the run-time linker based on the OS - match os.as_str() { - "darwin" => { - // For macOS, add multiple rpath entries for IDE compatibility - println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path"); - println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../.."); - println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path"); - println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path/../.."); + #[cfg(target_os = "macos")] + { + // For macOS, add multiple rpath entries for IDE compatibility + println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path"); + println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../.."); + println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path"); + println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path/../.."); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); + // Add the target directory to rpath as well + if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { println!( - "cargo:rustc-link-arg=-Wl,-rpath,{}", - lib_search_path.display() + "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", + target_root.display() + ); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/release", + target_root.display() ); - // Add the target directory to rpath as well - if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { - println!( - "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", - target_root.display() - ); - println!( - "cargo:rustc-link-arg=-Wl,-rpath,{}/release", - target_root.display() - ); - } } - "linux" => { - // For Linux, use $ORIGIN - println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN"); - println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../.."); + } + + #[cfg(target_os = "linux")] + { + // For Linux, use $ORIGIN + println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN"); + println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../.."); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}", + lib_search_path.display() + ); + // Add the target directory to rpath as well + if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { println!( - "cargo:rustc-link-arg=-Wl,-rpath,{}", - lib_search_path.display() + "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", + target_root.display() + ); + println!( + "cargo:rustc-link-arg=-Wl,-rpath,{}/release", + target_root.display() ); - // Add the target directory to rpath as well - if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) { - println!( - "cargo:rustc-link-arg=-Wl,-rpath,{}/debug", - target_root.display() - ); - println!( - "cargo:rustc-link-arg=-Wl,-rpath,{}/release", - target_root.display() - ); - } } - _ => {} // No rpath needed for Windows } + // No rpath needed for Windows + println!("cargo:rustc-link-lib=dylib=catboostmodel"); } From 42b07396ce672bde6a5d5369afa81dd4d7b5e902 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Mon, 17 Nov 2025 22:08:28 +0200 Subject: [PATCH 2/6] add polars support. --- Cargo.toml | 4 + src/lib.rs | 5 + src/polars_ext.rs | 307 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 316 insertions(+) create mode 100644 src/polars_ext.rs diff --git a/Cargo.toml b/Cargo.toml index 6c6d3c4..0dff2e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,9 @@ categories = ["science"] readme = "README.md" rust-version = "1.85" +[dependencies] +polars = { version = "0.45", optional = true, default-features = false, features = ["dtype-full"] } + [build-dependencies] bindgen = "0.72.0" ureq = "2.0" @@ -19,6 +22,7 @@ zip = "0.6" [features] gpu = [] +polars = ["dep:polars"] [[example]] name = "basic_usage" diff --git a/src/lib.rs b/src/lib.rs index 8c54142..3303798 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,3 +12,8 @@ pub use crate::features::{ mod model; pub use crate::model::Model; + +#[cfg(feature = "polars")] +mod polars_ext; +#[cfg(feature = "polars")] +pub use crate::polars_ext::ModelPolarsExt; diff --git a/src/polars_ext.rs b/src/polars_ext.rs new file mode 100644 index 0000000..2fba265 --- /dev/null +++ b/src/polars_ext.rs @@ -0,0 +1,307 @@ +use crate::error::{CatBoostError, CatBoostResult}; +use crate::Model; +use polars::prelude::*; + +/// Extension trait for CatBoost Model to support Polars DataFrames +pub trait ModelPolarsExt { + /// Predict using a Polars DataFrame as input (numeric features only) + /// + /// This method efficiently converts the DataFrame to the format CatBoost expects. + /// All numeric columns will be used as float features. + /// + /// # Arguments + /// * `df` - Input DataFrame with numeric features + /// + /// # Returns + /// A vector of prediction values + /// + /// # Example + /// ```no_run + /// # use catboost_rust::{Model, ModelPolarsExt}; + /// # use polars::prelude::*; + /// let model = Model::load("model.cbm").unwrap(); + /// + /// let df = df! { + /// "feature1" => [1.0f32, 2.0, 3.0], + /// "feature2" => [4.0f32, 5.0, 6.0], + /// }.unwrap(); + /// + /// let predictions = model.predict_dataframe(&df).unwrap(); + /// ``` + fn predict_dataframe(&self, df: &DataFrame) -> CatBoostResult>; + + /// Predict using specific columns from a Polars DataFrame + /// + /// # Arguments + /// * `df` - Input DataFrame + /// * `columns` - Column names to use as features (in order) + fn predict_dataframe_with_columns( + &self, + df: &DataFrame, + columns: &[&str], + ) -> CatBoostResult>; + + /// Predict using a DataFrame with both float and categorical features + /// + /// # Arguments + /// * `df` - Input DataFrame + /// * `float_columns` - Names of columns to use as float features + /// * `cat_columns` - Names of columns to use as categorical features (must be String type) + fn predict_dataframe_with_types( + &self, + df: &DataFrame, + float_columns: &[&str], + cat_columns: &[&str], + ) -> CatBoostResult>; +} + +impl ModelPolarsExt for Model { + fn predict_dataframe(&self, df: &DataFrame) -> CatBoostResult> { + let float_features = dataframe_to_float_features(df)?; + let cat_features: Vec> = vec![vec![]; df.height()]; + + self.calc_model_prediction(float_features, cat_features) + } + + fn predict_dataframe_with_columns( + &self, + df: &DataFrame, + columns: &[&str], + ) -> CatBoostResult> { + let column_names: Vec = columns.iter().map(|s| s.to_string()).collect(); + let selected = df.select(column_names).map_err(|e| CatBoostError { + description: format!("Failed to select columns: {}", e), + })?; + + self.predict_dataframe(&selected) + } + + fn predict_dataframe_with_types( + &self, + df: &DataFrame, + float_columns: &[&str], + cat_columns: &[&str], + ) -> CatBoostResult> { + // Extract float features + let float_col_names: Vec = float_columns.iter().map(|s| s.to_string()).collect(); + let float_df = df.select(float_col_names).map_err(|e| CatBoostError { + description: format!("Failed to select float columns: {}", e), + })?; + let float_features = dataframe_to_float_features(&float_df)?; + + // Extract categorical features + let cat_features = if cat_columns.is_empty() { + vec![vec![]; df.height()] + } else { + let cat_col_names: Vec = cat_columns.iter().map(|s| s.to_string()).collect(); + let cat_df = df.select(cat_col_names).map_err(|e| CatBoostError { + description: format!("Failed to select categorical columns: {}", e), + })?; + dataframe_to_cat_features(&cat_df)? + }; + + self.calc_model_prediction(float_features, cat_features) + } +} + +/// Convert a Polars DataFrame to CatBoost float features format (Vec>) +/// +/// Each inner Vec represents one row of features. +fn dataframe_to_float_features(df: &DataFrame) -> CatBoostResult>> { + let num_rows = df.height(); + let num_features = df.width(); + + if num_rows == 0 || num_features == 0 { + return Err(CatBoostError { + description: "DataFrame has zero rows or columns".to_string(), + }); + } + + let mut result = Vec::with_capacity(num_rows); + + // Process row by row + for row_idx in 0..num_rows { + let mut row_features = Vec::with_capacity(num_features); + + for col in df.get_columns() { + let series = col.as_materialized_series(); + let value = extract_f32_value(series, row_idx)?; + row_features.push(value); + } + + result.push(row_features); + } + + Ok(result) +} + +/// Convert a Polars DataFrame to CatBoost categorical features format (Vec>) +fn dataframe_to_cat_features(df: &DataFrame) -> CatBoostResult>> { + let num_rows = df.height(); + let num_features = df.width(); + + if num_rows == 0 || num_features == 0 { + return Err(CatBoostError { + description: "DataFrame has zero rows or columns".to_string(), + }); + } + + let mut result = Vec::with_capacity(num_rows); + + // Process row by row + for row_idx in 0..num_rows { + let mut row_features = Vec::with_capacity(num_features); + + for col in df.get_columns() { + let series = col.as_materialized_series(); + let value = extract_string_value(series, row_idx)?; + row_features.push(value); + } + + result.push(row_features); + } + + Ok(result) +} + +/// Extract an f32 value from a Series at the given index +fn extract_f32_value(series: &Series, idx: usize) -> CatBoostResult { + use DataType::*; + + match series.dtype() { + Float32 => { + let ca = series.f32().map_err(|e| CatBoostError { + description: format!("Failed to cast to f32: {}", e), + })?; + ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + }) + } + Float64 => { + let ca = series.f64().map_err(|e| CatBoostError { + description: format!("Failed to cast to f64: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? as f32) + } + Int8 => { + let ca = series.i8().map_err(|e| CatBoostError { + description: format!("Failed to cast to i8: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? as f32) + } + Int16 => { + let ca = series.i16().map_err(|e| CatBoostError { + description: format!("Failed to cast to i16: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? as f32) + } + Int32 => { + let ca = series.i32().map_err(|e| CatBoostError { + description: format!("Failed to cast to i32: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? as f32) + } + Int64 => { + let ca = series.i64().map_err(|e| CatBoostError { + description: format!("Failed to cast to i64: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? as f32) + } + UInt8 => { + let ca = series.u8().map_err(|e| CatBoostError { + description: format!("Failed to cast to u8: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? as f32) + } + UInt16 => { + let ca = series.u16().map_err(|e| CatBoostError { + description: format!("Failed to cast to u16: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? as f32) + } + UInt32 => { + let ca = series.u32().map_err(|e| CatBoostError { + description: format!("Failed to cast to u32: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? as f32) + } + UInt64 => { + let ca = series.u64().map_err(|e| CatBoostError { + description: format!("Failed to cast to u64: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? as f32) + } + Boolean => { + let ca = series.bool().map_err(|e| CatBoostError { + description: format!("Failed to cast to bool: {}", e), + })?; + Ok(if ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? { + 1.0 + } else { + 0.0 + }) + } + dt => Err(CatBoostError { + description: format!( + "Unsupported data type for float conversion: {}", + dt + ), + }), + } +} + +/// Extract a String value from a Series at the given index +fn extract_string_value(series: &Series, idx: usize) -> CatBoostResult { + use DataType::*; + + match series.dtype() { + String => { + let ca = series.str().map_err(|e| CatBoostError { + description: format!("Failed to cast to String: {}", e), + })?; + Ok(ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })?.to_string()) + } + // Convert numeric types to strings for categorical features + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { + let value = extract_f32_value(series, idx)?; + Ok(format!("{}", value as i64)) + } + Boolean => { + let ca = series.bool().map_err(|e| CatBoostError { + description: format!("Failed to cast to bool: {}", e), + })?; + let val = ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })?; + Ok(if val { "true".to_string() } else { "false".to_string() }) + } + dt => Err(CatBoostError { + description: format!( + "Unsupported data type for categorical conversion: {}", + dt + ), + }), + } +} From 87b7f8d42c8f90c3aa6a8ef6ed5ace0a53dc24ed Mon Sep 17 00:00:00 2001 From: aryehlev Date: Mon, 17 Nov 2025 23:26:43 +0200 Subject: [PATCH 3/6] simplify casting. --- src/polars_ext.rs | 116 +++++++++++++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 37 deletions(-) diff --git a/src/polars_ext.rs b/src/polars_ext.rs index 2fba265..d4ba779 100644 --- a/src/polars_ext.rs +++ b/src/polars_ext.rs @@ -1,5 +1,5 @@ -use crate::error::{CatBoostError, CatBoostResult}; use crate::Model; +use crate::error::{CatBoostError, CatBoostResult}; use polars::prelude::*; /// Extension trait for CatBoost Model to support Polars DataFrames @@ -107,6 +107,7 @@ impl ModelPolarsExt for Model { /// Convert a Polars DataFrame to CatBoost float features format (Vec>) /// /// Each inner Vec represents one row of features. +/// Optimized column-by-column conversion for better cache locality. fn dataframe_to_float_features(df: &DataFrame) -> CatBoostResult>> { let num_rows = df.height(); let num_features = df.width(); @@ -117,25 +118,37 @@ fn dataframe_to_float_features(df: &DataFrame) -> CatBoostResult>> }); } - let mut result = Vec::with_capacity(num_rows); + // Pre-allocate result rows + let mut result: Vec> = (0..num_rows) + .map(|_| Vec::with_capacity(num_features)) + .collect(); - // Process row by row - for row_idx in 0..num_rows { - let mut row_features = Vec::with_capacity(num_features); + // Process column by column - cast to Float32 for simplicity and speed + for column in df.get_columns() { + let series = column.as_materialized_series(); - for col in df.get_columns() { - let series = col.as_materialized_series(); - let value = extract_f32_value(series, row_idx)?; - row_features.push(value); - } + // Cast to Float32 - Polars handles all type conversions efficiently + let f32_series = series.cast(&DataType::Float32).map_err(|e| CatBoostError { + description: format!("Failed to cast column to f32: {}", e), + })?; - result.push(row_features); + let ca = f32_series.f32().map_err(|e| CatBoostError { + description: format!("Failed to get f32 array: {}", e), + })?; + + for (row_idx, opt_val) in ca.iter().enumerate() { + let val = opt_val.ok_or_else(|| CatBoostError { + description: format!("Null value at row {}", row_idx), + })?; + result[row_idx].push(val); + } } Ok(result) } /// Convert a Polars DataFrame to CatBoost categorical features format (Vec>) +/// Optimized column-by-column conversion for better cache locality. fn dataframe_to_cat_features(df: &DataFrame) -> CatBoostResult>> { let num_rows = df.height(); let num_features = df.width(); @@ -146,19 +159,45 @@ fn dataframe_to_cat_features(df: &DataFrame) -> CatBoostResult>> }); } - let mut result = Vec::with_capacity(num_rows); - - // Process row by row - for row_idx in 0..num_rows { + // Fast path for single row + if num_rows == 1 { let mut row_features = Vec::with_capacity(num_features); - for col in df.get_columns() { let series = col.as_materialized_series(); - let value = extract_string_value(series, row_idx)?; + let value = extract_string_value(series, 0)?; row_features.push(value); } + return Ok(vec![row_features]); + } - result.push(row_features); + // Pre-allocate result rows + let mut result: Vec> = (0..num_rows) + .map(|_| Vec::with_capacity(num_features)) + .collect(); + + // Process column by column for better cache locality + for col in df.get_columns() { + let series = col.as_materialized_series(); + + // For String columns, use direct iteration + if matches!(series.dtype(), DataType::String) { + let ca = series.str().map_err(|e| CatBoostError { + description: format!("Failed to cast to String: {}", e), + })?; + + for (row_idx, opt_val) in ca.iter().enumerate() { + let val = opt_val.ok_or_else(|| CatBoostError { + description: format!("Null value at row {}", row_idx), + })?; + result[row_idx].push(val.to_string()); + } + } else { + // Fallback for other types + for row_idx in 0..num_rows { + let value = extract_string_value(series, row_idx)?; + result[row_idx].push(value); + } + } } Ok(result) @@ -253,19 +292,18 @@ fn extract_f32_value(series: &Series, idx: usize) -> CatBoostResult { let ca = series.bool().map_err(|e| CatBoostError { description: format!("Failed to cast to bool: {}", e), })?; - Ok(if ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? { - 1.0 - } else { - 0.0 - }) + Ok( + if ca.get(idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? { + 1.0 + } else { + 0.0 + }, + ) } dt => Err(CatBoostError { - description: format!( - "Unsupported data type for float conversion: {}", - dt - ), + description: format!("Unsupported data type for float conversion: {}", dt), }), } } @@ -279,9 +317,12 @@ fn extract_string_value(series: &Series, idx: usize) -> CatBoostResult { let ca = series.str().map_err(|e| CatBoostError { description: format!("Failed to cast to String: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })?.to_string()) + Ok(ca + .get(idx) + .ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", idx), + })? + .to_string()) } // Convert numeric types to strings for categorical features Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { @@ -295,13 +336,14 @@ fn extract_string_value(series: &Series, idx: usize) -> CatBoostResult { let val = ca.get(idx).ok_or_else(|| CatBoostError { description: format!("Null value at index {}", idx), })?; - Ok(if val { "true".to_string() } else { "false".to_string() }) + Ok(if val { + "true".to_string() + } else { + "false".to_string() + }) } dt => Err(CatBoostError { - description: format!( - "Unsupported data type for categorical conversion: {}", - dt - ), + description: format!("Unsupported data type for categorical conversion: {}", dt), }), } } From 5df8634638834dcfc2dffcdfadb36dae61bdc35c Mon Sep 17 00:00:00 2001 From: aryehlev Date: Mon, 17 Nov 2025 23:49:51 +0200 Subject: [PATCH 4/6] add macro for checking value. --- src/polars_ext.rs | 80 +++++++++++++++++++---------------------------- 1 file changed, 32 insertions(+), 48 deletions(-) diff --git a/src/polars_ext.rs b/src/polars_ext.rs index d4ba779..3f36c3a 100644 --- a/src/polars_ext.rs +++ b/src/polars_ext.rs @@ -203,6 +203,24 @@ fn dataframe_to_cat_features(df: &DataFrame) -> CatBoostResult>> Ok(result) } +/// Helper to extract a value from a ChunkedArray with explicit bounds and null checking +macro_rules! get_checked_value { + ($ca:expr, $idx:expr) => {{ + if $idx >= $ca.len() { + return Err(CatBoostError { + description: format!( + "Index {} out of bounds (length: {})", + $idx, + $ca.len() + ), + }); + } + $ca.get($idx).ok_or_else(|| CatBoostError { + description: format!("Null value at index {}", $idx), + })? + }}; +} + /// Extract an f32 value from a Series at the given index fn extract_f32_value(series: &Series, idx: usize) -> CatBoostResult { use DataType::*; @@ -212,95 +230,68 @@ fn extract_f32_value(series: &Series, idx: usize) -> CatBoostResult { let ca = series.f32().map_err(|e| CatBoostError { description: format!("Failed to cast to f32: {}", e), })?; - ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - }) + Ok(get_checked_value!(ca, idx)) } Float64 => { let ca = series.f64().map_err(|e| CatBoostError { description: format!("Failed to cast to f64: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? as f32) + Ok(get_checked_value!(ca, idx) as f32) } Int8 => { let ca = series.i8().map_err(|e| CatBoostError { description: format!("Failed to cast to i8: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? as f32) + Ok(get_checked_value!(ca, idx) as f32) } Int16 => { let ca = series.i16().map_err(|e| CatBoostError { description: format!("Failed to cast to i16: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? as f32) + Ok(get_checked_value!(ca, idx) as f32) } Int32 => { let ca = series.i32().map_err(|e| CatBoostError { description: format!("Failed to cast to i32: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? as f32) + Ok(get_checked_value!(ca, idx) as f32) } Int64 => { let ca = series.i64().map_err(|e| CatBoostError { description: format!("Failed to cast to i64: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? as f32) + Ok(get_checked_value!(ca, idx) as f32) } UInt8 => { let ca = series.u8().map_err(|e| CatBoostError { description: format!("Failed to cast to u8: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? as f32) + Ok(get_checked_value!(ca, idx) as f32) } UInt16 => { let ca = series.u16().map_err(|e| CatBoostError { description: format!("Failed to cast to u16: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? as f32) + Ok(get_checked_value!(ca, idx) as f32) } UInt32 => { let ca = series.u32().map_err(|e| CatBoostError { description: format!("Failed to cast to u32: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? as f32) + Ok(get_checked_value!(ca, idx) as f32) } UInt64 => { let ca = series.u64().map_err(|e| CatBoostError { description: format!("Failed to cast to u64: {}", e), })?; - Ok(ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? as f32) + Ok(get_checked_value!(ca, idx) as f32) } Boolean => { let ca = series.bool().map_err(|e| CatBoostError { description: format!("Failed to cast to bool: {}", e), })?; - Ok( - if ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? { - 1.0 - } else { - 0.0 - }, - ) + let val = get_checked_value!(ca, idx); + Ok(if val { 1.0 } else { 0.0 }) } dt => Err(CatBoostError { description: format!("Unsupported data type for float conversion: {}", dt), @@ -317,12 +308,7 @@ fn extract_string_value(series: &Series, idx: usize) -> CatBoostResult { let ca = series.str().map_err(|e| CatBoostError { description: format!("Failed to cast to String: {}", e), })?; - Ok(ca - .get(idx) - .ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })? - .to_string()) + Ok(get_checked_value!(ca, idx).to_string()) } // Convert numeric types to strings for categorical features Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { @@ -333,9 +319,7 @@ fn extract_string_value(series: &Series, idx: usize) -> CatBoostResult { let ca = series.bool().map_err(|e| CatBoostError { description: format!("Failed to cast to bool: {}", e), })?; - let val = ca.get(idx).ok_or_else(|| CatBoostError { - description: format!("Null value at index {}", idx), - })?; + let val = get_checked_value!(ca, idx); Ok(if val { "true".to_string() } else { From 3fa43e335484be0dbe3972919e3aed08d8d7c514 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Mon, 17 Nov 2025 23:55:51 +0200 Subject: [PATCH 5/6] fix cargo fmt. --- src/polars_ext.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/polars_ext.rs b/src/polars_ext.rs index 3f36c3a..0a7125a 100644 --- a/src/polars_ext.rs +++ b/src/polars_ext.rs @@ -208,11 +208,7 @@ macro_rules! get_checked_value { ($ca:expr, $idx:expr) => {{ if $idx >= $ca.len() { return Err(CatBoostError { - description: format!( - "Index {} out of bounds (length: {})", - $idx, - $ca.len() - ), + description: format!("Index {} out of bounds (length: {})", $idx, $ca.len()), }); } $ca.get($idx).ok_or_else(|| CatBoostError { From cf4635d4ae7669c0a134dafe3ad3db8419d3de75 Mon Sep 17 00:00:00 2001 From: aryehlev Date: Tue, 18 Nov 2025 00:05:10 +0200 Subject: [PATCH 6/6] add vounds checks. --- src/polars_ext.rs | 74 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 8 deletions(-) diff --git a/src/polars_ext.rs b/src/polars_ext.rs index 0a7125a..8cdec93 100644 --- a/src/polars_ext.rs +++ b/src/polars_ext.rs @@ -82,12 +82,25 @@ impl ModelPolarsExt for Model { float_columns: &[&str], cat_columns: &[&str], ) -> CatBoostResult> { + // Ensure at least one feature type is provided + if float_columns.is_empty() && cat_columns.is_empty() { + return Err(CatBoostError { + description: "Must provide at least one float or categorical column".to_string(), + }); + } + // Extract float features - let float_col_names: Vec = float_columns.iter().map(|s| s.to_string()).collect(); - let float_df = df.select(float_col_names).map_err(|e| CatBoostError { - description: format!("Failed to select float columns: {}", e), - })?; - let float_features = dataframe_to_float_features(&float_df)?; + let float_features = if float_columns.is_empty() { + // No float features - create empty vectors for each row + vec![vec![]; df.height()] + } else { + let float_col_names: Vec = + float_columns.iter().map(|s| s.to_string()).collect(); + let float_df = df.select(float_col_names).map_err(|e| CatBoostError { + description: format!("Failed to select float columns: {}", e), + })?; + dataframe_to_float_features(&float_df)? + }; // Extract categorical features let cat_features = if cat_columns.is_empty() { @@ -307,9 +320,54 @@ fn extract_string_value(series: &Series, idx: usize) -> CatBoostResult { Ok(get_checked_value!(ca, idx).to_string()) } // Convert numeric types to strings for categorical features - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { - let value = extract_f32_value(series, idx)?; - Ok(format!("{}", value as i64)) + // Handle each integer type directly to avoid precision loss + Int8 => { + let ca = series.i8().map_err(|e| CatBoostError { + description: format!("Failed to cast to i8: {}", e), + })?; + Ok(get_checked_value!(ca, idx).to_string()) + } + Int16 => { + let ca = series.i16().map_err(|e| CatBoostError { + description: format!("Failed to cast to i16: {}", e), + })?; + Ok(get_checked_value!(ca, idx).to_string()) + } + Int32 => { + let ca = series.i32().map_err(|e| CatBoostError { + description: format!("Failed to cast to i32: {}", e), + })?; + Ok(get_checked_value!(ca, idx).to_string()) + } + Int64 => { + let ca = series.i64().map_err(|e| CatBoostError { + description: format!("Failed to cast to i64: {}", e), + })?; + Ok(get_checked_value!(ca, idx).to_string()) + } + UInt8 => { + let ca = series.u8().map_err(|e| CatBoostError { + description: format!("Failed to cast to u8: {}", e), + })?; + Ok(get_checked_value!(ca, idx).to_string()) + } + UInt16 => { + let ca = series.u16().map_err(|e| CatBoostError { + description: format!("Failed to cast to u16: {}", e), + })?; + Ok(get_checked_value!(ca, idx).to_string()) + } + UInt32 => { + let ca = series.u32().map_err(|e| CatBoostError { + description: format!("Failed to cast to u32: {}", e), + })?; + Ok(get_checked_value!(ca, idx).to_string()) + } + UInt64 => { + let ca = series.u64().map_err(|e| CatBoostError { + description: format!("Failed to cast to u64: {}", e), + })?; + Ok(get_checked_value!(ca, idx).to_string()) } Boolean => { let ca = series.bool().map_err(|e| CatBoostError {