-
Notifications
You must be signed in to change notification settings - Fork 0
created first version of api. ffi calls and example missing CU-861n1bn07 #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: train_with_val_data
Are you sure you want to change the base?
Conversation
leofidus
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is definitely on the right way
| /// Builder for the Booster. | ||
| /// | ||
| /// Uses TypeState Pattern to make sure that Training Data is added | ||
| /// so that Validation can be synced properly and params are present for training. | ||
| #[derive(Default, Clone)] | ||
| pub struct BoosterBuilder<T: Clone, P: Clone> { | ||
| train_data: T, | ||
| val_data: Vec<DataSet>, | ||
| params: P, // after #3 should this be a struct | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a neat pattern, and it looks like Rustdoc doesn't have issues with it either.
| impl<P: Clone> BoosterBuilder<TrainDataAdded, P> { | ||
| pub fn add_val_data(mut self, val: DataSet) -> Self { | ||
| self.val_data.push(val); | ||
| self | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
People tend to add validation data after training data, but is there a reason this isn't just implemented on BoosterBuilder<T,P>?
I guess it helps with validation to restrict it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My reason is that I'm not sure if loading the datasets in fit() is the best way to approach it. Restricting it like that always keeps the possibility open to load the dataset at the add_val_data call directly. This would make duplicate() also a lot more efficient.
| pub struct Booster { | ||
| handle: lightgbm_sys::BoosterHandle, | ||
| train_data: DataSet, | ||
| validation_data: Vec<DataSet>, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these meant to be LoadedDatasets? I would assume fit loads the data
What about Boosters that were trained in advance and are loaded from file? What would their train_data and validation_data be? (and does it even make sense to hold onto these potentially huge datasets?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you're right. This is a refactoring artifact and should just be the Dataset pointers. Now that you say it, it probably doesn't make sense to add them to the booster, if we build them with fit() anyway.
| pub enum DataFormat { | ||
| File { | ||
| path: String, | ||
| }, | ||
| Vecs { | ||
| x: InputMatrix, | ||
| y: OutputVec, | ||
| }, | ||
| #[cfg(feature = "dataframe")] | ||
| DataFrame { | ||
| df: DataFrame, | ||
| y_column: String, | ||
| }, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is to make Datasets clonable?
It feels like it makes datasets and error handling a bit more complicated, compared to just directly loading them. It would also prevent future load_* functions that only take a reference to properly laid-out data (maybe loading nalgebra arrays,if the support that?).
I think (suspect) you can implement clone on Dataset by calling LGBM_DatasetCreateByReference(h_old, rows, &mut h_new) followed by LGBM_DatasetAddFeaturesFrom(h_new, h_old). But maybe that's completely wrong, the documentation is incredibly vague.
matthiasvedder
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very promising overall.
src/booster/builder.rs
Outdated
| #[derive(Clone)] | ||
| pub struct TrainDataAdded(DataSet); // this should not implement default, so it can safely be used for construction | ||
| #[derive(Default, Clone)] | ||
| pub struct TrainDataNotAdded; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about the resulting API, I was wondering whether these structs could have names where the important part stands out more? Like WithTrainData and NoTrainData. The difference is the very first word, instead of (not) having a Not added in the middle of a fairly long type name.
Same for Params.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem with the TypeState Pattern is, that you can't set error messages. However, structs appear in the error message when you try to call a function from a different implementation. Example:
25 | pub struct BoosterBuilder<T: Clone, P: Clone> {
| --------------------------------------------- method `fit` not found for this struct
...
109 | let builder = Booster::builder().fit();
| ^^^ method not found in `BoosterBuilder<TrainDataNotAdded, ParamsNotAdded>`
|
= note: the method was found for
- `BoosterBuilder<TrainDataAdded, ParamsAdded>`
Since this is the only point where the user actually encounters the structs, I named them so that the Error message sounds natural.
But your suggestion would work as well.
| /// Returns the Builder and a clone from it. Useful if you want to train 2 models with | ||
| /// only a couple differences | ||
| pub fn duplicate(self) -> (Self, Self) { | ||
| (self.clone(), self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate why this function returns two instances of Self and why self is the second one?
Calling this like let (other, me) = me.duplicate(); feels a bit weird at first glance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(self, self.clone()) would work as well.
I added this function, so that you can call it after you defined everything that 2 boosters have in common and then add the differences, like this:
let (bst_low_lr, bst_high_lr) = Booster::builder()
.add_train_data(dataset)
.add_val_data(another_dataset)
.add_val_data(also_a_dataset)
.duplicate();
let bst_low_lr = bst_low_lr.add_params(params_a).fit()?;
let bst_high_lr = bst_high_lr.add_params(params_b).fit()?;There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your example does look clean.
It would restrict us to 2 boosters. I don't know if comparing 3 or more boosters does make any sense.
Eventually, examples like these should be part of the docs, they help understand the API better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want more than 2 boosters, you'd probably use clone again, for example if you have a Vec of params you want to test you could do
let src_bst = Booster::builder()
.add_train_data(dataset)
.add_val_data(another_dataset)
.add_val_data(also_a_dataset);
let boosters = params.map(|p| src_bst.clone().add_params(p).fit())
.filter_map(|booster| booster.ok());duplicate() is a bit of a special case, but I think it's nice to have.
|
I think the rewrite is so far done, that we can accept this pr. Any feedback? |
| validation_data: Vec<LoadedDataSet>, | ||
| } | ||
|
|
||
| // exchange params method as well? does this make sense? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leftover comment from the development stage?
| /// # Ok(())} | ||
| /// ``` | ||
| pub fn predict(&self, x: &Matrixf64) -> Result<Matrixf64, LgbmError> { | ||
| let prediction_params = ""; // do we need this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we?
| .collect()) | ||
| } | ||
|
|
||
| /// this should take &mut self, because it changes the model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really a doc comment.
| /// This should not reset the already existing submodels. | ||
| /// Pass an empty array as validation data, if you don't want to validate the train results. | ||
| /// TODO validate this after implemented | ||
| pub fn finetune( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should happen with this code? Delete it?
| /// The DatasetHandle is returned by the lightgbm ffi. | ||
| pub struct LoadedDataSet { | ||
| pub(crate) handle: DatasetHandle, | ||
| dataset: DataSet, // this can maybe be removed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since clippy warns about it, it should be removed.
|
Task linked: CU-861n1bn07 LightGBM API Rewrite |
I created the first version of the new API and would like some feedback. The booster is built with the TypeState pattern, so logic errors when building the booster are caught by the compiler.
I also added
add_pramasas a builder method, so that you can either have different params or datasets after duplicating a builder.ffi code and tests are still missing.
You can also ignore the changes in
old_booster.rsandold_dataset.rs, refactoring got a little messy.