Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ categories = ["science"]
readme = "README.md"
rust-version = "1.85"

[dependencies]
polars = { version = "0.45", optional = true, default-features = false, features = ["dtype-full"] }

[build-dependencies]
bindgen = "0.72.0"
ureq = "2.0"
Expand All @@ -19,6 +22,7 @@ zip = "0.6"

[features]
gpu = []
polars = ["dep:polars"]

[[example]]
name = "basic_usage"
Expand Down
156 changes: 88 additions & 68 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,52 @@ fn get_catboost_version() -> String {
env::var("CATBOOST_VERSION").unwrap_or_else(|_| "1.2.8".to_string())
}

fn get_platform_info() -> (String, String) {
fn get_arch_from_target() -> &'static str {
let target = env::var("TARGET").unwrap();

// Determine OS
let os = if target.contains("apple-darwin") {
"darwin"
} else if target.contains("linux") {
"linux"
} else if target.contains("windows") {
"windows"
} else {
panic!("Unsupported target: {}", target);
};

// Determine architecture
let arch = if target.contains("x86_64") {
if target.contains("x86_64") {
"x86_64"
} else if target.contains("aarch64") || target.contains("arm64") {
"aarch64"
} else if target.contains("i686") || target.contains("i586") {
"i686"
} else {
panic!("Unsupported architecture for target: {}", target);
};
}
}

fn get_os_name() -> &'static str {
#[cfg(target_os = "macos")]
{
"darwin"
}

(os.to_string(), arch.to_string())
#[cfg(target_os = "linux")]
{
"linux"
}

#[cfg(target_os = "windows")]
{
"windows"
}
}
Comment on lines +26 to +41
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Using #[cfg(target_os = …)] in get_os_name ties behavior to the host, not the Cargo target.

In build scripts, cfg(target_os) refers to the host OS, while TARGET / CARGO_CFG_TARGET_OS describe the compilation target. get_os_name() currently uses cfg(target_os) and is then used to form the download URL and platform tuple in download_compiled_library. This will break cross‑compilation (e.g., building Linux target binaries from macOS) by downloading the wrong library and emitting the wrong errors.

Consider deriving the OS name from CARGO_CFG_TARGET_OS or by parsing TARGET (similar to get_arch_from_target). For example:

-fn get_os_name() -> &'static str {
-    #[cfg(target_os = "macos")]
-    {
-        "darwin"
-    }
-    #[cfg(target_os = "linux")]
-    {
-        "linux"
-    }
-    #[cfg(target_os = "windows")]
-    {
-        "windows"
-    }
-}
+fn get_os_name() -> String {
+    let target_os = std::env::var("CARGO_CFG_TARGET_OS")
+        .expect("CARGO_CFG_TARGET_OS not set by Cargo");
+    match target_os.as_str() {
+        "macos" => "darwin".to_string(),
+        "linux" => "linux".to_string(),
+        "windows" => "windows".to_string(),
+        other => panic!("Unsupported target OS: {}", other),
+    }
+}

You’d then update call sites to accept a String instead of &'static str.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fn get_os_name() -> &'static str {
#[cfg(target_os = "macos")]
{
"darwin"
}
(os.to_string(), arch.to_string())
#[cfg(target_os = "linux")]
{
"linux"
}
#[cfg(target_os = "windows")]
{
"windows"
}
}
fn get_os_name() -> String {
let target_os = std::env::var("CARGO_CFG_TARGET_OS")
.expect("CARGO_CFG_TARGET_OS not set by Cargo");
match target_os.as_str() {
"macos" => "darwin".to_string(),
"linux" => "linux".to_string(),
"windows" => "windows".to_string(),
other => panic!("Unsupported target OS: {}", other),
}
}
🤖 Prompt for AI Agents
In build.rs around lines 26 to 41, get_os_name currently uses #[cfg(target_os =
"...")] which reflects the build host rather than the compilation target;
replace this with logic that reads CARGO_CFG_TARGET_OS (or parses TARGET) from
env::var at compile time and returns an owned String (not &'static str) to
represent the target OS name (e.g., "linux", "darwin", "windows"); update any
call sites (like download_compiled_library and platform tuple construction) to
accept/consume a String (or &str) instead of &'static str and ensure the string
mapping matches the conventions used elsewhere (same canonical names as
get_arch_from_target).


fn get_lib_filename() -> &'static str {
#[cfg(target_os = "windows")]
{
"catboostmodel.dll"
}

#[cfg(target_os = "macos")]
{
"libcatboostmodel.dylib"
}

#[cfg(target_os = "linux")]
{
"libcatboostmodel.so"
}
}
Comment on lines +43 to 58
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

get_lib_filename should also key off the Cargo target OS instead of the host.

Same cross‑compilation concern as get_os_name: #[cfg(target_os = "...")] in a build script evaluates to the host OS. Using it to choose the library filename will misbehave when cross‑compiling (e.g., targeting Windows from Linux).

You can align this with the suggested get_os_name change by either:

  • returning the filename directly from get_os_name (e.g., ("darwin", "libcatboostmodel.dylib")), or
  • introducing a CARGO_CFG_TARGET_OS‑based helper, similar to the previous diff, and using that string to pick the correct filename.
🤖 Prompt for AI Agents
In build.rs around lines 43 to 58, get_lib_filename currently uses host-targeted
#[cfg(target_os = "...")] which breaks cross-compilation; replace this logic so
it keys off the Cargo target OS instead (use env::var("CARGO_CFG_TARGET_OS") or
reuse/extend the proposed get_os_name helper to return the target OS or a (os,
filename) pair) and map the target os string ("windows", "macos"/"darwin",
"linux") to the correct filename ("catboostmodel.dll", "libcatboostmodel.dylib",
"libcatboostmodel.so"); ensure the function uses the target string to select the
filename and keep the return type compatible ( &'static str or adjust callers if
you change signature).


