diff --git a/src/adf_core_python/core/agent/info/world_info.py b/src/adf_core_python/core/agent/info/world_info.py index 0fce3a7..cc74d83 100644 --- a/src/adf_core_python/core/agent/info/world_info.py +++ b/src/adf_core_python/core/agent/info/world_info.py @@ -5,6 +5,7 @@ from rcrscore.entities.blockade import Blockade from rcrscore.entities.entity import Entity from rcrscore.entities.human import Human +from rcrscore.urn import EntityURN from rcrscore.worldmodel import ChangeSet, WorldModel @@ -54,6 +55,28 @@ def get_entity(self, entity_id: EntityID) -> Optional[Entity]: """ return self._world_model.get_entity(entity_id) + def get_entities(self) -> list[Entity]: + """ + Get all entities + + Returns + ------- + list[Entity] + Entities + """ + return self._world_model.get_entities() + + def get_entity_ids(self) -> list[EntityID]: + """ + Get all entity IDs + + Returns + ------- + list[EntityID] + Entity IDs + """ + return [entity.get_entity_id() for entity in self._world_model.get_entities()] + def get_entity_ids_of_types(self, entity_types: list[type[Entity]]) -> list[EntityID]: """ Get the entity IDs of the specified types @@ -96,6 +119,45 @@ def get_entities_of_types(self, entity_types: list[type[Entity]]) -> list[Entity return entities + def get_entity_ids_of_urns(self, urns: list[EntityURN]) -> list[EntityID]: + """ + Get the entity IDs of the specified URNs + + Parameters + ---------- + urns : list[EntityURN] + List of entity URNs + + Returns + ------- + list[EntityID] + Entity IDs + """ + entity_ids: list[EntityID] = [] + for entity in self._world_model.get_entities(): + if entity.get_urn() in urns: + entity_ids.append(entity.get_entity_id()) + + return entity_ids + + def get_entities_of_urns(self, urns: list[EntityURN]) -> list[Entity]: + """ + Get the entities of the specified URNs + Parameters + ---------- + urns : list[EntityURN] + List of entity URNs + Returns + ------- + list[Entity] + Entities + """ + entities: list[Entity] = [] + for entity in self._world_model.get_entities(): + if entity.get_urn() in urns: + entities.append(entity) + return entities + def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float: """ Get the distance between two entities