Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/adf_core_python/core/agent/info/world_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from rcrscore.entities.blockade import Blockade
from rcrscore.entities.entity import Entity
from rcrscore.entities.human import Human
from rcrscore.urn import EntityURN
from rcrscore.worldmodel import ChangeSet, WorldModel


Expand Down Expand Up @@ -54,6 +55,28 @@ def get_entity(self, entity_id: EntityID) -> Optional[Entity]:
"""
return self._world_model.get_entity(entity_id)

def get_entities(self) -> list[Entity]:
"""
Get all entities

Returns
-------
list[Entity]
Entities
"""
return self._world_model.get_entities()

def get_entity_ids(self) -> list[EntityID]:
"""
Get all entity IDs

Returns
-------
list[EntityID]
Entity IDs
"""
return [entity.get_entity_id() for entity in self._world_model.get_entities()]

def get_entity_ids_of_types(self, entity_types: list[type[Entity]]) -> list[EntityID]:
"""
Get the entity IDs of the specified types
Expand Down Expand Up @@ -96,6 +119,45 @@ def get_entities_of_types(self, entity_types: list[type[Entity]]) -> list[Entity

return entities

def get_entity_ids_of_urns(self, urns: list[EntityURN]) -> list[EntityID]:
"""
Get the entity IDs of the specified URNs

Parameters
----------
urns : list[EntityURN]
List of entity URNs

Returns
-------
list[EntityID]
Entity IDs
"""
entity_ids: list[EntityID] = []
for entity in self._world_model.get_entities():
if entity.get_urn() in urns:
entity_ids.append(entity.get_entity_id())

return entity_ids

def get_entities_of_urns(self, urns: list[EntityURN]) -> list[Entity]:
"""
Get the entities of the specified URNs
Parameters
----------
urns : list[EntityURN]
List of entity URNs
Returns
-------
list[Entity]
Entities
"""
entities: list[Entity] = []
for entity in self._world_model.get_entities():
if entity.get_urn() in urns:
entities.append(entity)
return entities

def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float:
"""
Get the distance between two entities
Expand Down
Loading