-
Notifications
You must be signed in to change notification settings - Fork 5
feature(sunjx): add high entropy token selection #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feature(sunjx): add high entropy token selection #6
Conversation
lightrft/trainer/fast_exp_maker.py
Outdated
| info, | ||
| kl, | ||
| ) | ||
| exp.action_entropy = output.action_entropy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to include action_entropy in the ExperienceVL and Experience definitions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, we should enable it in the dataclass definition and set the default value to None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and the below creation can also include action_entropy, we don't need extra assignment code
| parser.add_argument("--use_cpg_loss", action="store_true", default=False, help="whether to use the clipped policy gradient loss from CPGD") | ||
|
|
||
| # High-entropy token filtering (from "Beyond the 80/20 Rule" paper) | ||
| parser.add_argument("--high_entropy_token_ratio", type=float, default=0.0, help="Ratio of high-entropy tokens to use for gradient updates (0.0 means use all tokens, 0.2 means use top 20% highest entropy tokens). Common value when enabled: 0.2. Based on 'Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective Reinforcement Learning for LLM Reasoning' (https://arxiv.org/abs/2506.01939)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented a configuration option that allows high-entropy tokens within the stored trajectory to be saved with a distinct special token/marker.
…iew in PolicyLoss
| @@ -0,0 +1,532 @@ | |||
| <!DOCTYPE html> | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should move this visualization to examples/entropy_viz directory.
quick_view.sh
Outdated
| @@ -0,0 +1,13 @@ | |||
| #!/bin/bash | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use English statements and comments
lightrft/trainer/spmd_ppo_trainer.py
Outdated
| # Create entropy_mask if high_entropy_token_ratio > 0 and action_entropy is available | ||
| entropy_mask = None | ||
| if hasattr(experience, 'action_entropy') and experience.action_entropy is not None: | ||
| if hasattr(self.actor, 'high_entropy_token_ratio') and self.actor.high_entropy_token_ratio > 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not use self.high_entropy_token_ratio
lightrft/trainer/spmd_ppo_trainer.py
Outdated
| entropy_mask = None | ||
| if hasattr(experience, 'action_entropy') and experience.action_entropy is not None: | ||
| if hasattr(self.actor, 'high_entropy_token_ratio') and self.actor.high_entropy_token_ratio > 0.0: | ||
| from lightrft.models.utils import create_high_entropy_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move import to the top side
lightrft/trainer/fast_exp_maker.py
Outdated
| info, | ||
| kl, | ||
| ) | ||
| exp.action_entropy = output.action_entropy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, we should enable it in the dataclass definition and set the default value to None
lightrft/trainer/fast_exp_maker.py
Outdated
| info, | ||
| kl, | ||
| ) | ||
| exp.action_entropy = output.action_entropy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and the below creation can also include action_entropy, we don't need extra assignment code
lightrft/models/actor_language.py
Outdated
| if return_output: | ||
| # Include action_entropy in output if computed | ||
| if action_entropy is not None: | ||
| output_dict = dict(output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to transform output into a dict, what is the original type of output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.model is an AutoModelForCausalLM, and its forward() method returns a subclass of ModelOutput (such as CausalLMOutputWithPast).
Why is it necessary to convert it to a dictionary?
• ModelOutput is a fixed dataclass:
• It supports dictionary-style access: output["logits"] is valid.
• However, it does not support directly adding new keys: output["action_entropy"] = value will fail (because the fields are fixed).
• Therefore, converting it to a regular dictionary is required to add action_entropy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, add some comments here to explain the data type
lightrft/models/loss.py
Outdated
| self.use_dapo = use_dapo | ||
| self.use_cpg_loss = use_cpg_loss | ||
| self.high_entropy_token_ratio = high_entropy_token_ratio | ||
| self.entropy_mask = entropy_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need entropy_mask in init function? Do we have the default settings for mask?
…mask merging logic in PolicyLoss
No description provided.