Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipe/deepeyes/deepeyes.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __getitem__(self, item):
return row_dict


def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None) -> float:
def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None, **kwargs) -> float:
"""
Compute reward score for model solutions with robust handling of various formats.

Expand Down
9 changes: 5 additions & 4 deletions recipe/transfer_queue/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ def generate_sequences(self, prompts: BatchMeta) -> BatchMeta:
BatchMeta: Output batch metadata.
"""

if self.rm_micro_batch_size and len(prompts) % self.rm_micro_batch_size != 0:
raise ValueError(
f"The length of prompts {len(prompts)} cannot divide the world size of rm_wg {self.rm_micro_batch_size}"
)
if self.config.actor_rollout_ref.rollout.free_cache_engine:
self.wake_up()
if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine:
self.reward_model_manager.wake_up()

chunkes = prompts.chunk(len(self.agent_loop_workers))
outputs = ray.get(
[
Expand All @@ -46,6 +45,8 @@ def generate_sequences(self, prompts: BatchMeta) -> BatchMeta:
output = BatchMeta.concat(outputs)
if self.config.actor_rollout_ref.rollout.free_cache_engine:
self.sleep()
if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine:
self.reward_model_manager.sleep()

# calculate performance metrics
metrics = [output.extra_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]]
Expand Down
14 changes: 10 additions & 4 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(self, config: DictConfig) -> None:


class AgentLoopBase(ABC):
"""An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various
"""An agent loop takes an input message, chat with OpenAI compatible LLM server and interact with various
environments."""

_class_initialized = False
Expand Down Expand Up @@ -608,16 +608,16 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
meta_info={"metrics": metrics, "reward_extra_keys": reward_extra_keys},
)

def create_transferqueue_client(self, controller_infos, storage_infos, role):
def create_transferqueue_client(self, controller_infos, role):
"""Create a client for data system(transfer queue)."""
from verl.single_controller.ray.base import get_random_string
from verl.utils.transferqueue_utils import create_transferqueue_client

client_name = get_random_string(length=6)
create_transferqueue_client(
client_id=f"{role}_worker_{client_name}",
controller_infos=controller_infos,
storage_infos=storage_infos,
controller_info=controller_infos,
config=self.config,
)


Expand Down Expand Up @@ -734,6 +734,11 @@ def _initialize_llm_servers(self):
def _init_agent_loop_workers(self):
self.agent_loop_workers = []
num_workers = self.config.actor_rollout_ref.rollout.agent.num_workers
runtime_env = {
"env_vars": {
"TRANSFER_QUEUE_ENABLE": "1" if self.config.transfer_queue.enable else "0",
}
}

node_ids = [node["NodeID"] for node in ray.nodes() if node["Alive"] and node["Resources"].get("CPU", 0) > 0]
for i in range(num_workers):
Expand All @@ -745,6 +750,7 @@ def _init_agent_loop_workers(self):
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id, soft=True
),
runtime_env=runtime_env,
).remote(self.config, self.server_handles, self.reward_router_address)
)

Expand Down
4 changes: 2 additions & 2 deletions verl/utils/transferqueue_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ class BatchMeta:

_TRANSFER_QUEUE_CLIENT = None

is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False)
is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", "0") == "1"


def create_transferqueue_client(
client_id: str,
controller_info: dict[Any, "ZMQServerInfo"],
controller_info: "ZMQServerInfo",
config,
) -> None:
global _TRANSFER_QUEUE_CLIENT
Expand Down
Loading