Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion AIDojoCoordinator/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por
self._agent_steps = {}
# reset request per agent_addr (bool)
self._reset_requests = {}
self._randomize_topology_requests = {}
self._agent_status = {}
self._episode_ends = {}
self._agent_observations = {}
Expand Down Expand Up @@ -547,8 +548,10 @@ async def _process_reset_game_action(self, agent_addr: tuple, reset_action:Actio
"""
self.logger.debug("Beginning the _process_reset_game_action.")
async with self._reset_lock:
# add reset request for this agent
# add reset request for this agent
self._reset_requests[agent_addr] = True
# register if the agent wants to randomize the topology
self._randomize_topology_requests[agent_addr] = reset_action.parameters.get("randomize_topology", False)
if all(self._reset_requests.values()):
# all agents want reset - reset the world
self.logger.debug(f"All agents requested reset, setting the event")
Expand Down Expand Up @@ -724,6 +727,7 @@ async def _reset_game(self):
self._agent_observations[agent] = new_observation
self._episode_ends[agent] = False
self._reset_requests[agent] = False
self._randomize_topology_requests[agent] = False
self._agent_rewards[agent] = 0
self._agent_steps[agent] = 0
self._agent_false_positives[agent] = 0
Expand Down Expand Up @@ -788,6 +792,9 @@ async def _remove_agent_from_game(self, agent_addr):
agent_info["agent_status"] = self._agent_status.pop(agent_addr)
agent_info["false_positives"] = self._agent_false_positives.pop(agent_addr)
async with self._reset_lock:
# remove agent from topology reset requests
agent_info["topology_reset_request"] = self._randomize_topology_requests.pop(agent_addr, False)
# remove agent from reset requests
agent_info["reset_request"] = self._reset_requests.pop(agent_addr)
# check if this agent was not preventing reset
if any(self._reset_requests.values()):
Expand Down
9 changes: 7 additions & 2 deletions AIDojoCoordinator/game_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ def as_dict(self) -> Dict[str, Any]:
for k, v in self.parameters.items():
if hasattr(v, '__dict__'): # Handle custom objects like Service, Data, AgentInfo
params[k] = asdict(v)
elif isinstance(v, bool): # Handle boolean values
params[k] = v
else:
params[k] = str(v)
return {"action_type": str(self.action_type), "parameters": params}
Expand Down Expand Up @@ -448,8 +450,11 @@ def from_dict(cls, data_dict: Dict[str, Any]) -> "Action":
params[k] = Data.from_dict(v)
case "agent_info":
params[k] = AgentInfo.from_dict(v)
case "request_trajectory":
params[k] = ast.literal_eval(v)
case "request_trajectory" | "randomize_topology":
if isinstance(v, bool):
params[k] = v
else:
params[k] = ast.literal_eval(v)
case _:
raise ValueError(f"Unsupported value in {k}: {v}")
return cls(action_type=action_type, parameters=params)
Expand Down
2 changes: 1 addition & 1 deletion AIDojoCoordinator/netsecenv_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ env:
# random_seed: 42
scenario: 'scenario1'
use_global_defender: False
use_dynamic_addresses: False
use_dynamic_addresses: True
use_firewall: True
save_trajectories: False
required_players: 1
Expand Down
54 changes: 48 additions & 6 deletions AIDojoCoordinator/worlds/NSEGameCoordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class NSGCoordinator(GameCoordinator):

def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attacker", "Defender", "Benign"], seed=42):
def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attacker", "Defender", "Benign"], seed=None):
super().__init__(game_host, game_port, service_host=None, service_port=None, allowed_roles=allowed_roles, task_config_file=task_config)

# Internal data structure of the NSG
Expand All @@ -44,7 +44,17 @@ def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attack
self._seed = seed
self.logger.info(f'Setting env seed to {seed}')

def _initialize(self)->None:
def _initialize(self) -> None:
"""
Initializes the NetSecGame environment.

Loads the CYST configuration, sets up dynamic IP and network address generation if enabled,
and stores original copies of environment data structures for later resets. Also seeds the
random number generator for reproducibility and logs the completion of initialization.

Returns:
None
"""
# Load CYST configuration
self._process_cyst_config(self._cyst_objects)
# Check if dynamic network and ip adddresses are required
Expand Down Expand Up @@ -84,7 +94,16 @@ def _get_controlled_hosts_from_view(self, view_controlled_hosts:Iterable)->set:
return controlled_hosts

def _get_services_from_view(self, view_known_services:dict)->dict:
known_services ={}
"""
Parses view and translates all keywords. Produces dict of known services {IP: set(Service)}

Args:
view_known_services (dict): The view containing known services information.

Returns:
dict: A dictionary mapping IP addresses to sets of known services.
"""
known_services = {}
for ip, service_list in view_known_services.items():
if self._ip_mapping[ip] not in known_services:
known_services[self._ip_mapping[ip]] = set()
Expand All @@ -101,6 +120,15 @@ def _get_services_from_view(self, view_known_services:dict)->dict:
return known_services

def _get_data_from_view(self, view_known_data:dict)->dict:
"""
Parses view and translates all keywords. Produces dict of known data {IP: set(Data)}

Args:
view_known_data (dict): The view containing known data information.

Returns:
dict: A dictionary mapping IP addresses to sets of known data.
"""
known_data = {}
for ip, data_list in view_known_data.items():
if self._ip_mapping[ip] not in known_data:
Expand Down Expand Up @@ -920,7 +948,11 @@ async def reset(self)->bool:
self.logger.info('--- Reseting NSG Environment to its initial state ---')
# change IPs if needed
if self.task_config.get_use_dynamic_addresses():
self._create_new_network_mapping()
if all(self._randomize_topology_requests.values()):
self.logger.info("All agents requested reset with randomized topology.")
self._create_new_network_mapping()
else:
self.logger.info("Not all agents requested a topology randomization. Keeping the current one.")
# reset self._data to orignal state
self._data = copy.deepcopy(self._data_original)
# reset self._data_content to orignal state
Expand Down Expand Up @@ -977,6 +1009,16 @@ async def reset(self)->bool:
default="netsecenv_conf.yaml",
)

parser.add_argument(
"-s",
"--seed",
help="Random seed for the environment",
action="store",
required=False,
type=int,
default=42,
)

args = parser.parse_args()
print(args)
# Set the logging
Expand All @@ -994,7 +1036,7 @@ async def reset(self)->bool:
datefmt="%Y-%m-%d %H:%M:%S",
level=pass_level,
)
game_server = NSGCoordinator(args.game_host, args.game_port, args.task_config)

game_server = NSGCoordinator(args.game_host, args.game_port, args.task_config, seed=args.seed)
# Run it!
game_server.run()
13 changes: 12 additions & 1 deletion tests/components/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,18 @@ def test_action_to_dict_reset_game(self):
assert action == new_action
assert action_dict["action_type"] == str(action.type)
assert len(action_dict["parameters"]) == 0

action = Action(
action_type=ActionType.ResetGame,
parameters={"request_trajectory": True, "randomize_topology": False}
)
action_dict = action.as_dict
new_action = Action.from_dict(action_dict)
assert action == new_action
assert action_dict["action_type"] == str(action.type)
assert len(action_dict["parameters"]) == 2
assert action_dict["parameters"]["request_trajectory"] is True
assert action_dict["parameters"]["randomize_topology"] is False

def test_action_to_dict_quit_game(self):
action = Action(
action_type=ActionType.QuitGame,
Expand Down