diff --git a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py index a8e7b229..0bd31455 100644 --- a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py +++ b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py @@ -1,7 +1,9 @@ import os +import logging +import time from typing import List, Optional -from peewee import CharField, Model, SqliteDatabase +from peewee import CharField, Model, SqliteDatabase, FloatField from playhouse.sqlite_ext import JSONField from eval_protocol.models import EvaluationRow @@ -18,6 +20,7 @@ def __init__(self, db_path: str): os.makedirs(os.path.dirname(db_path), exist_ok=True) self._db_path = db_path self._db = SqliteDatabase(self._db_path, pragmas={"journal_mode": "wal"}) + self._logger = logging.getLogger(__name__) class BaseModel(Model): class Meta: @@ -26,12 +29,26 @@ class Meta: class EvaluationRow(BaseModel): # type: ignore rollout_id = CharField(unique=True) data = JSONField() + updated_at = FloatField(default=lambda: time.time()) self._EvaluationRow = EvaluationRow self._db.connect() # Use safe=True to avoid errors when tables/indexes already exist self._db.create_tables([EvaluationRow], safe=True) + # Attempt to add updated_at column for existing installations + try: + columns = {c.name for c in self._db.get_columns(self._EvaluationRow._meta.table_name)} + if "updated_at" not in columns: + self._db.execute_sql( + f'ALTER TABLE "{self._EvaluationRow._meta.table_name}" ADD COLUMN "updated_at" REAL' + ) + # Backfill with current time + now_ts = time.time() + self._EvaluationRow.update(updated_at=now_ts).execute() + except Exception: + # Best-effort; ignore if migration not needed or fails + pass @property def db_path(self) -> str: @@ -44,16 +61,60 @@ def upsert_row(self, data: dict) -> None: with self._db.atomic("EXCLUSIVE"): if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists(): - self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute() + self._EvaluationRow.update(data=data, updated_at=time.time()).where( + self._EvaluationRow.rollout_id == rollout_id + ).execute() else: - self._EvaluationRow.create(rollout_id=rollout_id, data=data) + self._EvaluationRow.create(rollout_id=rollout_id, data=data, updated_at=time.time()) def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]: + # Build base query if rollout_id is None: - query = self._EvaluationRow.select().dicts() + model_query = self._EvaluationRow.select().order_by(self._EvaluationRow.updated_at.desc()) else: - query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id) - results = list(query) + model_query = self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id) + + # Log SQL for debugging + try: + sql_text, sql_params = model_query.sql() + self._logger.debug( + "[SQLITE_READ_ROWS] db=%s sql=%s params=%s", self._db_path, sql_text, sql_params + ) + except Exception as e: + self._logger.debug("[SQLITE_READ_ROWS] Failed to render SQL for debug: %s", e) + + # Execute and collect results + results = list(model_query.dicts()) + + # Debug: summarize results + try: + count = len(results) + sample = results[:3] + sample_rollout_ids = [] + sample_updated = [] + for r in sample: + # r is a row dict with keys: rollout_id, data, updated_at + rid = r.get("rollout_id") + # updated_at may be missing on very old rows; guard accordingly + up_at = r.get("updated_at", None) + # Prefer rollout_id from nested data if available + try: + rid_nested = r.get("data", {}).get("execution_metadata", {}).get("rollout_id") + if rid_nested: + rid = rid_nested + except Exception: + pass + sample_rollout_ids.append(str(rid)) + sample_updated.append(up_at) + self._logger.debug( + "[SQLITE_READ_ROWS] fetched_rows=%d sample_rollout_ids=%s sample_updated_at=%s", + count, + sample_rollout_ids, + sample_updated, + ) + except Exception as e: + self._logger.debug("[SQLITE_READ_ROWS] Failed to summarize results for debug: %s", e) + return [result["data"] for result in results] def delete_row(self, rollout_id: str) -> int: diff --git a/eval_protocol/pytest/integrations/__init__.py b/eval_protocol/pytest/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/eval_protocol/pytest/integrations/openenv_trl.py b/eval_protocol/pytest/integrations/openenv_trl.py new file mode 100644 index 00000000..e5b1cf0c --- /dev/null +++ b/eval_protocol/pytest/integrations/openenv_trl.py @@ -0,0 +1,317 @@ +""" +TRL + OpenEnv Integration Helper + +This module exposes a single helper to build a TRL-compatible rollout_func +using the OpenEnvRolloutProcessor. It converts dataset prompts → EvaluationRows, +executes rollouts with concurrency, and converts results back to TRL's format. +""" + +from __future__ import annotations + +import asyncio +from typing import Any, Callable, Dict, List, Optional, Type +import re + +from eval_protocol.models import EvaluationRow, InputMetadata +from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig +from trl import GRPOConfig + + +def create_openenv_rollout_func( + env_factory: Callable[[], Any] | None, + prompt_builder: Callable[[Any, int, list[str]], Any], + action_parser: Callable[[str], Any], + model: str = "gpt-4o-mini", + max_steps: int = 8, + *, + completion_params: Dict[str, Any] | None = None, + concurrency: int | None = None, + # Allow any rollout processor to be used + processor_cls: Optional[Type[Any]] = OpenEnvRolloutProcessor, + processor_kwargs: Optional[Dict[str, Any]] = None, + # Optional environment integration (build a default env_factory if not provided) + env_client_cls: Optional[Type[Any]] = None, + tasks: List[str] | None = None, + miniwob_url: str | None = None, + docker_image: str = "browsergym-env:latest", + # HTTP client/direct server options (match HTTPEnvClient interface) + env_base_url: Optional[str] = None, + request_timeout_s: float = 15.0, + default_headers: Optional[Dict[str, str]] = None, + # Docker provider passthrough (match HTTPEnvClient.from_docker_image) + provider: Any | None = None, + docker_port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + # BrowserGym-specific convenience flags mapped to env vars + benchmark: str = "miniwob", + headless: bool = True, + viewport_width: int = 1280, + viewport_height: int = 720, + timeout_ms: int = 10000, +): + """ + Build a TRL-compatible rollout_func backed by a rollout processor (default: OpenEnvRolloutProcessor). + + Args: + env_factory: Callable yielding an OpenEnv HTTPEnvClient instance. If None, a default + BrowserGym env factory is built using tasks/miniwob_url/docker_image. + prompt_builder: (observation, step, history) -> Any content + Content should be directly compatible with the LLM client + (string or OpenAI-style content list/dict). The processor will + not modify it. + action_parser: (llm_response: str) -> env action object (e.g., BrowserGymAction) + model: LLM model identifier + max_steps: Maximum environment steps per rollout + completion_params: Extra/override completion parameters to pass through + (e.g., {"temperature": 0.2, "top_p": 0.9, ...}). + These merge over defaults inferred from GRPO args. + concurrency: Max number of concurrent rollouts. Defaults to + args.per_device_train_batch_size if not provided. + processor_cls: Rollout processor class to instantiate. Defaults to OpenEnvRolloutProcessor. + processor_kwargs: Extra kwargs forwarded to the processor constructor. These override any + automatically derived kwargs below. + env_client_cls: Optional environment client class to instantiate (generic). + If provided and env_factory is None: + - If env_base_url is set: env_client_cls(base_url=..., request_timeout_s=..., default_headers=...) + - Else: env_client_cls.from_docker_image(docker_image, provider=..., **docker_kwargs) + tasks: Optional list of BrowserGym task names to rotate over. If provided and env_factory is None, + we will select one task per num_generations group. + miniwob_url: MiniWoB base URL (e.g., "http://172.17.0.1:8888/miniwob") for containers. + docker_image: Docker image to use for per-rollout BrowserGym containers. + env_base_url: If provided, connect directly to an existing env server via HTTP. + request_timeout_s: HTTP client timeout (seconds). + default_headers: Default headers for HTTP requests (auth/trace). + provider: Optional Docker provider to use for from_docker_image. + docker_port: Optional host port binding override (provider-dependent). + env_vars: Extra environment vars for the container; merged with BrowserGym defaults. + benchmark: BrowserGym benchmark name ('miniwob', 'webarena', etc.) mapped to env var. + headless: Headless mode mapped to env var. + viewport_width/height: Browser viewport mapped to env vars. + timeout_ms: Action timeout mapped to env var. + + Returns: + rollout_func(prompts: List[str], args: GRPOConfig, processing_class) -> Dict[str, List] + """ + + def resolve_fireworks_model(model_str: str) -> str: + """ + Resolve a Fireworks deployment resource to its active deployed model name. + Accepts plain resource strings or LiteLLM-style prefixed ids, e.g.: + - "accounts//deployments/" + - "fireworks_ai/accounts//deployments/" + - "fireworks_ai/accounts/fireworks/models/qwen3-8b#accounts//deployments/" + Returns original string on any error or when resolution is not applicable. + """ + try: + if not isinstance(model_str, str) or not model_str: + return model_str + prefix = "" + raw = model_str + if model_str.startswith("fireworks_ai/"): + prefix = "fireworks_ai/" + raw = model_str[len(prefix):] + m = re.search(r"(accounts/[^/\s]+/deployments/[^#\s]+)", raw) + if not m and "#" in raw: + right = raw.split("#", 1)[1] + m = re.search(r"(accounts/[^/\s]+/deployments/[^#\s]+)", right) + if not m: + return model_str + deployment_res = m.group(1) + try: + from fireworks.gateway import Gateway + from fireworks.control_plane.generated.protos_grpcio.gateway.deployed_model_pb2 import ( # type: ignore + DeployedModel as SyncDeployedModel, + ListDeployedModelsRequest as SyncListDeployedModelsRequest, + ) + except Exception: + return model_str + gateway = Gateway() + req = SyncListDeployedModelsRequest(filter=f'deployment="{deployment_res}"') + resp = gateway.list_deployed_models_sync(req) + if getattr(resp, "total_size", 0) <= 0: + return model_str + deployed = resp.deployed_models[0] + if getattr(deployed, "state", None) is not None: + if deployed.state != SyncDeployedModel.DEPLOYED: + return model_str + resolved_name = getattr(deployed, "name", None) + if not resolved_name: + return model_str + return prefix + resolved_name if prefix else resolved_name + except Exception: + return model_str + + def rollout_func(prompts: List[str], args: GRPOConfig, processing_class) -> Dict[str, List]: + # 1) Prompts → EvaluationRows (one per generation per prompt) + num_generations = getattr(args, "num_generations", 8) + evaluation_rows: List[EvaluationRow] = [] + # Build rows contiguous per prompt: for each prompt, add num_generations rows + for prompt in prompts: + for _ in range(num_generations): + evaluation_rows.append( + EvaluationRow( + messages=[{"role": "user", "content": prompt}], + input_metadata=InputMetadata( + completion_params={"model": model} + ), + ) + ) + + # 2) Build rollout config + base_params: Dict[str, Any] = { + "model": model, + "temperature": getattr(args, "temperature", 0.0), + "max_tokens": getattr(args, "max_completion_length", 100), + } + if completion_params: + base_params.update(completion_params) + + max_concurrency = concurrency if concurrency is not None else getattr(args, "per_device_train_batch_size", 1) + + config = RolloutProcessorConfig( + completion_params=base_params, + mcp_config_path="", + semaphore=asyncio.Semaphore(max_concurrency), + steps=max_steps, + ) + + # 3) Execute rollouts + # 3) Instantiate rollout processor (pluggable) + Processor = processor_cls or OpenEnvRolloutProcessor # type: ignore[assignment] + _kwargs: Dict[str, Any] = dict(processor_kwargs or {}) + # If using OpenEnvRolloutProcessor (or compatible), supply env/prompt/action args unless overridden + _kwargs.setdefault("env_factory", env_factory) + _kwargs.setdefault("prompt_builder", prompt_builder) + _kwargs.setdefault("action_parser", action_parser) + # Environment args (only used by processors that support them) + _kwargs.setdefault("env_client_cls", env_client_cls) + _kwargs.setdefault("tasks", tasks) + _kwargs.setdefault("miniwob_url", miniwob_url) + _kwargs.setdefault("docker_image", docker_image) + _kwargs.setdefault("env_base_url", env_base_url) + _kwargs.setdefault("request_timeout_s", request_timeout_s) + _kwargs.setdefault("default_headers", default_headers) + _kwargs.setdefault("provider", provider) + _kwargs.setdefault("docker_port", docker_port) + _kwargs.setdefault("env_vars", env_vars) + _kwargs.setdefault("benchmark", benchmark) + _kwargs.setdefault("headless", headless) + _kwargs.setdefault("viewport_width", viewport_width) + _kwargs.setdefault("viewport_height", viewport_height) + _kwargs.setdefault("timeout_ms", timeout_ms) + _kwargs.setdefault("num_generations", num_generations) + + processor = Processor(**_kwargs) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + async def _run_all(): + tasks = processor(evaluation_rows, config) + return await asyncio.gather(*tasks) + completed_rows = loop.run_until_complete(_run_all()) + finally: + loop.close() + + # 4) Convert EvaluationRows → TRL expected dict + all_prompt_ids_per_row: List[List[int]] = [] + all_completion_ids: List[List[int]] = [] + all_logprobs: List[List[float]] = [] + step_rewards: List[List[float]] = [] + + non_empty_rewards = 0 + total_rewards_sum = 0.0 + rows_with_rewards = 0 + + # Prefer tokenizer on the processing_class if present + tokenizer = getattr(processing_class, "tokenizer", None) + if tokenizer is None: + tokenizer = processing_class + encode_fn = getattr(tokenizer, "encode", None) + eos_id = getattr(tokenizer, "eos_token_id", None) + if eos_id is None: + eos_id = 0 + if encode_fn is None: + pass + + for idx, row in enumerate(completed_rows): + prompt_ids: List[int] = [] + completion_ids: List[int] = [] + logprobs: List[float] = [] + + rewards: List[float] = [] + for msg in row.messages: + if msg.role == "user": + tokens = encode_fn(msg.content or "") if encode_fn else [] + prompt_ids.extend(tokens) + elif msg.role == "assistant": + tokens = encode_fn(msg.content or "") if encode_fn else [] + completion_ids.extend(tokens) + logprobs.extend([0.0] * len(tokens)) # placeholder + elif msg.role == "system": + try: + content = msg.content or "" + if isinstance(content, str) and content.startswith("__ep_step_rewards__:"): + payload = content.split(":", 1)[1] + import json as _json + rewards = _json.loads(payload) or [] + except Exception: + pass + + # Fallback to execution metadata (older processors) + if not rewards: + if hasattr(row.execution_metadata, "extra") and getattr(row.execution_metadata, "extra"): + try: + rewards = row.execution_metadata.extra.get("step_rewards", []) or [] + except Exception: + rewards = [] + + all_prompt_ids_per_row.append(prompt_ids if prompt_ids else [0]) + all_completion_ids.append(completion_ids if completion_ids else [eos_id]) + all_logprobs.append(logprobs if logprobs else [0.0]) + step_rewards.append(rewards if rewards else [0.0]) + + if rewards: + non_empty_rewards += 1 + rows_with_rewards += 1 + try: + total_rewards_sum += float(sum(rewards)) + except Exception: + pass + + if rows_with_rewards > 0: + avg_sum = total_rewards_sum / rows_with_rewards + else: + avg_sum = 0.0 + + # TRL expects 'prompt_ids' at the unique-prompt level in the vLLM-server path. + # Our processor produced per-row (i.e., per-generation) entries; collapse to unique prompts. + try: + if num_generations > 0 and len(all_prompt_ids_per_row) % num_generations == 0: + num_unique = len(all_prompt_ids_per_row) // num_generations + prompt_ids_unique = [ + all_prompt_ids_per_row[i * num_generations] for i in range(num_unique) + ] + else: + # Fallback: de-duplicate while preserving order + seen = set() + prompt_ids_unique = [] + for p in all_prompt_ids_per_row: + t = tuple(p) + if t in seen: + continue + seen.add(t) + prompt_ids_unique.append(p) + except Exception: + prompt_ids_unique = all_prompt_ids_per_row + + return { + "prompt_ids": prompt_ids_unique, + "completion_ids": all_completion_ids, + "logprobs": all_logprobs, + "step_rewards": step_rewards, + } + + return rollout_func + + diff --git a/eval_protocol/pytest/integrations/openenv_trl_vllm.py b/eval_protocol/pytest/integrations/openenv_trl_vllm.py new file mode 100644 index 00000000..476a56e7 --- /dev/null +++ b/eval_protocol/pytest/integrations/openenv_trl_vllm.py @@ -0,0 +1,351 @@ +""" +Lightweight vLLM + OpenEnv Integration + +Simplified integration using vLLM for inference with proper multi-turn completion splitting. +No Fireworks inference, no hot reload - just vLLM. +""" + +from __future__ import annotations + +import asyncio +from typing import Any, Callable, Dict, List, Optional, Type + +from eval_protocol.models import EvaluationRow, InputMetadata, Message +from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.utils.evaluation_row_utils import ( + filter_longest_conversation, + multi_turn_assistant_to_ground_truth, + assistant_to_ground_truth, +) +from trl import GRPOConfig + + +def create_openenv_vllm_rollout_func( + env_factory: Callable[[], Any] | None, + prompt_builder: Callable[[Any, int, list[str]], Any], + action_parser: Callable[[str], Any], + vllm_base_url: str = "http://localhost:8000", + max_steps: int = 8, + *, + split_mode: str = "multi_turn", # "multi_turn", "last_turn", "longest", or None + completion_params: Dict[str, Any] | None = None, + concurrency: int | None = None, + processor_cls: Optional[Type[Any]] = OpenEnvRolloutProcessor, + processor_kwargs: Optional[Dict[str, Any]] = None, + # Environment configuration + env_client_cls: Optional[Type[Any]] = None, + tasks: List[str] | None = None, + miniwob_url: str | None = None, + docker_image: str = "browsergym-env:latest", + env_base_url: Optional[str] = None, + request_timeout_s: float = 15.0, + default_headers: Optional[Dict[str, str]] = None, + provider: Any | None = None, + docker_port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + benchmark: str = "miniwob", + headless: bool = True, + viewport_width: int = 1280, + viewport_height: int = 720, + timeout_ms: int = 10000, +): + """ + Build a TRL-compatible rollout_func using vLLM inference with OpenEnv. + + This is a lightweight version that: + - Uses vLLM client directly (no Fireworks, no hot reload) + - Properly splits completions using evaluation_row_utils helpers + - Works with TRL's GRPO trainer + + Args: + env_factory: Callable yielding an OpenEnv HTTPEnvClient instance + prompt_builder: (observation, step, history) -> content for LLM + action_parser: (llm_response: str) -> env action object + vllm_base_url: Base URL for vLLM server (e.g., "http://localhost:8000") + max_steps: Maximum environment steps per rollout + split_mode: How to split completions: + - "multi_turn": Split each assistant message as separate row (multi_turn_assistant_to_ground_truth) + - "last_turn": Extract last assistant message as ground truth (assistant_to_ground_truth) + - "longest": Keep only longest conversation (filter_longest_conversation) + - None: No splitting, return all rows as-is + completion_params: Extra completion parameters (temperature, max_tokens, etc.) + concurrency: Max concurrent rollouts (defaults to per_device_train_batch_size) + processor_cls: Rollout processor class (default: OpenEnvRolloutProcessor) + processor_kwargs: Extra kwargs for processor + env_client_cls: Environment client class + tasks: List of task names to rotate through + miniwob_url: MiniWoB base URL + docker_image: Docker image for environments + env_base_url: Direct HTTP connection to existing server + request_timeout_s: HTTP timeout + default_headers: HTTP headers + provider: Docker provider + docker_port: Host port binding + env_vars: Environment variables for container + benchmark: BrowserGym benchmark name + headless: Headless browser mode + viewport_width/height: Browser viewport size + timeout_ms: Action timeout + + Returns: + rollout_func(prompts: List[str], args: GRPOConfig, processing_class) -> Dict[str, List] + + Example: + ```python + from trl import GRPOConfig, GRPOTrainer + from trl.extras.vllm_client import VLLMClient + from envs.browsergym_env import BrowserGymEnv, BrowserGymAction + + # Start vLLM server first: + # CUDA_VISIBLE_DEVICES=0,1 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 + + def make_env(): + return BrowserGymEnv.from_docker_image( + "browsergym-env:latest", + env_vars={"BROWSERGYM_BENCHMARK": "miniwob"} + ) + + def build_prompt(obs, step, history): + return f"Step {step}\\nGoal: {obs.goal}\\n{obs.text[:500]}" + + def parse_action(text): + return BrowserGymAction(action_str=text) + + rollout_func = create_openenv_vllm_rollout_func( + env_factory=make_env, + prompt_builder=build_prompt, + action_parser=parse_action, + vllm_base_url="http://localhost:8000", + tasks=["click-test", "click-button", "enter-text"], + split_mode="multi_turn", # Split each turn for training + ) + + training_args = GRPOConfig( + output_dir="outputs/vllm-training", + per_device_train_batch_size=2, + num_generations=4, + ) + + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-7B", + args=training_args, + train_dataset=dataset, + rollout_func=rollout_func, + ) + ``` + """ + + # Import vLLM client (will be used for generation) + try: + from trl.extras.vllm_client import VLLMClient + except ImportError: + raise ImportError( + "vLLM client not available. Install with: pip install trl[vllm]" + ) + + # Initialize vLLM client + vllm_client = VLLMClient(base_url=vllm_base_url) + + def rollout_func(prompts: List[str], args: GRPOConfig, processing_class) -> Dict[str, List]: + """ + Execute rollouts and return TRL-compatible results. + + Flow: + 1. Prompts → EvaluationRows (num_generations per prompt) + 2. Execute rollouts via OpenEnvRolloutProcessor + 3. Split completions using evaluation_row_utils + 4. Generate completions via vLLM for each split row + 5. Convert to TRL format + """ + num_generations = getattr(args, "num_generations", 8) + + # 1) Build evaluation rows (one per generation per prompt) + evaluation_rows: List[EvaluationRow] = [] + for prompt in prompts: + for gen_idx in range(num_generations): + evaluation_rows.append( + EvaluationRow( + messages=[Message(role="user", content=prompt)], + input_metadata=InputMetadata( + completion_params={}, + extra={"generation_idx": gen_idx} + ), + ) + ) + + # 2) Build processor config + base_params: Dict[str, Any] = { + "temperature": getattr(args, "temperature", 1.0), + "max_tokens": getattr(args, "max_completion_length", 100), + } + if completion_params: + base_params.update(completion_params) + + max_concurrency = concurrency if concurrency is not None else getattr( + args, "per_device_train_batch_size", 1 + ) + + config = RolloutProcessorConfig( + completion_params=base_params, + mcp_config_path="", + semaphore=asyncio.Semaphore(max_concurrency), + steps=max_steps, + ) + + # 3) Execute rollouts using OpenEnvRolloutProcessor + Processor = processor_cls or OpenEnvRolloutProcessor + _kwargs: Dict[str, Any] = dict(processor_kwargs or {}) + _kwargs.setdefault("env_factory", env_factory) + _kwargs.setdefault("prompt_builder", prompt_builder) + _kwargs.setdefault("action_parser", action_parser) + _kwargs.setdefault("env_client_cls", env_client_cls) + _kwargs.setdefault("tasks", tasks) + _kwargs.setdefault("miniwob_url", miniwob_url) + _kwargs.setdefault("docker_image", docker_image) + _kwargs.setdefault("env_base_url", env_base_url) + _kwargs.setdefault("request_timeout_s", request_timeout_s) + _kwargs.setdefault("default_headers", default_headers) + _kwargs.setdefault("provider", provider) + _kwargs.setdefault("docker_port", docker_port) + _kwargs.setdefault("env_vars", env_vars) + _kwargs.setdefault("benchmark", benchmark) + _kwargs.setdefault("headless", headless) + _kwargs.setdefault("viewport_width", viewport_width) + _kwargs.setdefault("viewport_height", viewport_height) + _kwargs.setdefault("timeout_ms", timeout_ms) + _kwargs.setdefault("num_generations", num_generations) + + processor = Processor(**_kwargs) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + async def _run_all(): + tasks = processor(evaluation_rows, config) + return await asyncio.gather(*tasks) + + completed_rows = loop.run_until_complete(_run_all()) + finally: + loop.close() + + # 4) Split completions based on split_mode + if split_mode == "multi_turn": + # Split each assistant message into separate rows + split_rows = multi_turn_assistant_to_ground_truth(completed_rows) + elif split_mode == "last_turn": + # Extract last assistant message as ground truth + split_rows = assistant_to_ground_truth(completed_rows) + elif split_mode == "longest": + # Keep only longest conversation per rollout_id + split_rows = filter_longest_conversation(completed_rows) + elif split_mode is None: + # No splitting + split_rows = completed_rows + else: + raise ValueError( + f"Invalid split_mode: {split_mode}. " + "Must be 'multi_turn', 'last_turn', 'longest', or None" + ) + + print(f"[OpenEnvVLLM] Split {len(completed_rows)} rows → {len(split_rows)} rows (mode={split_mode})") + + # 5) Generate completions via vLLM for each split row + # Build messages for vLLM chat endpoint + all_messages: List[List[Dict]] = [] + for row in split_rows: + messages = [{"role": msg.role, "content": msg.content} for msg in row.messages] + all_messages.append(messages) + + # Call vLLM to generate completions + # Check if we have conversational format + is_conversational = all_messages and isinstance(all_messages[0], list) + + vllm_params = { + "n": 1, # One completion per split row + "temperature": base_params["temperature"], + "max_tokens": base_params["max_tokens"], + } + + # Add any extra vLLM parameters from completion_params + if completion_params: + for key in ["top_p", "top_k", "min_p", "repetition_penalty"]: + if key in completion_params: + vllm_params[key] = completion_params[key] + + if is_conversational: + print(f"[OpenEnvVLLM] Calling vLLM chat endpoint with {len(all_messages)} conversations") + vllm_response = vllm_client.chat( + messages=all_messages, + **vllm_params, + ) + else: + # Convert messages to prompts for generate endpoint + prompts_for_vllm = [] + for msgs in all_messages: + # Simple concatenation (you may want to use a chat template here) + prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in msgs) + prompts_for_vllm.append(prompt_text) + + print(f"[OpenEnvVLLM] Calling vLLM generate endpoint with {len(prompts_for_vllm)} prompts") + vllm_response = vllm_client.generate( + prompts=prompts_for_vllm, + **vllm_params, + ) + + # 6) Convert to TRL format + prompt_ids = vllm_response["prompt_ids"] + completion_ids = vllm_response["completion_ids"] + logprobs = vllm_response["logprobs"] + + # Extract step rewards from completed rows + step_rewards: List[List[float]] = [] + for row in split_rows: + rewards: List[float] = [] + + # Look for rewards in system messages (sentinel pattern) + for msg in row.messages: + if msg.role == "system": + try: + content = msg.content or "" + if isinstance(content, str) and content.startswith("__ep_step_rewards__:"): + import json + payload = content.split(":", 1)[1] + rewards = json.loads(payload) or [] + break + except Exception: + pass + + # Fallback to execution metadata + if not rewards and hasattr(row.execution_metadata, "extra"): + try: + rewards = row.execution_metadata.extra.get("step_rewards", []) or [] + except Exception: + pass + + step_rewards.append(rewards if rewards else [0.0]) + + # Compute statistics + total_reward = sum(sum(r) for r in step_rewards) + avg_reward = total_reward / len(step_rewards) if step_rewards else 0.0 + print(f"[OpenEnvVLLM] Total reward: {total_reward:.2f}, Avg: {avg_reward:.2f}") + + # TRL expects prompt_ids at unique-prompt level (not per-generation) + # Deduplicate while preserving order + seen_prompts = set() + prompt_ids_unique = [] + for p_ids in prompt_ids: + p_tuple = tuple(p_ids) + if p_tuple not in seen_prompts: + seen_prompts.add(p_tuple) + prompt_ids_unique.append(p_ids) + + return { + "prompt_ids": prompt_ids_unique, + "completion_ids": completion_ids, + "logprobs": logprobs, + "step_rewards": step_rewards, + } + + return rollout_func + diff --git a/eval_protocol/pytest/openenv_rollout_processor.py b/eval_protocol/pytest/openenv_rollout_processor.py new file mode 100644 index 00000000..fc9c14dd --- /dev/null +++ b/eval_protocol/pytest/openenv_rollout_processor.py @@ -0,0 +1,483 @@ +""" +OpenEnv Rollout Processor + +Generic processor for ANY OpenEnv environment using the standard HTTPEnvClient interface. +No environment-specific code - works with BrowserGym, Echo, TextArena, Atari, etc. + +Key: OpenEnv provides a standard interface across all environments: +- All environments: HTTPEnvClient[ActionType, ObservationType] +- All have: reset() → StepResult, step(action) → StepResult, state() → State +- Client handles serialization/deserialization + +This processor just calls env.reset(), env.step(), env.state() - that's it! +""" + +import asyncio +import logging +import time +from typing import List, Any, Dict, Callable, Generic, TypeVar, Optional, Type +import json + +from openai.types import CompletionUsage + +from eval_protocol.mcp.execution.policy import LiteLLMPolicy +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + + logger = logging.getLogger(__name__) + + +class OpenEnvRolloutProcessor(RolloutProcessor): + """ + Generic rollout processor for ANY OpenEnv environment. + + Works with any environment that follows OpenEnv's standard interface: + - HTTPEnvClient[ActionType, ObservationType] + - reset() → StepResult[ObservationType] + - step(action: ActionType) → StepResult[ObservationType] + - state() → State + + No environment-specific code - just uses the standard interface! + + Examples: + ```python + # BrowserGym + from envs.browsergym_env import BrowserGymEnv, BrowserGymAction + def make_env(): + return BrowserGymEnv.from_docker_image(...) + + # Echo + from envs.echo_env import EchoEnv, EchoAction + def make_env(): + return EchoEnv.from_docker_image(...) + + # TextArena + from envs.textarena_env import TextArenaEnv, TextArenaAction + def make_env(): + return TextArenaEnv.from_docker_image(...) + + # Same processor works for all! + processor = OpenEnvRolloutProcessor( + env_factory=make_env, + action_parser=lambda text: BrowserGymAction(action_str=text), # or EchoAction(message=text), etc. + ) + ``` + + For TRL integration, see: trl-evalp/openenv_trl_integration.py + """ + + def __init__( + self, + env_factory: Optional[Callable] = None, + prompt_builder: Callable[[Any, int, List[str]], Any] | None = None, + action_parser: Callable[[str], Any] | None = None, + *, + # Environment construction parameters (generic HTTP client or Docker) + env_client_cls: Optional[Type[Any]] = None, + tasks: Optional[List[str]] = None, + miniwob_url: Optional[str] = None, + docker_image: str = "browsergym-env:latest", + env_base_url: Optional[str] = None, + hub_repo_id: Optional[str] = None, + request_timeout_s: float = 15.0, + default_headers: Optional[Dict[str, str]] = None, + provider: Any | None = None, + docker_port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + benchmark: str = "miniwob", + headless: bool = True, + viewport_width: int = 1280, + viewport_height: int = 720, + timeout_ms: int = 10000, + num_generations: Optional[int] = None, + ): + """ + Initialize processor. + + Args: + env_factory: Optional callable that creates an OpenEnv environment (HTTPEnvClient) + Example: lambda: BrowserGymEnv.from_docker_image(...). If not provided, + the processor will build one using the parameters below. + prompt_builder: Optional function that builds the user message content from + (observation, step, history). It should return content + directly compatible with the LLM client (e.g., a string, + or OpenAI-style content list/dict). No additional processing + is performed by the processor. + action_parser: Function that converts LLM text → Action object + Example: lambda text: BrowserGymAction(action_str=text) + Example: lambda text: EchoAction(message=text) + env_client_cls: Optional environment HTTP client class (generic). + tasks, miniwob_url, docker_image, env_base_url, request_timeout_s, default_headers, + provider, docker_port, env_vars, benchmark, headless, viewport_*, timeout_ms: + Parameters to construct default environments if env_factory is not provided. + num_generations: Optional hint for task rotation grouping (used to mimic GRPO grouping). + """ + self.prompt_builder = prompt_builder or (lambda obs, step, history: str(obs)) + if action_parser is None: + raise ValueError("action_parser must be provided and return an Action object.") + self.action_parser = action_parser + + # Store env construction parameters + self._provided_env_factory = env_factory + self._env_client_cls = env_client_cls + self._tasks = tasks or [] + self._miniwob_url = miniwob_url + self._docker_image = docker_image + self._env_base_url = env_base_url + self._hub_repo_id = hub_repo_id + self._request_timeout_s = request_timeout_s + self._default_headers = default_headers + self._provider = provider + self._docker_port = docker_port + self._env_vars = env_vars or {} + self._benchmark = benchmark + self._headless = headless + self._viewport_width = viewport_width + self._viewport_height = viewport_height + self._timeout_ms = timeout_ms + self._num_generations = max(1, int(num_generations)) if num_generations else 1 + self._env_create_idx: int = 0 + + # Build env_factory if not provided + self.env_factory = self._build_env_factory() + + def __call__( + self, rows: List[EvaluationRow], config: RolloutProcessorConfig + ) -> List[asyncio.Task[EvaluationRow]]: + """Process evaluation rows and return async tasks.""" + + semaphore = config.semaphore + max_steps = config.steps or 8 + + async def process_row(row: EvaluationRow) -> EvaluationRow: + """Process a single row with OpenEnv rollout.""" + start_time = time.perf_counter() + + # Create environment + try: + print("[OpenEnvRolloutProcessor] Creating environment via env_factory() ...") + except Exception: + pass + env = self.env_factory() + try: + print("[OpenEnvRolloutProcessor] Environment client created.") + except Exception: + pass + + try: + # Get model config + raw_model = config.completion_params.get("model", "gpt-4o-mini") + model = raw_model + temperature = config.completion_params.get("temperature", 0.0) + max_tokens = config.completion_params.get("max_tokens", 100) + # Optional: direct routing or provider overrides (e.g., base_url, api_key, top_p, stop, etc.) + base_url = config.completion_params.get("base_url") + # Forward any extra completion params to LiteLLMPolicy (they will be sent per-request) + extra_params: Dict[str, Any] = dict(config.completion_params or {}) + for _k in ("model", "temperature", "max_tokens", "base_url"): + try: + extra_params.pop(_k, None) + except Exception: + pass + try: + print(f"[OpenEnvRolloutProcessor] Model='{model}' temp={temperature} max_tokens={max_tokens} base_url={base_url or '(default)'}") + except Exception: + pass + + # Create policy for generation + policy = LiteLLMPolicy( + model_id=model, + temperature=temperature, + max_tokens=max_tokens, + base_url=base_url, + **extra_params, + ) + + # Reset environment with simple transient-error retries + reset_attempts = 3 + reset_delay = 1.0 + last_exc = None + try: + print("[OpenEnvRolloutProcessor] Resetting environment ...") + except Exception: + pass + for i in range(reset_attempts): + try: + result = env.reset() + try: + print(f"[OpenEnvRolloutProcessor] reset() succeeded on attempt {i + 1}") + except Exception: + pass + break + except Exception as e: + last_exc = e + if i == reset_attempts - 1: + raise + time.sleep(reset_delay) + reset_delay *= 2.0 + observation = result.observation + + + # Initialize tracking + messages = list(row.messages) # Copy initial messages + # Inject system prompt if provided and not already present + try: + has_system = any(m.role == "system" for m in messages) + except Exception: + has_system = False + system_prompt = None + try: + system_prompt = config.completion_params.get("system_prompt") + except Exception: + system_prompt = None + if system_prompt and not has_system: + messages.insert(0, Message(role="system", content=system_prompt)) + usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + step_rewards = [] + history: List[str] = [] + + # Agent loop: model → action → env.step → repeat + for step in range(max_steps): + if result.done: + logger.info(f"Episode done after {step} steps") + try: + print(f"[OpenEnvRolloutProcessor] Episode already done at step {step}") + except Exception: + pass + break + + # Build user message content via user-provided prompt_builder + try: + user_content = self.prompt_builder(observation, step + 1, history) + except Exception as e: + logger.error(f"prompt_builder failed: {e}", exc_info=True) + user_content = str(observation) + try: + print(f"[OpenEnvRolloutProcessor] Step {step + 1}: built user prompt (len={len(str(user_content))})") + except Exception: + pass + + messages.append(Message(role="user", content=user_content)) + # Optional tracing + if getattr(config, "logger", None): + try: + # Log a snapshot with current messages so UI shows incremental turns + try: + row_for_log = row.model_copy(deep=True) # pydantic v2 + except Exception: + import copy as _copy + row_for_log = _copy.deepcopy(row) + row_for_log.messages = list(messages) + config.logger.log(row_for_log) + except Exception: + pass + + # Call model to generate action (LiteLLM handles multimodal!) + try: + print(f"[OpenEnvRolloutProcessor] Calling model (messages={len(messages)}) ...") + except Exception: + pass + response = await policy._make_llm_call( + messages=[msg.model_dump() for msg in messages], + tools=None, # No tools - just text generation + ) + + # Update usage + usage["prompt_tokens"] += response["usage"]["prompt_tokens"] + usage["completion_tokens"] += response["usage"]["completion_tokens"] + usage["total_tokens"] += response["usage"]["total_tokens"] + + # Extract assistant message and parse into Action object + assistant_message = response["choices"][0]["message"]["content"] + try: + preview = assistant_message if isinstance(assistant_message, str) else str(assistant_message) + print(f"[OpenEnvRolloutProcessor] Model output (first 120): '{preview[:120] if preview else ''}'") + except Exception: + pass + action = self.action_parser(assistant_message) + try: + label = getattr(action, "action_str", None) + print(f"[OpenEnvRolloutProcessor] Parsed action='{(label or str(action))[:120]}'") + except Exception: + pass + + # Add assistant message (original content) + messages.append(Message(role="assistant", content=assistant_message)) + + # Execute action in environment (OpenEnv standard interface!) with transient-error retries + step_attempts = 2 + step_delay = 0.5 + for si in range(step_attempts): + try: + result = env.step(action) + break + except Exception as se: + if si == step_attempts - 1: + raise + time.sleep(step_delay) + + # Collect reward (OpenEnv standard: result.reward) + reward = float(result.reward or 0.0) + step_rewards.append(reward) + try: + print(f"[OpenEnvRolloutProcessor] Step {step + 1}: reward={reward} done={result.done}") + except Exception: + pass + _action_label = getattr(action, "action_str", None) + if not _action_label: + try: + _action_label = str(action) + except Exception: + _action_label = "" + logger.debug(f"Step {step}: action={_action_label}, reward={reward}") + + # Update observation (OpenEnv standard: result.observation) + observation = result.observation + + # Update history for next prompt + error_flag = getattr(observation, "last_action_error", False) + history_line = f"Step {step + 1}: {_action_label} -> reward {reward:+.2f}{' ERROR' if error_flag else ''}" + history.append(history_line) + # Optional tracing + if getattr(config, "logger", None): + try: + # Log a snapshot with current messages so UI shows incremental turns + try: + row_for_log = row.model_copy(deep=True) # pydantic v2 + except Exception: + import copy as _copy + row_for_log = _copy.deepcopy(row) + row_for_log.messages = list(messages) + config.logger.log(row_for_log) + except Exception: + pass + + # Update row with results + row.messages = messages + row.execution_metadata.usage = CompletionUsage( + prompt_tokens=usage["prompt_tokens"], + completion_tokens=usage["completion_tokens"], + total_tokens=usage["total_tokens"], + ) + row.execution_metadata.duration_seconds = time.perf_counter() - start_time + + # Store rewards for TRL reward functions via a system message sentinel + try: + sentinel = "__ep_step_rewards__:" + json.dumps(step_rewards) + messages.append(Message(role="system", content=sentinel)) + print(f"[OpenEnvRolloutProcessor] Total reward={sum(step_rewards):.2f} steps={len(step_rewards)}") + except Exception: + pass + + logger.info( + f"Rollout complete: {len(step_rewards)} steps, " + f"total_reward={sum(step_rewards):.2f}, " + f"duration={row.execution_metadata.duration_seconds:.2f}s" + ) + # Final log with complete message history + if getattr(config, "logger", None): + try: + config.logger.log(row) + except Exception: + pass + + return row + + except Exception as e: + logger.error(f"Error in rollout: {e}", exc_info=True) + try: + print(f"[OpenEnvRolloutProcessor][ERROR] {type(e).__name__}: {e}") + except Exception: + pass + raise + finally: + # Cleanup environment + try: + print("[OpenEnvRolloutProcessor] Closing environment client ...") + env.close() + print("[OpenEnvRolloutProcessor] Environment closed.") + except: + pass + + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + return await process_row(r) + + # Create and return tasks + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + return tasks + + def _build_prompt(self, observation_text: str, step: int) -> str: + """ + Build prompt for LLM from observation text. + + Generic prompt that works for any environment. + """ + return ( + f"Step {step + 1}\n\n" + f"Observation:\n{observation_text}\n\n" + f"What action should be taken? Respond with a single action." + ) + + # Removed _extract_action_text: action parsing handled entirely by action_parser + + def _build_env_factory(self) -> Callable[[], Any]: + """ + Create or return an environment factory based on the provided parameters. + Preference order: + 1) Use provided env_factory + 2) Use generic env_client_cls + """ + if self._provided_env_factory is not None: + return self._provided_env_factory + + # If a generic client class is provided, use it + if self._env_client_cls is not None: + def _generic_factory(): + if self._env_base_url: + try: + print(f"[OpenEnvRolloutProcessor] Using env_client_cls base_url={self._env_base_url}") + except Exception: + pass + return self._env_client_cls( # type: ignore[call-arg] + base_url=self._env_base_url, + request_timeout_s=self._request_timeout_s, + default_headers=self._default_headers, + ) + docker_kwargs: Dict[str, Any] = {} + if self._env_vars: + docker_kwargs["env_vars"] = {k: str(v) for k, v in self._env_vars.items()} + if self._docker_port is not None: + docker_kwargs["port"] = int(self._docker_port) + if self._hub_repo_id: + try: + print(f"[OpenEnvRolloutProcessor] Launching from_hub repo_id='{self._hub_repo_id}' ...") + except Exception: + pass + return self._env_client_cls.from_hub( # type: ignore[attr-defined] + self._hub_repo_id, + provider=self._provider, + **docker_kwargs, + ) + else: + try: + print(f"[OpenEnvRolloutProcessor] Launching from_docker_image image='{self._docker_image}' ...") + except Exception: + pass + return self._env_client_cls.from_docker_image( # type: ignore[attr-defined] + self._docker_image, + provider=self._provider, + **docker_kwargs, + ) + return _generic_factory + + # No fallback: require an env_factory or env_client_cls + raise RuntimeError( + "OpenEnvRolloutProcessor requires either env_factory or env_client_cls. " + "Provide one of these to construct the environment." + ) diff --git a/eval_protocol/utils/logs_server.py b/eval_protocol/utils/logs_server.py index adf44c57..b3a75f05 100644 --- a/eval_protocol/utils/logs_server.py +++ b/eval_protocol/utils/logs_server.py @@ -64,7 +64,25 @@ async def connect(self, websocket: WebSocket): logger.debug("[WEBSOCKET_CONNECT] Reading logs for initialization") logs = default_logger.read() + init_limit_str = os.environ.get("EP_LOGS_INIT_LIMIT", "1000") + try: + init_limit = max(1, int(init_limit_str)) + except Exception: + init_limit = 1000 + if len(logs) > init_limit: + # logs are ordered by updated_at DESC from sqlite read(); keep the newest entries + logs = logs[:init_limit] + logger.debug(f"[WEBSOCKET_CONNECT] Found many logs, truncating to newest {init_limit} entries for init") logger.debug(f"[WEBSOCKET_CONNECT] Found {len(logs)} logs to send") + try: + # Print a small sample of rollout_ids being sent on init for debugging + sample_ids = [] + for row in logs[-5:]: + rid = getattr(getattr(row, "execution_metadata", None), "rollout_id", "unknown") + sample_ids.append(str(rid)) + logger.debug(f"[WEBSOCKET_CONNECT] Init sample rollout_ids (tail): {sample_ids}") + except Exception: + pass data = { "type": "initialize_logs", @@ -104,6 +122,16 @@ def broadcast_row_upserted(self, row: "EvaluationRow"): logger.debug( f"[WEBSOCKET_BROADCAST] Successfully serialized message (length: {len(json_message)}) for rollout_id: {rollout_id}" ) + try: + # Extra debug: how many messages and last message summary + msgs = getattr(row, "messages", None) or [] + last_role = getattr(msgs[-1], "role", None) if msgs else None + last_len = len(str(getattr(msgs[-1], "content", ""))) if msgs else 0 + logger.debug( + f"[WEBSOCKET_BROADCAST] rollout_id={rollout_id} messages={len(msgs)} last_role={last_role} last_content_len={last_len}" + ) + except Exception: + pass # Queue the message for broadcasting in the main event loop logger.debug(f"[WEBSOCKET_BROADCAST] Queuing message for broadcast for rollout_id: {rollout_id}") diff --git a/tests/pytest/conftest.py b/tests/pytest/conftest.py new file mode 100644 index 00000000..6087c88c --- /dev/null +++ b/tests/pytest/conftest.py @@ -0,0 +1,4 @@ +# Ensure the eval_protocol pytest plugin is loaded in all environments +pytest_plugins = ("eval_protocol.pytest.plugin",) + + diff --git a/tests/pytest/data/echo_dataset.jsonl b/tests/pytest/data/echo_dataset.jsonl new file mode 100644 index 00000000..8992e618 --- /dev/null +++ b/tests/pytest/data/echo_dataset.jsonl @@ -0,0 +1,3 @@ +{"id": "echo-1", "prompt": "hello"} +{"id": "echo-2", "prompt": "test message"} + diff --git a/tests/pytest/data/openenv_browsergym_dataset.jsonl b/tests/pytest/data/openenv_browsergym_dataset.jsonl new file mode 100644 index 00000000..d454e6b8 --- /dev/null +++ b/tests/pytest/data/openenv_browsergym_dataset.jsonl @@ -0,0 +1,7 @@ +{"id": "openenv-1", "prompt": "start"} +{"id": "openenv-2", "prompt": "start"} +{"id": "openenv-3", "prompt": "start"} +{"id": "openenv-4", "prompt": "start"} +{"id": "openenv-5", "prompt": "start"} +{"id": "openenv-6", "prompt": "start"} + diff --git a/tests/pytest/data/wordle_dataset.jsonl b/tests/pytest/data/wordle_dataset.jsonl new file mode 100644 index 00000000..fde65371 --- /dev/null +++ b/tests/pytest/data/wordle_dataset.jsonl @@ -0,0 +1,3 @@ +{"id": "wordle-1", "prompt": "start"} +{"id": "wordle-2", "prompt": "start"} + diff --git a/tests/pytest/test_openenv_browsergym_basic.py b/tests/pytest/test_openenv_browsergym_basic.py new file mode 100644 index 00000000..75de7643 --- /dev/null +++ b/tests/pytest/test_openenv_browsergym_basic.py @@ -0,0 +1,83 @@ +import asyncio +import os +import shutil +from typing import Any, Dict, List + +import pytest + +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor + +# Skip these integration-heavy tests on CI runners by default +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip OpenEnv integration tests on CI") + + +@pytest.mark.integration +def test_openenv_browsergym_basic(): + """ + Very basic integration test to ensure OpenEnv + BrowserGym can run a single-step rollout. + Skips automatically if Docker is not available. + """ + if shutil.which("docker") is None: + pytest.skip("Docker not available on PATH; skipping OpenEnv BrowserGym basic test.") + + # Build a minimal EvaluationRow (messages can be empty; processor will add user prompts) + rows: List[EvaluationRow] = [EvaluationRow(messages=[Message(role="user", content="start")])] + + # Use tasks that are known to exist; requires MiniWoB server reachable from containers. + tasks = ["click-test"] + miniwob_url = os.getenv("MINIWOB_URL", "http://172.17.0.1:8888/miniwob/") + + # Construct the processor with a trivial action_parser; the model output will still be generated + # but we parse to a safe noop action to minimize flakiness for the environment step. + from envs.browsergym_env import BrowserGymAction # type: ignore + + processor = OpenEnvRolloutProcessor( + env_factory=None, + prompt_builder=lambda obs, step, history: "Do nothing", + action_parser=lambda text: BrowserGymAction(action_str="noop()"), + tasks=tasks, + miniwob_url=miniwob_url, + docker_image="browsergym-env:latest", + benchmark="miniwob", + timeout_ms=10000, + num_generations=1, + ) + + # Completion params: rely on an available provider/model in the environment + completion_params: Dict[str, Any] = { + "model": os.getenv( + "OPENENV_TEST_MODEL", + # Default to a Fireworks public model id used elsewhere in tests; requires FIREWORKS_API_KEY + "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct", + ), + "temperature": 0.0, + "max_tokens": 16, + } + + # Limit to a single step to keep the test fast and robust + config = RolloutProcessorConfig( + completion_params=completion_params, + semaphore=asyncio.Semaphore(1), + steps=1, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + async def _run_all(): + tasks_ = processor(rows, config) + return await asyncio.gather(*tasks_) + + completed_rows = loop.run_until_complete(_run_all()) + finally: + loop.close() + + assert len(completed_rows) == 1 + # Basic sanity checks that a rollout happened and usage is populated + row = completed_rows[0] + assert row is not None + assert row.execution_metadata is not None + assert getattr(row.execution_metadata, "duration_seconds", 0.0) >= 0.0 + diff --git a/tests/pytest/test_openenv_browsergym_eval.py b/tests/pytest/test_openenv_browsergym_eval.py new file mode 100644 index 00000000..0de81e20 --- /dev/null +++ b/tests/pytest/test_openenv_browsergym_eval.py @@ -0,0 +1,288 @@ +from typing import Any, Dict, List +import os +import re + +import pytest +from eval_protocol.models import EvaluationRow, Message, EvaluateResult +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor +import pytest + +# Skip these integration-heavy tests on CI runners by default +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip OpenEnv integration tests on CI") + + +def openenv_dataset_to_rows(data: List[Dict[str, Any]]) -> List[EvaluationRow]: + """ + Adapter: convert simple {"id": "...", "prompt": "..."} rows into EvaluationRows. + """ + rows: List[EvaluationRow] = [] + for row in data: + prompt = str(row.get("prompt", "start")) + rows.append(EvaluationRow(messages=[Message(role="user", content=prompt)])) + return rows + + +# ---- prompt_builder and action_parser modeled after browsergym_grpo_evalp.py ---- + +ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL) + + +def _as_scalar(x: Any) -> Any: + try: + return x.item() + except Exception: + return x + + +def _extract_goal_url_title(observation: Any) -> tuple[str, str, str]: + goal = getattr(observation, "goal", "") or "" + url = getattr(observation, "url", "") or "" + title = "" + metadata = getattr(observation, "metadata", {}) or {} + obs_dict = metadata.get("browsergym_obs", {}) or {} + if not goal: + goal = obs_dict.get("goal") or "" + if not url: + url = obs_dict.get("url") or "" + titles = obs_dict.get("open_pages_titles") or () + active_idx = _as_scalar(obs_dict.get("active_page_index")) + try: + active_idx = int(active_idx) + except Exception: + active_idx = 0 + if isinstance(titles, (list, tuple)) and 0 <= active_idx < len(titles): + title = titles[active_idx] or "" + return goal, url, title + + +def _extract_clickable_elements_lines(observation: Any) -> List[str]: + metadata = getattr(observation, "metadata", {}) or {} + obs_dict = metadata.get("browsergym_obs", {}) or {} + extra_props = obs_dict.get("extra_element_properties", {}) or {} + axtree_object = obs_dict.get("axtree_object") or {} + focused_bid = obs_dict.get("focused_element_bid") + bid_to_desc: Dict[str, tuple[str, str]] = {} + try: + nodes = axtree_object.get("nodes") or [] + for node in nodes: + bid = node.get("browsergym_id") + if bid is None: + continue + role = "" + name = "" + rf = node.get("role") or {} + if isinstance(rf, dict): + role = str(rf.get("value", "")).strip() + nf = node.get("name") or {} + if isinstance(nf, dict): + name = str(nf.get("value", "")).strip() + bid_to_desc[str(bid)] = (role, name) + except Exception: + pass + lines: List[str] = [] + for bid in sorted(extra_props.keys(), key=lambda x: str(x)): + props = extra_props[bid] or {} + if not props.get("clickable"): + continue + bbox = props.get("bbox") or [] + bbox_str = ", ".join(str(v) for v in bbox) if bbox else "?" + role, name = bid_to_desc.get(str(bid), ("", "")) + focus_tag = " [FOCUSED]" if (str(bid) == str(focused_bid)) else "" + rn = (role or "-") + if name: + rn = f"{rn} | {name}" + vis = props.get("visibility") + vis_str = f"{vis:.2f}" if isinstance(vis, (int, float)) else str(vis) if vis is not None else "?" + lines.append(f"- BID {bid}{focus_tag}: {rn} | bbox({bbox_str}) | visibility={vis_str}") + return lines + + +def _rank_clickables_lines(observation: Any, goal: str, top_n: int = 8) -> tuple[List[str], str | None]: + metadata = getattr(observation, "metadata", {}) or {} + obs_dict = metadata.get("browsergym_obs", {}) or {} + goal_lc = (goal or "").lower().strip() + extra_props = obs_dict.get("extra_element_properties", {}) or {} + axtree_object = obs_dict.get("axtree_object") or {} + focused_bid = str(obs_dict.get("focused_element_bid") or "") + bid_to_desc: Dict[str, tuple[str, str]] = {} + try: + nodes = axtree_object.get("nodes") or [] + for node in nodes: + bid = node.get("browsergym_id") + if bid is None: + continue + role = "" + name = "" + rf = node.get("role") or {} + if isinstance(rf, dict): + role = str(rf.get("value", "")).strip() + nf = node.get("name") or {} + if isinstance(nf, dict): + name = str(nf.get("value", "")).strip() + bid_to_desc[str(bid)] = (role, name) + except Exception: + pass + scored: List[tuple[float, str, str, str, str]] = [] + for bid_key in sorted(extra_props.keys(), key=lambda x: str(x)): + props = extra_props[bid_key] or {} + if not props.get("clickable"): + continue + role, name = bid_to_desc.get(str(bid_key), ("", "")) + name_lc = (name or "").lower() + score = 0.0 + if goal_lc and name_lc and (goal_lc in name_lc or name_lc in goal_lc): + score += 2.0 + if (role or "").lower() == "button": + score += 1.0 + if str(bid_key) == focused_bid: + score += 0.5 + vis = props.get("visibility") + try: + vis_f = float(vis) + score += max(0.0, min(1.0, vis_f)) + except Exception: + pass + bbox = props.get("bbox") or [] + bbox_str = ", ".join(str(v) for v in bbox) if bbox else "?" + rn = (role or "-") + if name: + rn = f"{rn} | {name}" + vis_str = f"{vis:.2f}" if isinstance(vis, (int, float)) else str(vis) if vis is not None else "?" + scored.append((score, str(bid_key), rn, bbox_str, vis_str)) + scored.sort(key=lambda t: t[0], reverse=True) + lines: List[str] = [] + recommended = scored[0][1] if scored else None + for idx, (score, bid, rn, bbox_str, vis_str) in enumerate(scored[:top_n], start=1): + lines.append(f"{idx}. BID {bid}: score={score:.2f} | {rn} | bbox({bbox_str}) | visibility={vis_str}") + return lines, recommended + + +def prompt_builder(observation: Any, step: int, history: List[str]) -> str: + goal, url, title = _extract_goal_url_title(observation) + url = url or "(unknown)" + error_note = "Yes" if getattr(observation, "last_action_error", False) else "No" + clickables_block = "\n".join(_extract_clickable_elements_lines(observation)) or "(none detected)" + ranked_lines, rec = _rank_clickables_lines(observation, goal, top_n=10) + ranked_block = "\n".join(ranked_lines) or "(none)" + text = getattr(observation, "text", "") or "" + text = text[:2048] + metadata = getattr(observation, "metadata", {}) or {} + obs_dict = metadata.get("browsergym_obs", {}) or {} + focused_bid = obs_dict.get("focused_element_bid") or "" + last_action = obs_dict.get("last_action") or "" + return ( + f"Step: {step}\n" + f"Goal: {goal}\n" + f"Current URL: {url}\n" + f"Title: {title}\n" + f"Previous steps:\n" + ("\n".join(history[-4:]) if history else "None") + "\n" + f"Last action: {last_action}\n" + f"Last action error: {error_note}\n" + f"Focused BID: {focused_bid}\n\n" + f"Clickable elements (BID: role | name | bbox | visibility):\n{clickables_block}\n\n" + f"Ranked clickable candidates (best first):\n{ranked_block}\n" + f"Recommended BID: {rec or '(none)'}\n\n" + "Instructions:\n" + "- Choose the most relevant clickable BID to achieve the goal.\n" + "- Prefer role=button or elements whose name matches the goal.\n" + "- Reply with a single action, e.g., click('13') or noop().\n\n" + f"Page excerpt:\n{text}\n\n" + "Reply with exactly one BrowserGym action string." + ).strip() + + +def action_parser(response_text: str): + try: + from envs.browsergym_env import BrowserGymAction # type: ignore + except Exception: + pytest.skip("OpenEnv (envs.browsergym_env) is not installed; skipping BrowserGym test.") + raise + if not response_text: + return BrowserGymAction(action_str="noop()") + for raw in response_text.splitlines(): + line = raw.strip() + if not line: + continue + m = ACTION_PATTERN.search(line) + if m: + parsed = re.sub(r"\s+", " ", m.group(0)) + return BrowserGymAction(action_str=parsed) + m = ACTION_PATTERN.search(response_text) + if m: + parsed = re.sub(r"\s+", " ", m.group(0)) + return BrowserGymAction(action_str=parsed) + return BrowserGymAction(action_str="noop()") + + +try: + from envs.browsergym_env import BrowserGymEnv # type: ignore + _HAS_BG = True +except Exception: + _HAS_BG = False + + +@evaluation_test( # type: ignore[misc] + input_dataset=["tests/pytest/data/openenv_browsergym_dataset.jsonl"], + dataset_adapter=openenv_dataset_to_rows, + completion_params=[ + { + "temperature": 0.0, + "max_tokens": 32, + "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct", + } + ], + # Keep concurrency and steps low for a quick health-check + num_runs=1, + max_concurrent_rollouts=1, + mode="pointwise", + rollout_processor=( + OpenEnvRolloutProcessor( + env_client_cls=BrowserGymEnv if _HAS_BG else None, + prompt_builder=prompt_builder, + action_parser=action_parser, + tasks=[ + "click-test", + "click-button", + "click-button-sequence", + "click-checkboxes", + "click-checkboxes-soft", + "click-checkboxes-large", + "click-checkboxes-transfer", + ], + miniwob_url=os.getenv("MINIWOB_URL", "http://172.17.0.1:8888/miniwob/"), + docker_image="browsergym-env:latest", + benchmark="miniwob", + timeout_ms=10000, + num_generations=1, + ) + if _HAS_BG + else None + ), +) +def test_openenv_browsergym_eval(row: EvaluationRow) -> EvaluationRow: + """ + Smoke test to ensure OpenEnv + BrowserGym MiniWoB runs and returns a row. + The evaluation harness will assert basic invariants (no exceptions, etc.). + """ + if not _HAS_BG: + pytest.skip("OpenEnv (envs.browsergym_env) is not installed; skipping BrowserGym test.") + # Extract step rewards from the sentinel system message injected by the rollout processor + step_rewards: List[float] = [] + try: + for msg in row.messages or []: + if msg.role == "system" and isinstance(msg.content, str) and msg.content.startswith("__ep_step_rewards__:"): + import json as _json + payload = msg.content.split(":", 1)[1] + step_rewards = _json.loads(payload) or [] + break + except Exception: + step_rewards = [] + + total = float(sum(step_rewards)) if step_rewards else 0.0 + # Map total reward to a score in [0,1]; MiniWoB rewards are typically 0/1 or -1/1 + score = max(0.0, min(1.0, total)) + reason = f"Total reward={total:.2f} across {len(step_rewards)} steps" + row.evaluation_result = EvaluateResult(score=score, reason=reason) + return row + diff --git a/tests/pytest/test_openenv_echo_baseurl.py b/tests/pytest/test_openenv_echo_baseurl.py new file mode 100644 index 00000000..f8c87bd3 --- /dev/null +++ b/tests/pytest/test_openenv_echo_baseurl.py @@ -0,0 +1,137 @@ +from typing import Any, Dict, List +import os + +import pytest + +from eval_protocol.models import EvaluationRow, Message, EvaluateResult +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor + +# Skip these integration-heavy tests on CI runners by default +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip OpenEnv integration tests on CI") + + +def echo_dataset_to_rows(data: List[Dict[str, Any]]) -> List[EvaluationRow]: + rows: List[EvaluationRow] = [] + for row in data: + prompt = str(row.get("prompt", "hello")) + rows.append(EvaluationRow(messages=[Message(role="user", content=prompt)])) + return rows + + +def prompt_builder(observation: Any, step: int, history: List[str]) -> str: + return "Please repeat back the next message exactly." + + +def action_parser(response_text: str): + try: + from envs.echo_env import EchoAction # type: ignore + except Exception: + pytest.skip("OpenEnv (envs.echo_env) is not installed; skipping Echo base_url test.") + raise + text = response_text.strip() if isinstance(response_text, str) else "" + return EchoAction(message=text or "hello") + + +def _score_from_system_rewards(row: EvaluationRow) -> float: + step_rewards: List[float] = [] + try: + for msg in row.messages or []: + if msg.role == "system" and isinstance(msg.content, str) and msg.content.startswith("__ep_step_rewards__:"): + import json as _json + payload = msg.content.split(":", 1)[1] + step_rewards = _json.loads(payload) or [] + break + except Exception: + step_rewards = [] + total = float(sum(step_rewards)) if step_rewards else 0.0 + return max(0.0, min(1.0, total)) + + +try: + from envs.echo_env import EchoEnv # type: ignore + _HAS_ECHO = True +except Exception: + _HAS_ECHO = False + + +@evaluation_test( # type: ignore[misc] + input_dataset=["tests/pytest/data/echo_dataset.jsonl"], + dataset_adapter=echo_dataset_to_rows, + completion_params=[ + { + "temperature": 0.0, + "max_tokens": 16, + "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct", + } + ], + num_runs=1, + max_concurrent_rollouts=2, + mode="pointwise", + rollout_processor=( + OpenEnvRolloutProcessor( + env_client_cls=EchoEnv, # type: ignore + env_base_url=os.getenv("OPENENV_ECHO_BASE_URL"), # e.g., http://0.0.0.0:8001 (docker or local) + prompt_builder=prompt_builder, + action_parser=action_parser, + timeout_ms=5000, + num_generations=1, + ) + if _HAS_ECHO + else None + ), +) +def test_openenv_echo_baseurl_local_or_docker(row: EvaluationRow) -> EvaluationRow: + """ + Base URL connectivity test for Echo env (local Python server or Docker). + Requires OPENENV_ECHO_BASE_URL to be set; otherwise, test is skipped. + """ + if not os.getenv("OPENENV_ECHO_BASE_URL"): + pytest.skip("OPENENV_ECHO_BASE_URL not set; skipping local/docker echo test.") + if not _HAS_ECHO: + pytest.skip("OpenEnv (envs.echo_env) is not installed; skipping Echo base_url test.") + score = _score_from_system_rewards(row) + row.evaluation_result = EvaluateResult(score=score, reason=f"Echo (base_url) score={score:.2f}") + return row + + +@evaluation_test( # type: ignore[misc] + input_dataset=["tests/pytest/data/echo_dataset.jsonl"], + dataset_adapter=echo_dataset_to_rows, + completion_params=[ + { + "temperature": 0.0, + "max_tokens": 16, + "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct", + } + ], + num_runs=1, + max_concurrent_rollouts=2, + mode="pointwise", + rollout_processor=( + OpenEnvRolloutProcessor( + env_client_cls=EchoEnv, # type: ignore + env_base_url=os.getenv("OPENENV_ECHO_SPACE_URL"), # e.g., https://openenv-echo-env.hf.space + prompt_builder=prompt_builder, + action_parser=action_parser, + timeout_ms=5000, + num_generations=1, + ) + if _HAS_ECHO + else None + ), +) +def test_openenv_echo_baseurl_space(row: EvaluationRow) -> EvaluationRow: + """ + Space URL connectivity test for Echo env (remote HF Space). + Requires OPENENV_ECHO_SPACE_URL to be set; otherwise, test is skipped. + """ + if not os.getenv("OPENENV_ECHO_SPACE_URL"): + pytest.skip("OPENENV_ECHO_SPACE_URL not set; skipping space echo test.") + if not _HAS_ECHO: + pytest.skip("OpenEnv (envs.echo_env) is not installed; skipping Echo base_url test.") + score = _score_from_system_rewards(row) + row.evaluation_result = EvaluateResult(score=score, reason=f"Echo (space) score={score:.2f}") + return row + + diff --git a/tests/pytest/test_openenv_echo_hub.py b/tests/pytest/test_openenv_echo_hub.py new file mode 100644 index 00000000..7ddd2b8c --- /dev/null +++ b/tests/pytest/test_openenv_echo_hub.py @@ -0,0 +1,109 @@ +from typing import Any, Dict, List +import os +import re + +from eval_protocol.models import EvaluationRow, Message, EvaluateResult +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor +import pytest +import os + +# Skip these integration-heavy tests on CI runners by default +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip OpenEnv integration tests on CI") + + +def echo_dataset_to_rows(data: List[Dict[str, Any]]) -> List[EvaluationRow]: + """ + Adapter: simple {"id": "...", "prompt": "..."} to EvaluationRows. + """ + rows: List[EvaluationRow] = [] + for row in data: + prompt = str(row.get("prompt", "hello")) + rows.append(EvaluationRow(messages=[Message(role="user", content=prompt)])) + return rows + + +def prompt_builder(observation: Any, step: int, history: List[str]) -> str: + """ + Echo env is very simple; we just send a short instruction. + """ + return "Please repeat back the next message exactly." + + +def action_parser(response_text: str): + """ + Convert raw model response to EchoAction. + """ + try: + from envs.echo_env import EchoAction # type: ignore + except Exception: + pytest.skip("OpenEnv (envs.echo_env) is not installed; skipping Echo hub test.") + raise + text = response_text.strip() if isinstance(response_text, str) else "" + return EchoAction(message=text or "hello") + + +try: + from envs.echo_env import EchoEnv # type: ignore + _HAS_ECHO = True +except Exception: + _HAS_ECHO = False + + +@evaluation_test( # type: ignore[misc] + input_dataset=["tests/pytest/data/echo_dataset.jsonl"], + dataset_adapter=echo_dataset_to_rows, + completion_params=[ + { + "temperature": 0.0, + "max_tokens": 16, + # Any working model with your API key; match other tests' default + "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct", + } + ], + num_runs=1, + max_concurrent_rollouts=2, + mode="pointwise", + rollout_processor=( + OpenEnvRolloutProcessor( + # Use HF Hub to launch the environment container automatically + env_client_cls=EchoEnv, # type: ignore + hub_repo_id=os.getenv("OPENENV_ECHO_REPO", "openenv/echo-env"), + # Simple prompt+parser above + prompt_builder=prompt_builder, + action_parser=action_parser, + # Keep defaults for timeouts/viewport/etc. (not relevant for echo) + timeout_ms=5000, + num_generations=1, + ) + if _HAS_ECHO + else None + ), +) +def test_openenv_echo_hub(row: EvaluationRow) -> EvaluationRow: + """ + Smoke test for Echo env via Hugging Face Hub (registry.hf.space/openenv-echo-env). + Extracts env rewards (from rollout policy extras) and sets evaluation_result. + """ + if not _HAS_ECHO: + pytest.skip("OpenEnv (envs.echo_env) is not installed; skipping Echo hub test.") + # Try to read rewards/usage left in execution metadata extra or system messages. + total_reward = 0.0 + try: + # Preferred path: system sentinel "__ep_step_rewards__" + step_rewards: List[float] = [] + for msg in row.messages or []: + if msg.role == "system" and isinstance(msg.content, str) and msg.content.startswith("__ep_step_rewards__:"): + import json as _json + payload = msg.content.split(":", 1)[1] + step_rewards = _json.loads(payload) or [] + break + total_reward = float(sum(step_rewards)) if step_rewards else 0.0 + except Exception: + total_reward = 0.0 + + score = max(0.0, min(1.0, total_reward)) + row.evaluation_result = EvaluateResult(score=score, reason=f"Echo total reward={total_reward:.2f}") + return row + + diff --git a/tests/pytest/test_openenv_textarena_wordle.py b/tests/pytest/test_openenv_textarena_wordle.py new file mode 100644 index 00000000..e7f39d8b --- /dev/null +++ b/tests/pytest/test_openenv_textarena_wordle.py @@ -0,0 +1,116 @@ +from typing import Any, Dict, List +import os +import re + +from eval_protocol.models import EvaluationRow, Message, EvaluateResult +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor +import pytest +import os + +# Skip these integration-heavy tests on CI runners by default +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip OpenEnv integration tests on CI") + + +def wordle_dataset_to_rows(data: List[Dict[str, Any]]) -> List[EvaluationRow]: + """ + Adapter: simple {"id": "...", "prompt": "..."} to EvaluationRows. + Prompts are ignored by the environment; they just seed the conversation. + """ + rows: List[EvaluationRow] = [] + for row in data: + prompt = str(row.get("prompt", "start")) + rows.append(EvaluationRow(messages=[Message(role="user", content=prompt)])) + return rows + + +def prompt_builder(observation: Any, step: int, history: List[str]) -> str: + """ + Build a minimal instruction for Wordle turns. + """ + prompt = getattr(observation, "prompt", "") or "" + return f"You are playing Wordle. Based on previous feedback, choose a valid 5-letter word.\nContext:\n{prompt}\nReply with only the guess." + + +def action_parser(response_text: str): + """ + Convert model response to TextArenaAction (message). + """ + try: + from envs.textarena_env import TextArenaAction # type: ignore + except Exception: + pytest.skip("OpenEnv (envs.textarena_env) is not installed; skipping TextArena test.") + raise + text = (response_text or "").strip() + # Keep only the first word-like token + guess = re.split(r"[^A-Za-z]+", text)[0] if text else "" + guess = guess[:5] if guess else "crane" + return TextArenaAction(message=guess.lower()) + + +def _score_from_rewards(row: EvaluationRow) -> float: + step_rewards: List[float] = [] + try: + for msg in row.messages or []: + if msg.role == "system" and isinstance(msg.content, str) and msg.content.startswith("__ep_step_rewards__:"): + import json as _json + payload = msg.content.split(":", 1)[1] + step_rewards = _json.loads(payload) or [] + break + except Exception: + step_rewards = [] + total = float(sum(step_rewards)) if step_rewards else 0.0 + # Clamp to [0,1] for dashboard score + return max(0.0, min(1.0, total)) + + +try: + from envs.textarena_env import TextArenaEnv # type: ignore + _HAS_TEXTARENA = True +except Exception: + _HAS_TEXTARENA = False + + +@evaluation_test( # type: ignore[misc] + input_dataset=["tests/pytest/data/wordle_dataset.jsonl"], + dataset_adapter=wordle_dataset_to_rows, + completion_params=[ + { + "temperature": 0.0, + "max_tokens": 8, + "model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct", + } + ], + num_runs=1, + max_concurrent_rollouts=1, + mode="pointwise", + rollout_processor=( + OpenEnvRolloutProcessor( + env_client_cls=TextArenaEnv, # type: ignore + hub_repo_id=os.getenv("OPENENV_TEXTARENA_REPO", "burtenshaw/textarena"), + # Pass Wordle settings to the container + env_vars={ + "TEXTARENA_ENV_ID": os.getenv("TEXTARENA_ENV_ID", "Wordle-v0"), + "TEXTARENA_NUM_PLAYERS": os.getenv("TEXTARENA_NUM_PLAYERS", "1"), + }, + prompt_builder=prompt_builder, + action_parser=action_parser, + timeout_ms=10000, + num_generations=1, + ) + if _HAS_TEXTARENA + else None + ), +) +def test_openenv_textarena_wordle_hub(row: EvaluationRow) -> EvaluationRow: + """ + Smoke test for TextArena Wordle via HF Hub (registry.hf.space/burtenshaw-textarena). + Requires Docker available to start the Space container. + """ + if not _HAS_TEXTARENA: + pytest.skip("OpenEnv (envs.textarena_env) is not installed; skipping TextArena test.") + score = _score_from_rewards(row) + row.evaluation_result = EvaluateResult(score=score, reason=f"Wordle total score={score:.2f}") + return row + +