Skip to content

Conversation

@CallMeMSL
Copy link

Changes:

This implements the addition of a validation set for training and the retrieval of the train/validation score. A bug in the JSON parser for the train params has been fixed, which lets you now specify an array of values for a key. Additionally, bins can now be aligned between datasets, that are loaded from a file.

Problems:

With this implementation, only one validation set can be added.
This pull request breaks the current API by additionally requiring an Optional for the train and from_file methods.
Additionally, I found an instance where the addition of a big validation dataset leads to a segfault in the train step. See the ignored` test.

Possible Solutions

API Changes & only one validation set: This could be changed by changing the API to a builder pattern (as already mentioned). I thought of something along those lines:

let booster = Booster::new(&params)
    .train_data(...)
    .validation_data(...)
    .validation_data(...)
    .validation_data(...)
    .train_step_callback(...)
    .train_step_callback(...)
    .fit() // or fit_predict()

The callbacks could also be used for #6 and maybe should be done together with #3.
As for the broken test: I am not sure what is causing the segfault and I hope that I just somehow just load/handle the validation data wrong.

@CallMeMSL CallMeMSL requested a review from leofidus May 5, 2023 12:34
@CallMeMSL CallMeMSL linked an issue May 5, 2023 that may be closed by this pull request
src/booster.rs Outdated
Comment on lines 279 to 298
let out_strs = (0..num_metrics)
.map(|_| {
CString::new(" ".repeat(metric_name_length))
.unwrap()
.into_raw() as *mut c_char
})
.collect::<Vec<_>>();
lgbm_call!(lightgbm_sys::LGBM_BoosterGetEvalNames(
self.handle,
num_metrics,
&mut num_eval_names,
metric_name_length as u64,
&mut out_buffer_len,
out_strs.as_ptr() as *mut *mut c_char
))?;
let output: Vec<String> = out_strs
.into_iter()
.map(|s| unsafe { CString::from_raw(s).into_string().unwrap() })
.take(num_eval_names as usize)
.collect();
Copy link

@leofidus leofidus May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know the string logic is taken from feature_name(),j but if you are getting segfaults this (and feature_names) might be worth investigating. It is a bit suspicious, especially how the CString::from_raw documentation says you aren't supposed to change the string's length, which this probably does.

A better solution might be to allocate the strings as Vec<u8> initialized with 0s, and read them in with CString::from_vec_with_nul

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a PR that should fix this, and makes the eval_names test pass: #11

@leofidus
Copy link

leofidus commented May 5, 2023

Possible Solutions

API Changes & only one validation set: This could be changed by changing the API to a builder pattern (as already mentioned). I thought of something along those lines:

let booster = Booster::new(&params)
    .train_data(...)
    .validation_data(...)
    .validation_data(...)
    .validation_data(...)
    .train_step_callback(...)
    .train_step_callback(...)
    .fit() // or fit_predict()

The callbacks could also be used for #6 and maybe should be done together with #3. As for the broken test: I am not sure what is causing the segfault and I hope that I just somehow just load/handle the validation data wrong.

That would make a lot of sense. Doesn't have to be part of this PR, we can also do a general cleanup PR that rewrites APIs to make more sense (another point would e.g. be the pervasive use of &str instead of AsRef<Path>, which the original author copied over from the XGBoost crate).

I'd maybe make the params part of the fit function though (unless that causes any problems) to make it easier to train multiple models with the same data but different parameters.

@CallMeMSL CallMeMSL changed the title Train with validation data Train with validation data CU-861n1bn7h Jul 24, 2023
@leofidus
Copy link

Task linked: CU-861n1bn7h LightGBM Validation Data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allow adding validation data, return metrics

3 participants