diff --git a/synth/validator/reward.py b/synth/validator/reward.py index 5fd0f1e..f980887 100644 --- a/synth/validator/reward.py +++ b/synth/validator/reward.py @@ -13,6 +13,7 @@ import typing from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from multiprocessing import shared_memory import time # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO @@ -37,11 +38,14 @@ # Module level - must be picklable def _crps_worker(args): - """Standalone worker - no database, no complex objects""" + """Standalone worker - no database, no complex objects. + Uses shared memory for real_prices to reduce memory duplication across processes. + """ ( miner_uid, prediction_array, - real_prices, + shm_name, + prices_shape, time_increment, scoring_intervals, format_validation, @@ -72,7 +76,7 @@ def _crps_worker(args): process_time, ) - if len(real_prices) == 0: + if prices_shape[0] == 0: return ( miner_uid, -1, @@ -83,48 +87,57 @@ def _crps_worker(args): process_time, ) - prediction_array = adjust_predictions(list(prediction_array)) - + # Attach to shared memory for real_prices + existing_shm = shared_memory.SharedMemory(name=shm_name) try: - simulation_runs = np.array(prediction_array).astype(float) - score, detailed_crps_data = calculate_crps_for_miner( - simulation_runs, - np.array(real_prices), - int(time_increment), - scoring_intervals, + real_prices = np.ndarray( + prices_shape, dtype=np.float64, buffer=existing_shm.buf ) - if np.isnan(score): + prediction_array = adjust_predictions(list(prediction_array)) + + try: + simulation_runs = np.array(prediction_array).astype(float) + score, detailed_crps_data = calculate_crps_for_miner( + simulation_runs, + real_prices, # Already a numpy array from shared memory + int(time_increment), + scoring_intervals, + ) + + if np.isnan(score): + return ( + miner_uid, + -1, + detailed_crps_data, + f"Error calculating CRPS for miner {miner_uid}", + format_validation, + prediction_id, + process_time, + ) + return ( miner_uid, - -1, + score, detailed_crps_data, - f"Error calculating CRPS for miner {miner_uid}", + None, format_validation, prediction_id, process_time, ) - return ( - miner_uid, - score, - detailed_crps_data, - None, - format_validation, - prediction_id, - process_time, - ) - - except Exception as e: - return ( - miner_uid, - -1, - [], - str(e), - format_validation, - prediction_id, - process_time, - ) + except Exception as e: + return ( + miner_uid, + -1, + [], + str(e), + format_validation, + prediction_id, + process_time, + ) + finally: + existing_shm.close() # Global executor - create once @@ -139,6 +152,103 @@ def get_process_executor(nprocs: int = 2) -> ProcessPoolExecutor: return _PROCESS_EXECUTOR +def _get_scoring_intervals(validator_request: ValidatorRequest) -> dict: + """Determine scoring intervals based on time length.""" + is_high_freq = ( + validator_request.time_length + == prompt_config.HIGH_FREQUENCY.time_length + ) + if is_high_freq: + return prompt_config.HIGH_FREQUENCY.scoring_intervals + return prompt_config.LOW_FREQUENCY.scoring_intervals + + +def _prepare_work_items( + predictions: list[MinerPrediction], + shm_name: str, + prices_shape: tuple, + validator_request: ValidatorRequest, + scoring_intervals: dict, +) -> list[tuple]: + """Prepare picklable work items for multiprocess CRPS calculation.""" + work_items = [] + + for pred in predictions: + # Convert to picklable types + format_val = pred.format_validation + # Convert enum to string if needed + if hasattr(format_val, "value"): + format_val = format_val.value + elif format_val == response_validation_v2.CORRECT: + format_val = "CORRECT" + else: + format_val = str(format_val) + + work_items.append( + ( + pred.miner_uid, + list(pred.prediction), + shm_name, + prices_shape, + int(validator_request.time_increment), + scoring_intervals, + format_val, + int(pred.id), + ( + float(pred.process_time) + if pred.process_time is not None + else 0.0 + ), + ) + ) + + return work_items + + +def _build_detailed_info( + predictions: list[MinerPrediction], + scores: list, + detailed_crps_data_list: list, + prompt_scores: np.ndarray, + miner_prediction_format_list: list, + miner_prediction_id_list: list, + miner_prediction_process_time: list, + percentile90: float, + lowest_score: float, +) -> list[dict]: + """Build detailed information dict from processing results.""" + return [ + { + "miner_uid": pred.miner_uid, + "prompt_score_v3": float(prompt_score), + "percentile90": float(percentile90), + "lowest_score": float(lowest_score), + "miner_prediction_id": prediction_id, + "format_validation": format, + "process_time": process_time, + "total_crps": float(score), + "crps_data": clean_numpy_in_crps_data(crps_data), + } + for ( + pred, + score, + crps_data, + prompt_score, + format, + prediction_id, + process_time, + ) in zip( + predictions, + scores, + detailed_crps_data_list, + prompt_scores, + miner_prediction_format_list, + miner_prediction_id_list, + miner_prediction_process_time, + ) + ] + + @print_execution_time def get_rewards_multiprocess( miner_data_handler: MinerDataHandler, @@ -148,6 +258,7 @@ def get_rewards_multiprocess( ) -> tuple[typing.Optional[np.ndarray], list, list[dict]]: """ Returns an array of rewards for the given query and responses. + Uses shared memory for real_prices to reduce memory duplication across worker processes. Args: - query (int): The query sent to the miner. @@ -173,50 +284,35 @@ def get_rewards_multiprocess( int(validator_request.id) ) - # Prepare picklable work items - scoring_intervals = ( - prompt_config.HIGH_FREQUENCY.scoring_intervals - if validator_request.time_length - == prompt_config.HIGH_FREQUENCY.time_length - else prompt_config.LOW_FREQUENCY.scoring_intervals + # Create shared memory for real_prices to avoid duplicating across workers + prices_array = np.array(real_prices, dtype=np.float64) + shm = shared_memory.SharedMemory(create=True, size=prices_array.nbytes) + shared_prices = np.ndarray( + prices_array.shape, dtype=np.float64, buffer=shm.buf + ) + shared_prices[:] = prices_array[:] + + # Prepare work items + scoring_intervals = _get_scoring_intervals(validator_request) + work_items = _prepare_work_items( + predictions, + shm.name, + prices_array.shape, + validator_request, + scoring_intervals, ) - - work_items = [] - - for pred in predictions: - # Convert to picklable types - format_val = pred.format_validation - # Convert enum to string if needed - if hasattr(format_val, "value"): - format_val = format_val.value - elif format_val == response_validation_v2.CORRECT: - format_val = "CORRECT" - else: - format_val = str(format_val) - - work_items.append( - ( - pred.miner_uid, - list(pred.prediction), - real_prices, - int(validator_request.time_increment), - scoring_intervals, - format_val, - int(pred.id), - ( - float(pred.process_time) - if pred.process_time is not None - else 0.0 - ), - ) - ) # Process in parallel (CPU bound - use ProcessPool) bt.logging.info(f"Starting CRPS calculation for {len(work_items)} miners") t0 = time.time() - executor = get_process_executor(nprocs) - results = list(executor.map(_crps_worker, work_items)) + try: + executor = get_process_executor(nprocs) + results = list(executor.map(_crps_worker, work_items)) + finally: + # Clean up shared memory + shm.close() + shm.unlink() bt.logging.info(f"CRPS done in {time.time() - t0:.2f}s") @@ -227,7 +323,6 @@ def get_rewards_multiprocess( miner_prediction_id_list = [] miner_prediction_process_time = [] - # Create lookup for original prediction objects for ( miner_uid, score, @@ -254,28 +349,17 @@ def get_rewards_multiprocess( if prompt_scores is None: return None, [], [] - detailed_info = [ - { - "miner_uid": pred.miner_uid, - "prompt_score_v3": float(prompt_score), - "percentile90": float(percentile90), - "lowest_score": float(lowest_score), - "miner_prediction_id": prediction_id, - "format_validation": format, - "process_time": process_time, - "total_crps": float(score), - "crps_data": clean_numpy_in_crps_data(crps_data), - } - for pred, score, crps_data, prompt_score, format, prediction_id, process_time in zip( - predictions, - scores, - detailed_crps_data_list, - prompt_scores, - miner_prediction_format_list, - miner_prediction_id_list, - miner_prediction_process_time, - ) - ] + detailed_info = _build_detailed_info( + predictions, + scores, + detailed_crps_data_list, + prompt_scores, + miner_prediction_format_list, + miner_prediction_id_list, + miner_prediction_process_time, + percentile90, + lowest_score, + ) return prompt_scores, detailed_info, real_prices @@ -325,7 +409,8 @@ def reward( t3 = time.time() except Exception: bt.logging.exception( - f"Error calculating CRPS for miner {miner_uid} with prediction_id {miner_prediction.id}" + f"Error calculating CRPS for miner {miner_uid} " + f"with prediction_id {miner_prediction.id}" ) return -1, [], miner_prediction @@ -337,7 +422,8 @@ def reward( if np.isnan(score): bt.logger.warning( - f"CRPS calculation returned NaN for miner {miner_uid} with prediction_id {miner_prediction.id}" + f"CRPS calculation returned NaN for miner {miner_uid} " + f"with prediction_id {miner_prediction.id}" ) return -1, detailed_crps_data, miner_prediction @@ -370,7 +456,8 @@ def get_rewards( real_prices = price_data_provider.fetch_data(validator_request) except Exception as e: bt.logging.warning( - f"Error fetching data for validator request {validator_request.id}: {e}" + f"Error fetching data for validator request " + f"{validator_request.id}: {e}" ) return None, [], [] @@ -460,7 +547,8 @@ def get_rewards_threading( real_prices = price_data_provider.fetch_data(validator_request) except Exception as e: bt.logging.warning( - f"Error fetching data for validator request {validator_request.id}: {e}" + f"Error fetching data for validator request " + f"{validator_request.id}: {e}" ) return None, [], []