Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -21,11 +21,12 @@
import signal
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional
from unittest.mock import Mock

import toml
import yaml
from pydantic import ValidationError

from cloudai.core import (
BaseInstaller,
Expand Down Expand Up @@ -145,7 +146,24 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
continue

env = CloudAIGymEnv(test_run=test_run, runner=runner.runner)
agent = agent_class(env)

try:
agent_overrides = (
validate_agent_overrides(agent_type, test_run.test.agent_config)
if test_run.test.agent_config is not None
else None
)
except ValidationError as e:
logging.error(f"Invalid agent_config for agent '{agent_type}': ")
for error in e.errors():
logging.error(f" - {'.'.join(str(var_name) for var_name in error['loc'])}: {error['msg']}")
logging.error("Valid overrides: ")
for item, desc in validate_agent_overrides(agent_type).items():
logging.error(f" - {item}: {desc}")
err = 1
continue
agent = agent_class(env, **agent_overrides) if agent_overrides is not None else agent_class(env)

for step in range(agent.max_steps):
result = agent.select_action()
if result is None:
Expand All @@ -166,6 +184,37 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
return err


def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]] = None) -> dict[str, Any]:
"""
Validate and process agent configuration overrides.

If agent_config is empty, returns the available configuration fields for the agent type.
"""
registry = Registry()
config_class_map = {}
for agent_name, agent_class in registry.agents_map.items():
if agent_class.config:
config_class_map[agent_name] = agent_class.config

config_class = config_class_map.get(agent_type)
if not config_class:
valid_types = ", ".join(f"'{agent_name}'" for agent_name in config_class_map)
raise ValueError(
f"Agent type '{agent_type}' does not support configuration overrides. "
f"Valid agent types are: {valid_types}. "
)

if agent_config:
validated_config = config_class.model_validate(agent_config)
agent_kwargs = validated_config.model_dump(exclude_none=True)
logging.info(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}")
else:
agent_kwargs = {}
for field_name, field_info in config_class.model_fields.items():
agent_kwargs[field_name] = field_info.description
return agent_kwargs


def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None:
registry = Registry()

Expand Down
8 changes: 6 additions & 2 deletions src/cloudai/configurator/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,7 +15,9 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

from cloudai.models.agent_config import AgentConfig

from .base_gym import BaseGym

Expand All @@ -28,6 +30,8 @@ class BaseAgent(ABC):
Automatically infers parameter types from TestRun's cmd_args.
"""

config: Optional[AgentConfig] = None

def __init__(self, env: BaseGym):
"""
Initialize the agent with the environment.
Expand Down
27 changes: 27 additions & 0 deletions src/cloudai/models/agent_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field


class AgentConfig(BaseModel, ABC):
"""Base configuration for agent overrides."""

model_config = ConfigDict(extra="forbid")
random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
3 changes: 2 additions & 1 deletion src/cloudai/models/workload.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -107,6 +107,7 @@ class TestDefinition(BaseModel, ABC):
agent_steps: int = 1
agent_metrics: list[str] = Field(default=["default"])
agent_reward_function: str = "inverse"
agent_config: Optional[dict[str, Any]] = None

@property
def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]:
Expand Down