fn download_model_interface_headers(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -66,7 +86,8 @@ fn download_model_interface_headers(out_dir: &Path) -> Result<(), Box<dyn std::e
}

fn download_compiled_library(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
let (os, arch) = get_platform_info();
let os = get_os_name();
let arch = get_arch_from_target();
let version = get_catboost_version();
Comment on lines +89 to 91
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Platform selection in download_compiled_library depends on the same OS source; align with target OS.

The match (os, arch) and later match os blocks assume os is the target OS. With the current get_os_name implementation that’s only true for native builds. For cross‑builds, you’ll end up in the wrong arm of the match or in the _ => Unsupported platform branches even for valid target triples.

Once get_os_name() is switched to derive from the Cargo target (see earlier comment), this function should work as‑is again. After that, it might be worth adding a short debug println! of os/arch and TARGET to help diagnose platform mismatches when users hit the Unsupported platform errors.

Also applies to: 113-187, 189-246


// Create the library directory early
Expand All @@ -91,7 +112,7 @@ fn download_compiled_library(out_dir: &Path) -> Result<(), Box<dyn std::error::E
// Determine download URL based on version and platform
let (lib_filename, download_url) = if use_new_format {
// v1.2+ format with platform and version in filename
match (os.as_str(), arch.as_str()) {
match (os, arch) {
("linux", "x86_64") => (
"libcatboostmodel.so".to_string(),
format!(
Expand Down Expand Up @@ -165,7 +186,7 @@ fn download_compiled_library(out_dir: &Path) -> Result<(), Box<dyn std::error::E
}
} else {
// v1.0.x - v1.1.x format with simple filenames
match os.as_str() {
match os {
"linux" => (
"libcatboostmodel.so".to_string(),
format!(
Expand Down Expand Up @@ -317,15 +338,8 @@ fn main() {
.write_to_file(out_dir.join("bindings.rs"))
.expect("Couldn't write bindings.");

// 1. Get platform info using your existing function
let (os, _arch) = get_platform_info();

// 2. Determine the library filename based on the OS
let lib_filename = match os.as_str() {
"windows" => "catboostmodel.dll",
"darwin" => "libcatboostmodel.dylib", // "darwin" comes from your function
_ => "libcatboostmodel.so", // Default to Linux/Unix
};
// 1. Get the library filename based on the OS (using compile-time cfg)
let lib_filename = get_lib_filename();

// 3. Copy the library from OUT_DIR/libs to the final target directory
let lib_source_path = out_dir.join("libs").join(lib_filename);
Expand All @@ -342,7 +356,8 @@ fn main() {

// On macOS/Linux, change the install name/soname to use @loader_path/$ORIGIN
// This needs to be done on the source library in OUT_DIR before linking
if os == "darwin" {
#[cfg(target_os = "macos")]
{
use std::process::Command;
let _ = Command::new("install_name_tool")
.arg("-id")
Expand All @@ -355,7 +370,10 @@ fn main() {
.arg(format!("@loader_path/{}", lib_filename))
.arg(&lib_dest_path)
.status();
} else if os == "linux" {
}

#[cfg(target_os = "linux")]
{
use std::process::Command;
// Use patchelf to set soname to just the library filename on Linux (if available)
// This is optional - if patchelf is not installed, we just skip it
Expand All @@ -379,51 +397,53 @@ fn main() {
);

// 5. Set the rpath for the run-time linker based on the OS
match os.as_str() {
"darwin" => {
// For macOS, add multiple rpath entries for IDE compatibility
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../..");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path/../..");
#[cfg(target_os = "macos")]
{
// For macOS, add multiple rpath entries for IDE compatibility
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../..");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path");
println!("cargo:rustc-link-arg=-Wl,-rpath,@loader_path/../..");
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
lib_search_path.display()
);
// Add the target directory to rpath as well
if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) {
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
lib_search_path.display()
"cargo:rustc-link-arg=-Wl,-rpath,{}/debug",
target_root.display()
);
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/release",
target_root.display()
);
// Add the target directory to rpath as well
if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) {
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/debug",
target_root.display()
);
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/release",
target_root.display()
);
}
}
"linux" => {
// For Linux, use $ORIGIN
println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN");
println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../..");
}

#[cfg(target_os = "linux")]
{
// For Linux, use $ORIGIN
println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN");
println!("cargo:rustc-link-arg=-Wl,-rpath,$ORIGIN/../..");
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
lib_search_path.display()
);
// Add the target directory to rpath as well
if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) {
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}",
lib_search_path.display()
"cargo:rustc-link-arg=-Wl,-rpath,{}/debug",
target_root.display()
);
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/release",
target_root.display()
);
// Add the target directory to rpath as well
if let Some(target_root) = out_dir.ancestors().find(|p| p.ends_with("target")) {
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/debug",
target_root.display()
);
println!(
"cargo:rustc-link-arg=-Wl,-rpath,{}/release",
target_root.display()
);
}
}
_ => {} // No rpath needed for Windows
}

// No rpath needed for Windows

println!("cargo:rustc-link-lib=dylib=catboostmodel");
}
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@ pub use crate::features::{

mod model;
pub use crate::model::Model;

#[cfg(feature = "polars")]
mod polars_ext;
#[cfg(feature = "polars")]
pub use crate::polars_ext::ModelPolarsExt;
Loading