diff --git a/.github/workflows/python-checks.yml b/.github/workflows/python-checks.yml index 7d478401..aec9a14e 100644 --- a/.github/workflows/python-checks.yml +++ b/.github/workflows/python-checks.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.12"] + python-version: ["3.12.10"] steps: - uses: actions/checkout@v3 diff --git a/AIDojoCoordinator/coordinator.py b/AIDojoCoordinator/coordinator.py index 504059c3..96fbd3f4 100644 --- a/AIDojoCoordinator/coordinator.py +++ b/AIDojoCoordinator/coordinator.py @@ -255,6 +255,7 @@ async def _fetch_initialization_objects(self): self.logger.error(f"Error fetching initialization objects: {e}") # Temporary fix self.task_config = ConfigParser(self._task_config_file) + def _load_initialization_objects(self)->None: """ Loads task configuration from a local file. @@ -383,6 +384,9 @@ async def start_tasks(self): self.logger.info(f"Rewards set to:{self._rewards}") self._min_required_players = self.task_config.get_required_num_players() self.logger.info(f"Min player requirement set to:{self._min_required_players}") + # run self initialization + self._initialize() + # start server for agent communication self._spawn_task(self.start_tcp_server) @@ -484,6 +488,9 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No "configuration_hash": self._CONFIG_FILE_HASH }, } + if hasattr(self, "_registration_info"): + for key, value in self._registration_info.items(): + output_message_dict["message"][key] = value await self._agent_response_queues[agent_addr].put(self.convert_msg_dict_to_json(output_message_dict)) else: self.logger.info( @@ -750,6 +757,7 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState self._agent_status[agent_addr] = AgentStatus.Playing self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr) self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}") + # create initial observation return Observation(self._agent_states[agent_addr], 0, False, {}) async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState: @@ -804,6 +812,12 @@ async def step(self, agent_id:tuple, agent_state:GameState, action:Action): async def reset(self): return NotImplemented + def _initialize(self): + """ + Initialize the game state and other necessary components. This is called at the start of the game after the configuration is loaded. + """ + return NotImplemented + def goal_check(self, agent_addr:tuple)->bool: """ Check if the goal conditons were satisfied in a given game state diff --git a/AIDojoCoordinator/worlds/NSEGameCoordinator.py b/AIDojoCoordinator/worlds/NSEGameCoordinator.py index d95aa40f..954ad0b5 100644 --- a/AIDojoCoordinator/worlds/NSEGameCoordinator.py +++ b/AIDojoCoordinator/worlds/NSEGameCoordinator.py @@ -889,8 +889,6 @@ def update_log_file(self, known_data:set, action, target_host:IP): self._data[hostaname].add(Data(owner="system", id="logfile", type="log", size=len(new_content) , content= new_content)) async def register_agent(self, agent_id, agent_role, agent_initial_view)->GameState: - if len(self._networks) == 0: - self._initialize() game_state = self._create_state_from_view(agent_initial_view) return game_state diff --git a/AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py b/AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py new file mode 100644 index 00000000..852c1134 --- /dev/null +++ b/AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py @@ -0,0 +1,185 @@ +import itertools +import argparse +import logging +import os +import json +from pathlib import Path +from AIDojoCoordinator.utils.utils import get_logging_level +from AIDojoCoordinator.game_components import Action, ActionType +from AIDojoCoordinator.worlds.NSEGameCoordinator import NSGCoordinator + + + + +class WhiteBoxNSGCoordinator(NSGCoordinator): + """ + WhiteBoxNSGCoordinator is an extension for the NetSecGame environment + that provides list of all possible actions to each agent that registers in the game. + """ + def __init__(self, game_host, game_port, task_config, allowed_roles=["Attacker", "Defender", "Benign"], seed=42, include_block_action=False): + super().__init__(game_host, game_port, task_config, allowed_roles, seed) + self._all_actions = None + self._include_block_action = include_block_action + + def _initialize(self): + # First do the parent initialization + super()._initialize() + # All components are initialized, now we can set the action mapping + self.logger.debug("Creating action mapping for the game.") + self._generate_all_actions() + self._registration_info = { + "all_actions": json.dumps([v.as_dict for v in self._all_actions]), + } + + + def _generate_all_actions(self)-> list: + """ + Generate a list of all possible actions for the game. + """ + actions = [] + all_ips = [self._ip_mapping[ip] for ip in self._ip_to_hostname.keys()] + all_networks = self._networks.keys() + all_data = set() + ip_with_services = {} + for ip in all_ips: + if ip in self._ip_to_hostname: + hostname = self._ip_to_hostname[ip] + if hostname in self._services: + ip_with_services[ip] = self._services[hostname] + + # Collect all data from all hosts + for data in self._data.values(): + all_data.update(data) + + # Network Scans + for source_host, target_network in itertools.product(all_ips, all_networks): + actions.append(Action( + ActionType.ScanNetwork, + parameters={ + "source_host": source_host, + "target_network": target_network + } + )) + + # Service Scans + for source_host, target_host in itertools.product(all_ips, all_ips): + actions.append(Action( + ActionType.FindServices, + parameters={ + "source_host": source_host, + "target_host": target_host + } + )) + # Service Exploits + for source_host, target_host in itertools.product(all_ips, ip_with_services.keys()): + for service in ip_with_services[target_host]: + actions.append(Action( + ActionType.ExploitService, + parameters={ + "source_host": source_host, + "target_host": target_host, + "target_service": service + } + )) + # Data Scans + for source_host, target_host in itertools.product(all_ips, all_ips): + actions.append(Action( + ActionType.FindData, + parameters={ + "source_host": source_host, + "target_host": target_host + } + )) + # Data transfers + for (source_host, target_host), datum in itertools.product(itertools.product(all_ips, all_ips), all_data): + actions.append(Action( + ActionType.ExfiltrateData, + parameters={ + "source_host": source_host, + "target_host": target_host, + "data": datum + } + )) + # Blocks + if self._include_block_action: + for (source_host, target_host), blocked_ip in itertools.product(itertools.product(all_ips, all_ips), all_ips): + actions.append(Action( + ActionType.BlockIP, + parameters={ + "source_host": source_host, + "target_host": target_host, + "blocked_ip": blocked_ip + } + )) + self.logger.info(f"Created action mapping with {len(actions)} actions.") + for action in actions: + self.logger.debug(action) + self._all_actions = actions + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NetSecGame Coordinator Server Author: Ondrej Lukas ondrej.lukas@aic.fel.cvut.cz", + usage="%(prog)s [options]", + ) + + parser.add_argument( + "-l", + "--debug_level", + help="Define the debug level for the logs. DEBUG, INFO, WARNING, ERROR, CRITICAL", + action="store", + required=False, + type=str, + default="INFO", + ) + + parser.add_argument( + "-gh", + "--game_host", + help="host where to run the game server", + action="store", + required=False, + type=str, + default="127.0.0.1", + ) + + parser.add_argument( + "-gp", + "--game_port", + help="Port where to run the game server", + action="store", + required=False, + type=int, + default="9000", + ) + + parser.add_argument( + "-c", + "--task_config", + help="File with the task configuration", + action="store", + required=True, + type=str, + default="netsecenv_conf.yaml", + ) + + args = parser.parse_args() + print(args) + # Set the logging + log_filename = Path("logs/WhiteBox_NSG_coordinator.log") + if not log_filename.parent.exists(): + os.makedirs(log_filename.parent) + + # Convert the logging level in the args to the level to use + pass_level = get_logging_level(args.debug_level) + + logging.basicConfig( + filename=log_filename, + filemode="w", + format="%(asctime)s %(name)s %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=pass_level, + ) + + game_server = WhiteBoxNSGCoordinator(args.game_host, args.game_port, args.task_config) + # Run it! + game_server.run() \ No newline at end of file