From 14f2d3e3c5f4cac997ddb2271360a1cd91942476 Mon Sep 17 00:00:00 2001 From: DearAJ Date: Thu, 4 Dec 2025 16:51:18 +0800 Subject: [PATCH 1/4] add RAFT --- agentlightning/verl/trainer.py | 246 +++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index cec2e9101..3b14f78dc 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -200,6 +200,10 @@ def _validate(self): return test_metrics def _train_step(self, batch_dict: dict) -> dict: + # Check if RAFT mode is enabled + if self.config.algorithm.adv_estimator == "raft": + return self._train_step_raft(batch_dict) + # Isolate in a separate method to automatically recycle the variables before validation. batch: DataProto = DataProto.from_single_dict(batch_dict) metrics = {} @@ -388,6 +392,248 @@ def _train_step(self, batch_dict: dict) -> dict: return metrics + def _train_step_raft(self, batch_dict: dict) -> dict: + """ + RAFT training step: Simplified training loop that only trains on r=1 samples. + + RAFT (Rejection sampling Adaptive Fine-Tuning) differs from GRPO/PPO by: + 1. Rejection sampling: Only keeping samples with reward r=1 + 2. Simple loss: Using standard cross-entropy (NLL) loss instead of advantage-weighted loss + 3. No critic: No value function estimation needed + 4. No advantage: No advantage function or GAE computation needed + """ + batch: DataProto = DataProto.from_single_dict(batch_dict) + metrics = {} + timing_raw = {} + + with _timer("step", timing_raw): + # When agent mode is enabled, we read the batch as it is. + gen_batch = batch + + # Generate rollouts and collect data + with _timer("gen", timing_raw): + self.async_rollout_manager.wake_up() + self.agent_mode_daemon.set_up_data_and_server( + gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses + ) + self.agent_mode_daemon.run_until_all_finished() + batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch( + max_prompt_length=self.config.data.max_prompt_length, + max_response_length=self.config.data.max_response_length, + device=gen_batch.batch["fake_ids"].device, + ) + metrics.update(agent_metrics) + self.agent_mode_daemon.clear_data_and_server() + self.async_rollout_manager.sleep() + + # RAFT Step 1: Rejection Sampling - Filter to keep only r=1 samples + with _timer("rejection_sampling", timing_raw): + # Extract rewards from token_level_scores (sum to get sequence-level reward) + # The reward is stored at the last token position in token_level_scores + sequence_rewards = batch.batch["token_level_scores"].sum(dim=-1) # (batch_size,) + + # Binary reward: 1.0 for success, 0.0 for failure + # In RAFT, we only keep samples with reward == 1.0 + is_positive_reward = (sequence_rewards == 1.0) + positive_indices = is_positive_reward.nonzero(as_tuple=True)[0] + + # Log rejection sampling statistics + n_total = len(batch) + n_positive = len(positive_indices) + n_rejected = n_total - n_positive + metrics["raft/n_total_samples"] = n_total + metrics["raft/n_positive_samples"] = n_positive + metrics["raft/n_rejected_samples"] = n_rejected + metrics["raft/rejection_rate"] = n_rejected / n_total if n_total > 0 else 0.0 + metrics["raft/positive_rate"] = n_positive / n_total if n_total > 0 else 0.0 + + # If no positive samples, skip this training step + if n_positive == 0: + metrics["raft/loss"] = 0.0 + metrics["raft/skipped_no_positive_samples"] = 1 + return metrics + + # Filter batch to keep only positive samples + positive_batch = batch[positive_indices.cpu().tolist()] + + # RAFT Step 2: Compute response mask for the filtered batch + positive_batch.batch["response_mask"] = compute_response_mask(positive_batch) + + # Set uid (required by update_actor, similar to GRPO) + # uid is used for algorithm like GRPO, should be aligned to data id + if "data_id_list" in positive_batch.non_tensor_batch: + positive_batch.non_tensor_batch["uid"] = positive_batch.non_tensor_batch["data_id_list"] + + # Drop samples with prompts that are too long + keep_indices = (~positive_batch.batch["is_drop_mask"]).nonzero(as_tuple=True)[0] + metrics["raft/n_triplets_prompt_too_long"] = ( + positive_batch.batch["is_drop_mask"].shape[0] - keep_indices.shape[0] + ) + if len(keep_indices) == 0: + metrics["raft/loss"] = 0.0 + metrics["raft/skipped_all_dropped"] = 1 + return metrics + positive_batch = positive_batch[keep_indices] + + # Round to mini batch size for efficient training + mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + n_transition = len(positive_batch) + random_indices = list(range(n_transition)) + random.shuffle(random_indices) + positive_batch.reorder(torch.tensor(random_indices).type(torch.int32)) + n_remained_transition = n_transition // mini_batch_size * mini_batch_size + positive_batch = positive_batch[list(range(n_remained_transition))] + metrics["raft/n_triplets_dropped_remainder"] = n_transition - n_remained_transition + + # Balance batch if enabled + if self.config.trainer.balance_batch: + self._balance_batch(positive_batch, metrics=metrics) + + # Pad batch for distributed training + positive_batch, pad_size = pad_dataproto_to_divisor(positive_batch, self.actor_rollout_wg.world_size) + + # RAFT Step 3: Prepare batch for RAFT loss computation + # Remove advantage-related fields since RAFT doesn't use them + raft_batch = positive_batch + max_response_length = raft_batch.batch["responses"].shape[-1] + + # Unpad before computing loss + raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size) + + # RAFT Step 4: Prepare batch for actor update + # Need to compute old_log_probs and set required meta_info fields + with _timer("prepare_raft_batch", timing_raw): + # Ensure uid is set (may have been lost during filtering) + if "data_id_list" in raft_batch.non_tensor_batch: + raft_batch.non_tensor_batch["uid"] = raft_batch.non_tensor_batch["data_id_list"] + + # Compute global_token_num (required by update_actor) + raft_batch.meta_info["global_token_num"] = torch.sum(raft_batch.batch["attention_mask"], dim=-1).tolist() + + # Pad batch for distributed training before computing log_probs + raft_batch, pad_size_prep = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size) + + # Compute old_log_probs (required by update_actor, similar to GRPO) + # This is needed even for RAFT because update_actor expects this field + old_log_prob = self.actor_rollout_wg.compute_log_prob(raft_batch) + entropys = old_log_prob.batch["entropys"] + response_masks = raft_batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + raft_batch = raft_batch.union(old_log_prob) + + # Set required meta_info fields (similar to GRPO) + raft_batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + # Temperature is required by update_actor (from config or default 0.7) + raft_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.get("temperature", 0.7) + + # Unpad before setting advantages + raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size_prep) + + # RAFT Step 5: Pure SFT update + # Use standard cross-entropy loss like TRL's SFTTrainer + # Reference: trl/trainer/sft_trainer.py compute_loss method + with _timer("update_actor_sft", timing_raw): + # Prepare inputs for SFT loss computation (like SFTTrainer) + # SFTTrainer expects: input_ids, attention_mask, labels + input_ids = raft_batch.batch["input_ids"] # (batch_size, seq_len) + attention_mask = raft_batch.batch["attention_mask"] # (batch_size, seq_len) + + # Create labels: -100 for prompt tokens (ignore in loss), actual token IDs for response + # This matches SFTTrainer's label format + labels = input_ids.clone() + # Shift labels for next-token prediction: predict token[t] given tokens[ Date: Fri, 5 Dec 2025 10:59:49 +0800 Subject: [PATCH 2/4] fix wandb log promblem --- agentlightning/verl/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index 3b14f78dc..320b9fa1a 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -595,7 +595,10 @@ def _train_step_raft(self, batch_dict: dict) -> dict: # Extract and log the SFT loss # Use actor loss from update_actor output (same as GRPO) - if "actor/loss" in actor_output_metrics: + # Note: update_actor returns "actor/pg_loss" not "actor/loss" + if "actor/pg_loss" in actor_output_metrics: + metrics["raft/loss"] = actor_output_metrics["actor/pg_loss"] + elif "actor/loss" in actor_output_metrics: metrics["raft/loss"] = actor_output_metrics["actor/loss"] else: # Fallback: use a default value if loss not found From 6e294e0862893bc621107b2efa5d74067d77f89f Mon Sep 17 00:00:00 2001 From: DearAJ Date: Sat, 6 Dec 2025 16:21:58 +0800 Subject: [PATCH 3/4] Remove redundant code. --- agentlightning/verl/trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index 320b9fa1a..7e105ec6c 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -489,17 +489,11 @@ def _train_step_raft(self, batch_dict: dict) -> dict: if self.config.trainer.balance_batch: self._balance_batch(positive_batch, metrics=metrics) - # Pad batch for distributed training - positive_batch, pad_size = pad_dataproto_to_divisor(positive_batch, self.actor_rollout_wg.world_size) - # RAFT Step 3: Prepare batch for RAFT loss computation # Remove advantage-related fields since RAFT doesn't use them raft_batch = positive_batch max_response_length = raft_batch.batch["responses"].shape[-1] - # Unpad before computing loss - raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size) - # RAFT Step 4: Prepare batch for actor update # Need to compute old_log_probs and set required meta_info fields with _timer("prepare_raft_batch", timing_raw): From f51fede8d730ada3189fe705de70d59c3f98c35a Mon Sep 17 00:00:00 2001 From: DearAJ Date: Sat, 6 Dec 2025 22:39:58 +0800 Subject: [PATCH 4/4] Fix (1) The clip disabling mechanism (2) Deleted redundant code related to "labels" --- agentlightning/verl/trainer.py | 46 +++++++++------------------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index 7e105ec6c..3a75a66db 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -528,41 +528,18 @@ def _train_step_raft(self, batch_dict: dict) -> dict: raft_batch = unpad_dataproto(raft_batch, pad_size=pad_size_prep) # RAFT Step 5: Pure SFT update - # Use standard cross-entropy loss like TRL's SFTTrainer - # Reference: trl/trainer/sft_trainer.py compute_loss method + # Use standard cross-entropy loss via PPO with advantages=1.0 and disabled clipping + # Note: PPO loss with advantages=1.0 and no clipping becomes equivalent to SFT with _timer("update_actor_sft", timing_raw): - # Prepare inputs for SFT loss computation (like SFTTrainer) - # SFTTrainer expects: input_ids, attention_mask, labels - input_ids = raft_batch.batch["input_ids"] # (batch_size, seq_len) - attention_mask = raft_batch.batch["attention_mask"] # (batch_size, seq_len) - - # Create labels: -100 for prompt tokens (ignore in loss), actual token IDs for response - # This matches SFTTrainer's label format - labels = input_ids.clone() - # Shift labels for next-token prediction: predict token[t] given tokens[ dict: # Pad again for distributed training before update_actor raft_batch, pad_size_actor = pad_dataproto_to_divisor(raft_batch, self.actor_rollout_wg.world_size) - # Temporarily disable PPO clipping for pure SFT (like SFTTrainer) - # Set clip_ratio to 1.0 effectively disables clipping + # Temporarily disable PPO clipping for pure SFT original_clip_low = self.config.actor_rollout_ref.actor.get("clip_ratio_low", 0.2) original_clip_high = self.config.actor_rollout_ref.actor.get("clip_ratio_high", 0.3) - # Disable clipping: set both ratios to 1.0 (no clipping in pure SFT) - self.config.actor_rollout_ref.actor["clip_ratio_low"] = 1.0 - self.config.actor_rollout_ref.actor["clip_ratio_high"] = 1.0 + # Disable clipping: set both ratios to a very large value (effectively no clipping) + # Using 1000.0 ensures clip(ratio, 1-1, 1+1000) = clip(ratio, 0, 1001) + # which doesn't restrict ratio values in [0, +∞) range + self.config.actor_rollout_ref.actor["clip_ratio_low"] = 1 + self.config.actor_rollout_ref.actor["clip_ratio_high"] = 1000 try: # Update actor with pure SFT loss - # With advantages=1.0 and clip_ratio=1.0, this becomes standard cross-entropy + # With advantages=1.0 and clipping disabled, this becomes standard cross-entropy # This mimics SFTTrainer.compute_loss() behavior actor_output = self.actor_rollout_wg.update_actor(raft_batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) @@ -739,4 +717,4 @@ def fit(self): return progress_bar.update(1) - self.global_steps += 1 + self.global_steps += 1 \ No newline at end of file