Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 187 additions & 99 deletions synth/validator/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import typing
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from multiprocessing import shared_memory
import time

# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
Expand All @@ -37,11 +38,14 @@

# Module level - must be picklable
def _crps_worker(args):
"""Standalone worker - no database, no complex objects"""
"""Standalone worker - no database, no complex objects.
Uses shared memory for real_prices to reduce memory duplication across processes.
"""
(
miner_uid,
prediction_array,
real_prices,
shm_name,
prices_shape,
time_increment,
scoring_intervals,
format_validation,
Expand Down Expand Up @@ -72,7 +76,7 @@ def _crps_worker(args):
process_time,
)

if len(real_prices) == 0:
if prices_shape[0] == 0:
return (
miner_uid,
-1,
Expand All @@ -83,48 +87,57 @@ def _crps_worker(args):
process_time,
)

prediction_array = adjust_predictions(list(prediction_array))

# Attach to shared memory for real_prices
existing_shm = shared_memory.SharedMemory(name=shm_name)
try:
simulation_runs = np.array(prediction_array).astype(float)
score, detailed_crps_data = calculate_crps_for_miner(
simulation_runs,
np.array(real_prices),
int(time_increment),
scoring_intervals,
real_prices = np.ndarray(
prices_shape, dtype=np.float64, buffer=existing_shm.buf
)

if np.isnan(score):
prediction_array = adjust_predictions(list(prediction_array))

try:
simulation_runs = np.array(prediction_array).astype(float)
score, detailed_crps_data = calculate_crps_for_miner(
simulation_runs,
real_prices, # Already a numpy array from shared memory
int(time_increment),
scoring_intervals,
)

if np.isnan(score):
return (
miner_uid,
-1,
detailed_crps_data,
f"Error calculating CRPS for miner {miner_uid}",
format_validation,
prediction_id,
process_time,
)

return (
miner_uid,
-1,
score,
detailed_crps_data,
f"Error calculating CRPS for miner {miner_uid}",
None,
format_validation,
prediction_id,
process_time,
)

return (
miner_uid,
score,
detailed_crps_data,
None,
format_validation,
prediction_id,
process_time,
)

except Exception as e:
return (
miner_uid,
-1,
[],
str(e),
format_validation,
prediction_id,
process_time,
)
except Exception as e:
return (
miner_uid,
-1,
[],
str(e),
format_validation,
prediction_id,
process_time,
)
finally:
existing_shm.close()


# Global executor - create once
Expand All @@ -139,6 +152,103 @@ def get_process_executor(nprocs: int = 2) -> ProcessPoolExecutor:
return _PROCESS_EXECUTOR


def _get_scoring_intervals(validator_request: ValidatorRequest) -> dict:
"""Determine scoring intervals based on time length."""
is_high_freq = (
validator_request.time_length
== prompt_config.HIGH_FREQUENCY.time_length
)
if is_high_freq:
return prompt_config.HIGH_FREQUENCY.scoring_intervals
return prompt_config.LOW_FREQUENCY.scoring_intervals


def _prepare_work_items(
predictions: list[MinerPrediction],
shm_name: str,
prices_shape: tuple,
validator_request: ValidatorRequest,
scoring_intervals: dict,
) -> list[tuple]:
"""Prepare picklable work items for multiprocess CRPS calculation."""
work_items = []

for pred in predictions:
# Convert to picklable types
format_val = pred.format_validation
# Convert enum to string if needed
if hasattr(format_val, "value"):
format_val = format_val.value
elif format_val == response_validation_v2.CORRECT:
format_val = "CORRECT"
else:
format_val = str(format_val)

work_items.append(
(
pred.miner_uid,
list(pred.prediction),
shm_name,
prices_shape,
int(validator_request.time_increment),
scoring_intervals,
format_val,
int(pred.id),
(
float(pred.process_time)
if pred.process_time is not None
else 0.0
),
)
)

return work_items


def _build_detailed_info(
predictions: list[MinerPrediction],
scores: list,
detailed_crps_data_list: list,
prompt_scores: np.ndarray,
miner_prediction_format_list: list,
miner_prediction_id_list: list,
miner_prediction_process_time: list,
percentile90: float,
lowest_score: float,
) -> list[dict]:
"""Build detailed information dict from processing results."""
return [
{
"miner_uid": pred.miner_uid,
"prompt_score_v3": float(prompt_score),
"percentile90": float(percentile90),
"lowest_score": float(lowest_score),
"miner_prediction_id": prediction_id,
"format_validation": format,
"process_time": process_time,
"total_crps": float(score),
"crps_data": clean_numpy_in_crps_data(crps_data),
}
for (
pred,
score,
crps_data,
prompt_score,
format,
prediction_id,
process_time,
) in zip(
predictions,
scores,
detailed_crps_data_list,
prompt_scores,
miner_prediction_format_list,
miner_prediction_id_list,
miner_prediction_process_time,
)
]


@print_execution_time
def get_rewards_multiprocess(
miner_data_handler: MinerDataHandler,
Expand All @@ -148,6 +258,7 @@ def get_rewards_multiprocess(
) -> tuple[typing.Optional[np.ndarray], list, list[dict]]:
"""
Returns an array of rewards for the given query and responses.
Uses shared memory for real_prices to reduce memory duplication across worker processes.

Args:
- query (int): The query sent to the miner.
Expand All @@ -173,50 +284,35 @@ def get_rewards_multiprocess(
int(validator_request.id)
)

# Prepare picklable work items
scoring_intervals = (
prompt_config.HIGH_FREQUENCY.scoring_intervals
if validator_request.time_length
== prompt_config.HIGH_FREQUENCY.time_length
else prompt_config.LOW_FREQUENCY.scoring_intervals
# Create shared memory for real_prices to avoid duplicating across workers
prices_array = np.array(real_prices, dtype=np.float64)
shm = shared_memory.SharedMemory(create=True, size=prices_array.nbytes)
shared_prices = np.ndarray(
prices_array.shape, dtype=np.float64, buffer=shm.buf
)
shared_prices[:] = prices_array[:]

# Prepare work items
scoring_intervals = _get_scoring_intervals(validator_request)
work_items = _prepare_work_items(
predictions,
shm.name,
prices_array.shape,
validator_request,
scoring_intervals,
)

work_items = []

for pred in predictions:
# Convert to picklable types
format_val = pred.format_validation
# Convert enum to string if needed
if hasattr(format_val, "value"):
format_val = format_val.value
elif format_val == response_validation_v2.CORRECT:
format_val = "CORRECT"
else:
format_val = str(format_val)

work_items.append(
(
pred.miner_uid,
list(pred.prediction),
real_prices,
int(validator_request.time_increment),
scoring_intervals,
format_val,
int(pred.id),
(
float(pred.process_time)
if pred.process_time is not None
else 0.0
),
)
)

# Process in parallel (CPU bound - use ProcessPool)
bt.logging.info(f"Starting CRPS calculation for {len(work_items)} miners")
t0 = time.time()

executor = get_process_executor(nprocs)
results = list(executor.map(_crps_worker, work_items))
try:
executor = get_process_executor(nprocs)
results = list(executor.map(_crps_worker, work_items))
finally:
# Clean up shared memory
shm.close()
shm.unlink()

bt.logging.info(f"CRPS done in {time.time() - t0:.2f}s")

Expand All @@ -227,7 +323,6 @@ def get_rewards_multiprocess(
miner_prediction_id_list = []
miner_prediction_process_time = []

# Create lookup for original prediction objects
for (
miner_uid,
score,
Expand All @@ -254,28 +349,17 @@ def get_rewards_multiprocess(
if prompt_scores is None:
return None, [], []

detailed_info = [
{
"miner_uid": pred.miner_uid,
"prompt_score_v3": float(prompt_score),
"percentile90": float(percentile90),
"lowest_score": float(lowest_score),
"miner_prediction_id": prediction_id,
"format_validation": format,
"process_time": process_time,
"total_crps": float(score),
"crps_data": clean_numpy_in_crps_data(crps_data),
}
for pred, score, crps_data, prompt_score, format, prediction_id, process_time in zip(
predictions,
scores,
detailed_crps_data_list,
prompt_scores,
miner_prediction_format_list,
miner_prediction_id_list,
miner_prediction_process_time,
)
]
detailed_info = _build_detailed_info(
predictions,
scores,
detailed_crps_data_list,
prompt_scores,
miner_prediction_format_list,
miner_prediction_id_list,
miner_prediction_process_time,
percentile90,
lowest_score,
)

return prompt_scores, detailed_info, real_prices

Expand Down Expand Up @@ -325,7 +409,8 @@ def reward(
t3 = time.time()
except Exception:
bt.logging.exception(
f"Error calculating CRPS for miner {miner_uid} with prediction_id {miner_prediction.id}"
f"Error calculating CRPS for miner {miner_uid} "
f"with prediction_id {miner_prediction.id}"
)
return -1, [], miner_prediction

Expand All @@ -337,7 +422,8 @@ def reward(

if np.isnan(score):
bt.logger.warning(
f"CRPS calculation returned NaN for miner {miner_uid} with prediction_id {miner_prediction.id}"
f"CRPS calculation returned NaN for miner {miner_uid} "
f"with prediction_id {miner_prediction.id}"
)
return -1, detailed_crps_data, miner_prediction

Expand Down Expand Up @@ -370,7 +456,8 @@ def get_rewards(
real_prices = price_data_provider.fetch_data(validator_request)
except Exception as e:
bt.logging.warning(
f"Error fetching data for validator request {validator_request.id}: {e}"
f"Error fetching data for validator request "
f"{validator_request.id}: {e}"
)
return None, [], []

Expand Down Expand Up @@ -460,7 +547,8 @@ def get_rewards_threading(
real_prices = price_data_provider.fetch_data(validator_request)
except Exception as e:
bt.logging.warning(
f"Error fetching data for validator request {validator_request.id}: {e}"
f"Error fetching data for validator request "
f"{validator_request.id}: {e}"
)
return None, [], []

Expand Down