Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 280 additions & 57 deletions notebooks/tutorials/02_parallel_execution_on_ray.ipynb

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions src/orcapod/core/pods.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
ArrowPacket,
DictPacket,
)
from functools import wraps

from orcapod.utils.git_utils import get_git_info_for_python_object
from orcapod.core.kernels import KernelStream, TrackedKernelBase
from orcapod.core.operators import Join
Expand Down Expand Up @@ -252,9 +254,14 @@ def function_pod(
"""

def decorator(func: Callable) -> CallableWithPod:

if func.__name__ == "<lambda>":
raise ValueError("Lambda functions cannot be used with function_pod")

@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

# Store the original function in the module for pickling purposes
# and make sure to change the name of the function

Expand All @@ -267,9 +274,8 @@ def decorator(func: Callable) -> CallableWithPod:
label=label,
**kwargs,
)
setattr(func, "pod", pod)
return cast(CallableWithPod, func)

setattr(wrapper, "pod", pod)
return cast(CallableWithPod, wrapper)
return decorator


Expand Down
1 change: 0 additions & 1 deletion src/orcapod/core/sources/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from ast import Not
from collections.abc import Collection, Iterator
from typing import TYPE_CHECKING, Any

Expand Down
4 changes: 4 additions & 0 deletions src/orcapod/core/streams/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from calendar import c
import logging
from abc import abstractmethod
from collections.abc import Collection, Iterator, Mapping
Expand Down Expand Up @@ -475,6 +476,9 @@ def flow(

def _repr_html_(self) -> str:
df = self.as_polars_df()
# reorder columns
new_column_order = [c for c in df.columns if c in self.tag_keys()] + [c for c in df.columns if c not in self.tag_keys()]
df = df[new_column_order]
tag_map = {t: f"*{t}" for t in self.tag_keys()}
# TODO: construct repr html better
df = df.rename(tag_map)
Expand Down
206 changes: 37 additions & 169 deletions src/orcapod/core/streams/pod_node_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,70 +57,24 @@ def mode(self) -> str:

async def run_async(
self,
*args: Any,
execution_engine: cp.ExecutionEngine | None = None,
execution_engine_opts: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""
Runs the stream, processing the input stream and preparing the output stream.
This is typically called before iterating over the packets.
"""
if self._cached_output_packets is None:
cached_results = []

# identify all entries in the input stream for which we still have not computed packets
target_entries = self.input_stream.as_table(
include_content_hash=constants.INPUT_PACKET_HASH,
include_source=True,
include_system_tags=True,
)
existing_entries = self.pod_node.get_all_cached_outputs(
include_system_columns=True
cached_results, missing = self._identify_existing_and_missing_entries(*args,
execution_engine=execution_engine,
execution_engine_opts=execution_engine_opts,
**kwargs,
)
if existing_entries is None or existing_entries.num_rows == 0:
missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH])
existing = None
else:
all_results = target_entries.join(
existing_entries.append_column(
"_exists", pa.array([True] * len(existing_entries))
),
keys=[constants.INPUT_PACKET_HASH],
join_type="left outer",
right_suffix="_right",
)
# grab all columns from target_entries first
missing = (
all_results.filter(pc.is_null(pc.field("_exists")))
.select(target_entries.column_names)
.drop_columns([constants.INPUT_PACKET_HASH])
)

existing = all_results.filter(
pc.is_valid(pc.field("_exists"))
).drop_columns(
[
"_exists",
constants.INPUT_PACKET_HASH,
constants.PACKET_RECORD_ID,
*self.input_stream.keys()[1], # remove the input packet keys
]
# TODO: look into NOT fetching back the record ID
)

renamed = [
c.removesuffix("_right") if c.endswith("_right") else c
for c in existing.column_names
]
existing = existing.rename_columns(renamed)

tag_keys = self.input_stream.keys()[0]

if existing is not None and existing.num_rows > 0:
# If there are existing entries, we can cache them
existing_stream = TableStream(existing, tag_columns=tag_keys)
for tag, packet in existing_stream.iter_packets():
cached_results.append((tag, packet))

pending_calls = []
if missing is not None and missing.num_rows > 0:
for tag, packet in TableStream(missing, tag_columns=tag_keys):
Expand All @@ -134,23 +88,23 @@ async def run_async(
or self._execution_engine_opts,
)
pending_calls.append(pending)
import asyncio

import asyncio
completed_calls = await asyncio.gather(*pending_calls)
for result in completed_calls:
cached_results.append(result)

self.clear_cache()
self._cached_output_packets = cached_results
self._set_modified_time()
self.pod_node.flush()

