diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 6e60fb83a..c3c352441 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -3,36 +3,35 @@ from datetime import timedelta from enum import Enum from typing import ( + TYPE_CHECKING, + Any, + Callable, Dict, List, + Optional, Type, Union, - Any, - Callable, - Optional, - TYPE_CHECKING, ) from kloppy.domain.models.common import ( - DatasetType, AttackingDirection, + DatasetType, OrientationError, PositionType, ) from kloppy.utils import ( + DeprecatedEnumValue, camelcase_to_snakecase, - removes_suffix, - docstring_inherit_attributes, deprecated, - DeprecatedEnumValue, + docstring_inherit_attributes, + removes_suffix, ) +from ...exceptions import InvalidFilterError, KloppyError, OrphanedRecordError from .common import DataRecord, Dataset, Player, Team from .formation import FormationType from .pitch import Point -from ...exceptions import OrphanedRecordError, InvalidFilterError, KloppyError - if TYPE_CHECKING: from .tracking import Frame @@ -1119,6 +1118,136 @@ def _update_formations_and_positions(self): else: event.team.formations.set(event.time, event.formation_type) + def insert( + self, + event: Event, + position: Optional[int] = None, + before_event_id: Optional[str] = None, + after_event_id: Optional[str] = None, + timestamp: Optional[timedelta] = None, + scoring_function: Optional[ + Callable[[Event, "EventDataset"], float] + ] = None, + ): + """Inserts an event into the dataset at the appropriate position. + + Args: + event (Event): The event to be inserted into the dataset. + position (Optional[int]): The exact index where the event should be inserted. + If provided, overrides all other positioning parameters. Defaults to None. + before_event_id (Optional[str]): The ID of the event before which the new event + should be inserted. Ignored if `position` is provided. Defaults to None. + after_event_id (Optional[str]): The ID of the event after which the new event + should be inserted. Ignored if `position` or `before_event_id` is provided. + Defaults to None. + timestamp (Optional[timedelta]): The timestamp of the event, used to determine + its position based on chronological order if no other positional parameters + are specified. Defaults to None. + scoring_function (Optional[Callable[[Event, EventDataset], float]]): A custom + function that takes the event and dataset as arguments and returns a score + indicating how suitable the position is for insertion. Higher scores indicate + better placement. The new event will be inserted before the event that gives + the maximum score. If no valid position is found (i.e., all scores are zero), + the insertion will fail with a ValueError. Defaults to None. + scoring_function (Optional[Callable[[Event, EventDataset], float]]): A custom + function that takes the event and dataset as arguments and returns a score + indicating how suitable the position is for insertion. Negative scores mean + insertion should happen **before** the highest-scoring event, while positive + scores mean insertion should happen **after** the highest-scoring event. + If all scores are zero, the insertion will fail with a ValueError. + + Raises: + ValueError: If the insertion position cannot be determined or is invalid. + + Notes: + - If multiple parameters are provided to specify the position, the precedence is: + 1. `position` + 2. `before_event_id` + 3. `after_event_id` + 4. `timestamp` + 5. `scoring_function` + - If none of the above parameters are specified, the method raises a `ValueError`. + """ + if position is not None: + # If position is provided, use it directly + insert_position = position + + elif before_event_id is not None: + # Find the event with the matching `before_event_id` and insert before it + try: + insert_position = next( + ( + i + for i, e in enumerate(self.records) + if e.event_id == before_event_id + ), + ) + except StopIteration: + raise ValueError(f"No event found with ID {before_event_id}.") + + elif after_event_id is not None: + # Find the event with the matching `after_event_id` and insert after it + try: + insert_position = next( + ( + i + 1 + for i, e in enumerate(self.records) + if e.event_id == after_event_id + ), + ) + except StopIteration: + raise ValueError(f"No event found with ID {after_event_id}.") + + elif timestamp is not None: + # If no position or event IDs are specified, insert based on timestamp + insert_position = next( + ( + i + for i, e in enumerate(self.records) + if e.timestamp > timestamp + ), + len(self.records), + ) + + elif scoring_function is not None: + # Evaluate all possible positions using the constraint function + scores = [ + (i, scoring_function(event, self)) + for i, event in enumerate(self.records) + ] + # Select the best position with the highest score + best_index, best_score = max( + scores, key=lambda x: abs(x[1]), default=(0, -1) + ) + if best_score == 0: + raise ValueError( + "No valid insertion position found based on the provided scoring function." + ) + + # Insert after if score is positive, before if score is negative + insert_position = best_index + 1 if best_score > 0 else best_index + + else: + raise ValueError( + "Unable to determine insertion position for the event." + ) + + # Insert the event at the determined position + self.records.insert(insert_position, event) + + # Update the event's references + self.records[insert_position].dataset = self + for i in range( + max(0, insert_position - 1), + min(insert_position + 2, len(self.records)), + ): + self.records[i].prev_record = ( + self.records[i - 1] if i > 0 else None + ) + self.records[i].next_record = ( + self.records[i + 1] if i + 1 < len(self.records) else None + ) + @property def events(self): return self.records diff --git a/kloppy/tests/test_event.py b/kloppy/tests/test_event.py index c89c214bf..758e563fe 100644 --- a/kloppy/tests/test_event.py +++ b/kloppy/tests/test_event.py @@ -1,7 +1,15 @@ +from datetime import timedelta + import pytest from kloppy import statsbomb -from kloppy.domain import EventDataset +from kloppy.domain import ( + BallState, + CarryResult, + Event, + EventDataset, + EventFactory, +) class TestEvent: @@ -87,3 +95,116 @@ def test_find_all(self, dataset: EventDataset): assert goals[0].next("shot.goal") == goals[1] assert goals[0].next("shot.goal") == goals[2].prev("shot.goal") assert goals[2].next("shot.goal") is None + + def test_insert(self, dataset: EventDataset): + new_event = EventFactory().build_carry( + qualifiers=None, + timestamp=timedelta(seconds=700), + end_timestamp=timedelta(seconds=701), + result=CarryResult.COMPLETE, + period=dataset.metadata.periods[0], + ball_owning_team=dataset.metadata.teams[0], + ball_state="alive", + event_id="test-insert-1234", + team=dataset.metadata.teams[0], + player=dataset.metadata.teams[0].players[0], + coordinates=(0.2, 0.3), + end_coordinates=(0.22, 0.33), + raw_event=None, + ) + + # insert by position + dataset.insert(new_event, position=3) + assert dataset.events[3].event_id == "test-insert-1234" + del dataset.events[3] # Remove by index to restore the dataset + + # insert by before_event_id + dataset.insert(new_event, before_event_id=dataset.events[100].event_id) + assert dataset.events[100].event_id == "test-insert-1234" + del dataset.events[100] # Remove by index to restore the dataset + + # insert by after_event_id + dataset.insert(new_event, after_event_id=dataset.events[305].event_id) + assert dataset.events[306].event_id == "test-insert-1234" + del dataset.events[306] # Remove by index to restore the dataset + + # insert by timestamp + dataset.insert(new_event, timestamp=new_event.timestamp) + assert dataset.events[609].event_id == "test-insert-1234" + del dataset.events[609] # Remove by index to restore the dataset + + # insert using scoring function + def insert_after_scoring_function(event: Event, dataset: EventDataset): + if event.ball_owning_team != dataset.metadata.teams[0]: + return 0 + if event.period != new_event.period: + return 0 + return 1 / abs( + event.timestamp.total_seconds() + - new_event.timestamp.total_seconds() + ) + + dataset.insert( + new_event, scoring_function=insert_after_scoring_function + ) + assert dataset.events[608].event_id == "test-insert-1234" + del dataset.events[608] # Remove by index to restore the dataset + + # insert using scoring function + def insert_before_scoring_function( + event: Event, dataset: EventDataset + ): + if event.ball_owning_team != dataset.metadata.teams[0]: + return 0 + if event.period != new_event.period: + return 0 + return -1 / abs( + event.timestamp.total_seconds() + - new_event.timestamp.total_seconds() + ) + + dataset.insert( + new_event, scoring_function=insert_before_scoring_function + ) + assert dataset.events[607].event_id == "test-insert-1234" + del dataset.events[607] # Remove by index to restore the dataset + + def no_match_scoring_function(event: Event, dataset: EventDataset): + return 0 + + with pytest.raises(ValueError): + dataset.insert( + new_event, scoring_function=no_match_scoring_function + ) + + # update references + dataset.insert(new_event, position=1) + assert dataset.events[0].next_record.event_id == "test-insert-1234" + assert ( + dataset.events[1].prev_record.event_id + == dataset.events[0].event_id + ) + assert dataset.events[1].event_id == "test-insert-1234" + assert ( + dataset.events[1].next_record.event_id + == dataset.events[2].event_id + ) + assert dataset.events[2].prev_record.event_id == "test-insert-1234" + + dataset.insert(new_event, position=0) + assert dataset.events[0].prev_record is None + assert dataset.events[0].event_id == "test-insert-1234" + assert ( + dataset.events[0].next_record.event_id + == dataset.events[1].event_id + ) + assert dataset.events[1].prev_record.event_id == "test-insert-1234" + + dataset.insert(new_event, position=len(dataset)) + assert dataset.events[-2].next_record.event_id == "test-insert-1234" + assert ( + dataset.events[-1].prev_record.event_id + == dataset.events[-2].event_id + ) + assert dataset.events[-1].event_id == "test-insert-1234" + assert dataset.events[-1].next_record is None