diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 98c58f330..f3fcbfa17 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -72,6 +72,63 @@ def _none_special_token_sequence(ids: List[int]) -> List[int]: return is_prefix, (template_mismatch, retoken_mismatch, others_mismatch) +def _normalize_image_url(url: str) -> str: + if url.startswith("file://"): + path = url[len("file://") :] + try: + with open(path, "rb") as handle: + data = handle.read() + import hashlib + + return f"sha1:{hashlib.sha1(data).hexdigest()}" + except Exception: + return path + if url.startswith("data:"): + import base64 + import hashlib + + header, _, payload = url.partition(",") + try: + if ";base64" in header: + data = base64.b64decode(payload) + else: + data = payload.encode("utf-8") + return f"sha1:{hashlib.sha1(data).hexdigest()}" + except Exception: + return f"{header},{payload[:64]}" + return url + + +def image_urls_startswith(full_urls: List[str], prefix_urls: List[str]) -> bool: + if not prefix_urls: + return True + if len(full_urls) < len(prefix_urls): + return False + norm_full = [_normalize_image_url(url) for url in full_urls] + norm_prefix = [_normalize_image_url(url) for url in prefix_urls] + return norm_full[: len(norm_prefix)] == norm_prefix + + +def log_image_mismatch_detail( + full_urls: List[str], + prefix_urls: List[str], + global_steps: int, + rollout_id: str, + turn_id: int, + log_dir: str | None = None, +): + if log_dir is None: + return + os.makedirs(log_dir, exist_ok=True) + with open(os.path.join(log_dir, "image_mismatch.log"), "a+") as f: + print( + "-" * 10 + f" Global Steps: {global_steps}, Rollout ID: {rollout_id}, Turn ID: {turn_id} " + "-" * 10, + file=f, + ) + print([_normalize_image_url(u) for u in full_urls], file=f) + print([_normalize_image_url(u) for u in prefix_urls], file=f) + + def log_mismatch_detail( diagnostic: Tuple[bool, bool, bool], full_ids: List[int], @@ -913,11 +970,11 @@ def get_train_data_batch( image_grid_thw_list.append(self._get_image_grid_thw(image_urls)) elif self.trace_aggregator.get("level", "transition") == "trajectory": - assert not self._use_mrope, "M-RoPE is not supported in trajectory level yet." response_mask_list: List[List[int]] = [] unmerged_count: int = 0 template_mismatch_count, retoken_mismatch_count, others_mismatch_count = 0, 0, 0 + image_mismatch_count = 0 response_per_turn_list: List[int] = [] for rollout_id, sample_info in finished_id_to_sample_info.items(): @@ -926,15 +983,28 @@ def get_train_data_batch( # Identify which turns can be merged based on token ids prefix matching current_merged_trace_idx: List[int] = [] current_context: List[int] = [] + current_image_urls: List[str] = [] for turn_index, trace in enumerate(sample_info["trace_list"]): response_per_turn_list.append(len(trace["response_ids"])) - is_prefix, diagnostic = ids_startswith( + token_prefix_ok, diagnostic = ids_startswith( trace["prompt_ids"] + trace["response_ids"], current_context, self.tokenizer, self.trace_aggregator.get("debug", False), ) - if not is_prefix and self.trace_aggregator.get("debug", False) == True: + image_prefix_ok = image_urls_startswith(trace.get("image_urls", []), current_image_urls) + if not image_prefix_ok: + image_mismatch_count += 1 + if self.trace_aggregator.get("debug", False) == True: + log_image_mismatch_detail( + trace.get("image_urls", []), + current_image_urls, + global_steps, + rollout_id, + turn_index, + self.trace_aggregator.get("mismatch_log_dir", None), + ) + if not token_prefix_ok and self.trace_aggregator.get("debug", False) == True: template_mismatch_count += diagnostic[0] retoken_mismatch_count += diagnostic[1] others_mismatch_count += diagnostic[2] @@ -948,13 +1018,15 @@ def get_train_data_batch( self.trace_aggregator.get("mismatch_log_dir", None), ) - if is_prefix: + if token_prefix_ok and image_prefix_ok: current_context = trace["prompt_ids"] + trace["response_ids"] current_merged_trace_idx.append(turn_index) + current_image_urls = trace.get("image_urls", []) else: merged_trace_idx.append(current_merged_trace_idx) current_merged_trace_idx = [turn_index] current_context = trace["prompt_ids"] + trace["response_ids"] + current_image_urls = trace.get("image_urls", []) if current_merged_trace_idx not in merged_trace_idx: merged_trace_idx.append(current_merged_trace_idx) @@ -1019,6 +1091,10 @@ def get_train_data_batch( response_mask_list.append(one_response_mask) data_id_list.append(sample_info["data_id"]) rollout_id_list.append(rollout_id) + if self._use_mrope: + last_trace = sample_info["trace_list"][current_merged_trace_idx[-1]] + image_urls = last_trace.get("image_urls", []) + image_grid_thw_list.append(self._get_image_grid_thw(image_urls)) # turn_index_list.append(current_merged_trace_idx) else: raise ValueError(f"Unknown trace_aggregator level: {self.trace_aggregator.get('level')}") @@ -1115,9 +1191,11 @@ def get_train_data_batch( "training/template_mismatch_triplets": template_mismatch_count, # type: ignore "training/retoken_mismatch_triplets": retoken_mismatch_count, # type: ignore "training/others_mismatch_triplets": others_mismatch_count, # type: ignore + "training/image_mismatch_triplets": image_mismatch_count, # type: ignore "training/template_mismatch_ratio": template_mismatch_count / len(response_per_turn_list), # type: ignore "training/retoken_mismatch_ratio": retoken_mismatch_count / len(response_per_turn_list), # type: ignore "training/others_mismatch_ratio": others_mismatch_count / len(response_per_turn_list), # type: ignore + "training/image_mismatch_ratio": image_mismatch_count / len(response_per_turn_list), # type: ignore } if self.trace_aggregator.get("level", "transition") == "trajectory" and self.trace_aggregator.get("debug", False)