def run(
self,
*args: Any,
def _identify_existing_and_missing_entries(self,
*args: Any,
execution_engine: cp.ExecutionEngine | None = None,
execution_engine_opts: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
cached_results = []
**kwargs: Any) -> tuple[list[tuple[cp.Tag, cp.Packet|None]], pa.Table | None]:
cached_results: list[tuple[cp.Tag, cp.Packet|None]] = []

# identify all entries in the input stream for which we still have not computed packets
if len(args) > 0 or len(kwargs) > 0:
Expand Down Expand Up @@ -223,6 +177,25 @@ def run(
for tag, packet in existing_stream.iter_packets():
cached_results.append((tag, packet))



return cached_results, missing

def run(
self,
*args: Any,
execution_engine: cp.ExecutionEngine | None = None,
execution_engine_opts: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
tag_keys = self.input_stream.keys()[0]
cached_results, missing = self._identify_existing_and_missing_entries(
*args,
execution_engine=execution_engine,
execution_engine_opts=execution_engine_opts,
**kwargs,
)

if missing is not None and missing.num_rows > 0:
packet_record_to_output_lut: dict[str, cp.Packet | None] = {}
execution_engine_hash = (
Expand Down Expand Up @@ -257,11 +230,14 @@ def run(
)
cached_results.append((tag, output_packet))


# reset the cache and set new results
self.clear_cache()
self._cached_output_packets = cached_results
self._set_modified_time()
self.pod_node.flush()
# TODO: evaluate proper handling of cache here
self.clear_cache()
# self.clear_cache()

def clear_cache(self) -> None:
self._cached_output_packets = None
Expand Down Expand Up @@ -300,115 +276,7 @@ def iter_packets(
self._cached_output_packets = cached_results
self._set_modified_time()

# if self._cached_output_packets is None:
# cached_results = []

# # identify all entries in the input stream for which we still have not computed packets
# target_entries = self.input_stream.as_table(
# include_system_tags=True,
# include_source=True,
# include_content_hash=constants.INPUT_PACKET_HASH,
# execution_engine=execution_engine,
# )
# existing_entries = self.pod_node.get_all_cached_outputs(
# include_system_columns=True
# )
# if existing_entries is None or existing_entries.num_rows == 0:
# missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH])
# existing = None
# else:
# # missing = target_entries.join(
# # existing_entries,
# # keys=[constants.INPUT_PACKET_HASH],
# # join_type="left anti",
# # )
# # Single join that gives you both missing and existing
# # More efficient - only bring the key column from existing_entries
# # .select([constants.INPUT_PACKET_HASH]).append_column(
# # "_exists", pa.array([True] * len(existing_entries))
# # ),

# # TODO: do more proper replacement operation
# target_df = pl.DataFrame(target_entries)
# existing_df = pl.DataFrame(
# existing_entries.append_column(
# "_exists", pa.array([True] * len(existing_entries))
# )
# )
# all_results_df = target_df.join(
# existing_df,
# on=constants.INPUT_PACKET_HASH,
# how="left",
# suffix="_right",
# )
# all_results = all_results_df.to_arrow()
# # all_results = target_entries.join(
# # existing_entries.append_column(
# # "_exists", pa.array([True] * len(existing_entries))
# # ),
# # keys=[constants.INPUT_PACKET_HASH],
# # join_type="left outer",
# # right_suffix="_right", # rename the existing records in case of collision of output packet keys with input packet keys
# # )
# # grab all columns from target_entries first
# missing = (
# all_results.filter(pc.is_null(pc.field("_exists")))
# .select(target_entries.column_names)
# .drop_columns([constants.INPUT_PACKET_HASH])
# )

# existing = all_results.filter(
# pc.is_valid(pc.field("_exists"))
# ).drop_columns(
# [
# "_exists",
# constants.INPUT_PACKET_HASH,
# constants.PACKET_RECORD_ID,
# *self.input_stream.keys()[1], # remove the input packet keys
# ]
# # TODO: look into NOT fetching back the record ID
# )
# renamed = [
# c.removesuffix("_right") if c.endswith("_right") else c
# for c in existing.column_names
# ]
# existing = existing.rename_columns(renamed)

# tag_keys = self.input_stream.keys()[0]

# if existing is not None and existing.num_rows > 0:
# # If there are existing entries, we can cache them
# existing_stream = TableStream(existing, tag_columns=tag_keys)
# for tag, packet in existing_stream.iter_packets():
# cached_results.append((tag, packet))
# yield tag, packet

# if missing is not None and missing.num_rows > 0:
# hash_to_output_lut: dict[str, cp.Packet | None] = {}
# for tag, packet in TableStream(missing, tag_columns=tag_keys):
# # Since these packets are known to be missing, skip the cache lookup
# packet_hash = packet.content_hash().to_string()
# if packet_hash in hash_to_output_lut:
# output_packet = hash_to_output_lut[packet_hash]
# else:
# tag, output_packet = self.pod_node.call(
# tag,
# packet,
# skip_cache_lookup=True,
# execution_engine=execution_engine,
# )
# hash_to_output_lut[packet_hash] = output_packet
# cached_results.append((tag, output_packet))
# if output_packet is not None:
# yield tag, output_packet

# self._cached_output_packets = cached_results
# self._set_modified_time()
# else:
# for tag, packet in self._cached_output_packets:
# if packet is not None:
# yield tag, packet


def keys(
self, include_system_tags: bool = False
) -> tuple[tuple[str, ...], tuple[str, ...]]:
Expand Down
5 changes: 5 additions & 0 deletions src/orcapod/execution_engines/ray_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class RayEngine:
3. No polling needed - Ray handles async integration
"""

@property
def supports_async(self) -> bool:
"""Indicate that this engine supports async execution."""
return True

def __init__(self, ray_address: str | None = None, **ray_init_kwargs):
"""Initialize Ray with native async support."""

Expand Down
Loading
Loading