From ba0a8ffec878092c1a384c2ed353be0a2fab74b8 Mon Sep 17 00:00:00 2001 From: Zijing Zhang <1092727518@qq.com> Date: Wed, 21 Jan 2026 13:05:14 +0000 Subject: [PATCH] add GRPO with trajectory-level deduplication support --- agentlightning/verl/trainer.py | 110 ++++++++++++++++++++++++++++++--- 1 file changed, 101 insertions(+), 9 deletions(-) diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index 94b3276f8..8fc257b3d 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -5,6 +5,7 @@ from __future__ import annotations import random +from collections import defaultdict from contextlib import contextmanager from copy import deepcopy from pprint import pprint @@ -45,6 +46,72 @@ ] +def compute_grpo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + traj_index: np.ndarray | None = None, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: bool = True, + compute_mean_std_cross_all_data: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute advantage for GRPO with trajectory-level deduplication support. + + This is a minimal extension of VeRL's GRPO implementation, adding support for + trajectory-level deduplication via `traj_index` and `compute_mean_std_cross_all_data`. + + Args: + token_level_rewards: Shape (bs, response_length). + response_mask: Shape (bs, response_length). + index: Group index array (e.g., data_id). + traj_index: Trajectory index array (e.g., rollout_id). If None, no deduplication. + epsilon: Small value for numerical stability. + norm_adv_by_std_in_grpo: If True, normalize by std (original GRPO). If False, Dr.GRPO style. + compute_mean_std_cross_all_data: If True (default), compute mean/std across all data. + If False, compute mean/std per unique (index, traj_index) trajectory. + + Returns: + Tuple of (advantages, returns), both shape (bs, response_length). + """ + scores = token_level_rewards.sum(dim=-1) + + id2score: dict = defaultdict(list) + id2mean: dict = {} + id2std: dict = {} + seen_pairs: set = set() + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + # Trajectory deduplication: skip if (index, traj_index) already seen + if traj_index is not None and (index[i], traj_index[i]) in seen_pairs: + continue + id2score[index[i]].append(scores[i]) + # Mark as seen only when compute_mean_std_cross_all_data is False + if traj_index is not None and not compute_mean_std_cross_all_data: + seen_pairs.add((index[i], traj_index[i])) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + scores_tensor = torch.stack(id2score[idx]) + id2mean[idx] = torch.mean(scores_tensor) + id2std[idx] = torch.std(scores_tensor) + else: + raise ValueError(f"no score in prompt index: {idx}") + + for i in range(bsz): + if norm_adv_by_std_in_grpo: + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + else: + scores[i] = scores[i] - id2mean[index[i]] + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + @contextmanager def _timer(name: str, timing_raw: Dict[str, float]): with Timer(name=name, logger=None) as timer: @@ -355,15 +422,40 @@ def _train_step(self, batch_dict: dict) -> dict: "norm_adv_by_std_in_grpo", True ) # GRPO adv normalization factor - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - config=self.config.algorithm, - ) + # compute_mean_std_cross_all_data: trajectory-level advantage computation + # Currently only supported for GRPO algorithm + compute_mean_std_cross_all_data = self.config.algorithm.get("compute_mean_std_cross_all_data", True) + if not compute_mean_std_cross_all_data: + assert self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO, ( + f"compute_mean_std_cross_all_data=False is only supported for GRPO, " + f"got {self.config.algorithm.adv_estimator}" + ) + + # Use local GRPO implementation when compute_mean_std_cross_all_data is disabled + if self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO: + if "response_mask" not in batch.batch: + batch.batch["response_mask"] = compute_response_mask(batch) + traj_index = batch.non_tensor_batch["rollout_id_list"] + advantages, returns = compute_grpo_outcome_advantage( + token_level_rewards=batch.batch["token_level_rewards"], + response_mask=batch.batch["response_mask"], + index=batch.non_tensor_batch["uid"], + traj_index=traj_index, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + compute_mean_std_cross_all_data=compute_mean_std_cross_all_data, + ) + batch.batch["advantages"] = advantages + batch.batch["returns"] = returns + else: + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) # Calculate the metrics before processing. Refer to the comments of function `compute_data_metrics` for details. metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic, suffix="_before_processing"))