[WIP]Fine tune with LayerSkip#20
Conversation
mostafaelhoushi
left a comment
There was a problem hiding this comment.
Thanks @ariG23498 ! Looks great and I like the breakdown.
I have made some comments and suggestions.
fine-tune.py
Outdated
| super().__init__(*args, **kwargs) | ||
| self.model = model | ||
| self.num_layers = model.config.num_hidden_layers | ||
| self.early_exit_layer = 0 |
There was a problem hiding this comment.
We don't have to do it in this PR, but in the future, self.early_exit_layer could be a list of layers
There was a problem hiding this comment.
This is an interesting idea. If I am understanding this correctly you mean we could have a list of indices for the layer we want to exit early self.early_exit_layers=[0, 4, 8]?
There was a problem hiding this comment.
Yes. That is actually what we referred to as "rotational curriculum" in the paper.
There was a problem hiding this comment.
I can create an issue and then do it in another PR if you want.
mostafaelhoushi
left a comment
There was a problem hiding this comment.
Other things we might need:
- add a line to save checkpoint to file (the path could be a CLI argument as well)
- update README with example command to train a model
Leads to: ``` ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Orig Time: 0.8662526607513428 From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`. ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.5897870063781738 For Layer: 1 Speedup: 1.4687550783305965 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.35465383529663086 For Layer: 2 Speedup: 2.4425300801464984 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.36078310012817383 For Layer: 3 Speedup: 2.4010344731878877 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3151836395263672 For Layer: 4 Speedup: 2.748406173788329 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3399159908294678 For Layer: 5 Speedup: 2.5484316246420207 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3960268497467041 For Layer: 6 Speedup: 2.1873584109395403 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.38035011291503906 For Layer: 7 Speedup: 2.2775138782326128 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.4112529754638672 For Layer: 8 Speedup: 2.106374208658952 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3891477584838867 For Layer: 9 Speedup: 2.226025055691568 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.41824865341186523 For Layer: 10 Speedup: 2.0711427369457924 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.43691492080688477 For Layer: 11 Speedup: 1.9826575369675328 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.44853711128234863 For Layer: 12 Speedup: 1.931284254885316 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.45146870613098145 For Layer: 13 Speedup: 1.9187435341310743 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.4784407615661621 For Layer: 14 Speedup: 1.8105745378292801 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.4980161190032959 For Layer: 15 Speedup: 1.7394068739883695 ```
README.md
Outdated
| ## Fine Tune | ||
|
|
||
| To train any supported HuggingFace model with the LayerSkip approach: | ||
| ```bash | ||
| torchrun finetune_layerskip.py \ | ||
| --ckpt facebook/llama2-7B \ | ||
| --ds_ckpt some_dataset \ | ||
| --template "###INST: {utterance}\n\n###RES: {semantic_parse}" \ | ||
| --lr 1e-4 \ | ||
| --batch_size 8 \ | ||
| --epochs 3 \ | ||
| --early_exit_loss_scale 1.0 \ | ||
| --eval_freq 50 \ | ||
| --output_dir ./checkpoints | ||
| ``` |
There was a problem hiding this comment.
Updating here the command to train.py. But can you also check that the command works? Like we can put a command that users can copy and paste to their command line (e.g., put the TopV2 dataset rather than some_dataset
| ## Fine Tune | |
| To train any supported HuggingFace model with the LayerSkip approach: | |
| ```bash | |
| torchrun finetune_layerskip.py \ | |
| --ckpt facebook/llama2-7B \ | |
| --ds_ckpt some_dataset \ | |
| --template "###INST: {utterance}\n\n###RES: {semantic_parse}" \ | |
| --lr 1e-4 \ | |
| --batch_size 8 \ | |
| --epochs 3 \ | |
| --early_exit_loss_scale 1.0 \ | |
| --eval_freq 50 \ | |
| --output_dir ./checkpoints | |
| ``` | |
| ## Train | |
| To train any supported HuggingFace model with the LayerSkip approach: | |
| ```bash | |
| torchrun train.py \ | |
| --ckpt meta-llama/Llama-2-7b-hf \ | |
| --ds_ckpt some_dataset \ | |
| --template "###INST: {utterance}\n\n###RES: {semantic_parse}" \ | |
| --lr 1e-4 \ | |
| --batch_size 8 \ | |
| --epochs 3 \ | |
| --early_exit_loss_scale 1.0 \ | |
| --eval_freq 50 \ | |
| --output_dir ./checkpoints |
This PR is aimed at adding a fine tuning script with LayerSkip.
@mostafaelhoushi would you like to review at the current state and let me know what you think of it, and need in the next iteration?
(Note: This is a WIP and does not support a lot of goodies required for training efficiently.)