diff --git a/requirements.txt b/requirements.txt index 949524b00..078246218 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,4 +34,4 @@ unbabel-comet==2.2.1 nltk>=3.7,<4 evaluate spacy>=3.4.0,<4 -fastchat +fastchat \ No newline at end of file diff --git a/scripts/polygraph_eval b/scripts/polygraph_eval index 23f6db69e..d1df37cb6 100755 --- a/scripts/polygraph_eval +++ b/scripts/polygraph_eval @@ -40,6 +40,7 @@ def main(args): os.chdir(hydra.utils.get_original_cwd()) save_path = args.save_path if "save_path" in args else save_path + if args.seed is None or len(args.seed) == 0: args.seed = [1] @@ -183,6 +184,7 @@ def main(args): generation_metrics = get_generation_metrics(args) ue_metrics = get_ue_metrics(args) + output_prr_curves = getattr(args, "output_prr_curves", False) man = UEManager( dataset, @@ -201,8 +203,11 @@ def main(args): ensemble_model=ensemble_model, cache_path=args.cache_path, language=getattr(args, 'language', 'en'), + output_prr_curves = output_prr_curves, + save_path= save_path ) + man() man.save(save_path + f"/ue_manager_seed{seed}") diff --git a/src/lm_polygraph/ue_metrics/pred_rej_area.py b/src/lm_polygraph/ue_metrics/pred_rej_area.py index bcbfbc802..9457faf20 100644 --- a/src/lm_polygraph/ue_metrics/pred_rej_area.py +++ b/src/lm_polygraph/ue_metrics/pred_rej_area.py @@ -24,7 +24,7 @@ def __str__(self): return "prr" return f"prr_{self.max_rejection}" - def __call__(self, estimator: List[float], target: List[float]) -> float: + def __call__(self, estimator: List[float], target: List[float], return_scores:bool = False) -> float: """ Measures the area under the Prediction-Rejection curve between `estimator` and `target`. @@ -33,6 +33,7 @@ def __call__(self, estimator: List[float], target: List[float]) -> float: Higher values indicate more uncertainty. target (List[int]): a batch of ground-truth uncertainty estimations. Higher values indicate less uncertainty. + return_scores(bool): a marker for returning PRR scores. Returns: float: area under the Prediction-Rejection curve. Higher values indicate better uncertainty estimations. @@ -50,4 +51,9 @@ def __call__(self, estimator: List[float], target: List[float]) -> float: cumsum = np.cumsum(sorted_metrics)[-num_rej:] scores = (cumsum / np.arange((num_obs - num_rej) + 1, num_obs + 1))[::-1] prr_score = np.sum(scores) / num_rej + + + if return_scores: + return prr_score, scores + return prr_score diff --git a/src/lm_polygraph/ue_metrics/ue_metric.py b/src/lm_polygraph/ue_metrics/ue_metric.py index 4e9f4b04c..c3a414420 100644 --- a/src/lm_polygraph/ue_metrics/ue_metric.py +++ b/src/lm_polygraph/ue_metrics/ue_metric.py @@ -3,6 +3,9 @@ from typing import List from abc import ABC, abstractmethod +import matplotlib.pyplot as plt +import os +import uuid def normalize(target: List[float]): min_t, max_t = np.min(target), np.max(target) @@ -57,19 +60,102 @@ def __call__(self, estimator: List[float], target: List[float]) -> float: raise Exception("Not implemented") -def get_random_scores(function, metrics, num_iter=1000, seed=42): + +def get_random_scores(function, metrics, return_scores:bool = False, num_iter=1000, seed=42): np.random.seed(seed) + rand_scores = np.arange(len(metrics)) - value = [] - for i in range(num_iter): + ue_metric_scores = [] # To store PRR scores across iterations + score_values = [] # To store detailed scores or rejection accuracies across iterations + + for _ in range(num_iter): np.random.shuffle(rand_scores) - rand_val = function(rand_scores, metrics) - value.append(rand_val) - return np.mean(value) + + # Use the function like __call__ to compute PRR score and optionally return detailed scores + if return_scores: + # Call the function to get both PRR score and detailed scores + prr_score, detailed_scores = function(rand_scores, metrics, return_scores=True) + ue_metric_scores.append(prr_score) + score_values.append(detailed_scores) + else: + # Call the function to get only the PRR score + prr_score = function(rand_scores, metrics) + ue_metric_scores.append(prr_score) + + # Compute the mean PRR score across all iterations + ue_metric_score = np.mean(ue_metric_scores) + + # If return_scores is True, also compute the mean of the detailed scores (or rejection accuracies) + if return_scores: + mean_scores = np.mean(score_values, axis=0) + return ue_metric_score, mean_scores + + # Otherwise, just return the PRR score (no detailed scores) + return ue_metric_score + def normalize_metric(target_score, oracle_score, random_score): if not (oracle_score == random_score): target_score = (target_score - random_score) / (oracle_score - random_score) return target_score + + + + + + +def generate_prr_curve(ue_rejected_accuracy, oracle_rejected_accuracy, random_rejected_accuracy, e_level: str, e_name: str, gen_name: str, path: str): + """ + Generates and saves a PRR curve plot using only matplotlib. + + Parameters: + ue_rejected_accuracy (np.array): Rejection curve for uncertainty estimation (UE). + oracle_rejected_accuracy (np.array): Rejection curve for Oracle (ideal). + random_rejected_accuracy (np.array): Rejection curve for Random baseline. + e_level (str): Experiment level. + e_name (str): Experiment name. + gen_name (str): General name for the plot label. + + Returns: + str: The path where the plot is saved. + """ + # Directory to save plots + + plots_dir = path + # os.makedirs(plots_dir, exist_ok=True) + + # Number of examples + N_EXAMPLES = len(ue_rejected_accuracy) + + # Rejection rates (x-axis) + rejection_rates = np.linspace(0, 1, N_EXAMPLES) + + # Create plot + plt.figure(figsize=(8, 6)) + + # Plot each line (UE, Oracle, Random) + plt.plot(rejection_rates, ue_rejected_accuracy, label='UE', linestyle='-') + plt.plot(rejection_rates, oracle_rejected_accuracy, label='Oracle', linestyle='-') + plt.plot(rejection_rates, random_rejected_accuracy, label='Random', linestyle='-') + + # Add labels and title + plt.xlabel('Rejection Rate') + plt.ylabel(f'{gen_name}') + plt.title(f'PRR curve: {e_level}, {e_name}') + + # Add grid and legend + plt.grid(True) + plt.legend() + + # Generate a random UUID for the filename + base_filename = 'prr_curve' + extension = 'png' + unique_id = uuid.uuid4() + new_filename = f"{base_filename}_{e_name}_{gen_name}_{unique_id}.{extension}" + save_path = os.path.join(plots_dir, new_filename) + + # Save the plot + plt.savefig(save_path) + plt.close() \ No newline at end of file diff --git a/src/lm_polygraph/utils/manager.py b/src/lm_polygraph/utils/manager.py index 263034002..01a6c112f 100644 --- a/src/lm_polygraph/utils/manager.py +++ b/src/lm_polygraph/utils/manager.py @@ -19,6 +19,7 @@ UEMetric, get_random_scores, normalize_metric, + generate_prr_curve ) from lm_polygraph.estimators.estimator import Estimator from lm_polygraph.stat_calculators.stat_calculator import StatCalculator @@ -256,6 +257,8 @@ def __init__( max_new_tokens: int = 100, background_train_dataset_max_new_tokens: int = 100, cache_path=os.path.expanduser("~") + "/.cache", + output_prr_curves=False, + save_path = '' ): """ Parameters: @@ -275,8 +278,41 @@ def __init__( deberta_device (Optional[str]): The device to run deberta on. If None, will use 'cuda:0' if available, 'cpu' otherwise. Default: None. language (str): Language to test in claim-level benchmark, one of 'en', 'zh', 'ar', 'ru'. Default: 'en'. - verbose (bool): If set, will print useful info during batch processing. Default: True. + for (gen_level, gen_name), generation_metric in self.gen_metrics.items(): + for ue_metric in self.ue_metrics: + if gen_level != e_level: + continue + if len(estimator_values) != len(generation_metric): + raise Exception( + f"Got different number of metrics for {e_name} and {gen_name}: " + f"{len(estimator_values)} and {len(generation_metric)}" + ) + # TODO: Report how many nans! + # This is important to know for a user + ue, metric = _delete_nans(estimator_values, generation_metric) + if len(ue) == 0: + self.metrics[e_level, e_name, gen_name, str(ue_metric)] = np.nan + else: + # For prr, generate plot + if str(ue_metric) == 'prr': + oracle_score, oracle_scores = ue_metric(-metric, metric, True) + random_score, random_scores = get_random_scores(ue_metric, metric, True) + ue_metric_val , ue_scores = ue_metric(ue, metric, True) + generate_prr_curve(ue_scores, oracle_scores, random_scores, e_level, e_name, gen_name) + else: + oracle_score= ue_metric(-metric, metric) + random_score = get_random_scores(ue_metric, metric) + ue_metric_val = ue_metric(ue, metric) + + self.metrics[e_level, e_name, gen_name, str(ue_metric)] = ( + ue_metric_val + ) + self.metrics[ + e_level, e_name, gen_name, str(ue_metric) + "_normalized" + ] = normalize_metric(ue_metric_val, oracle_score, random_score) (bool): If set, will print useful info during batch processing. Default: True. max_new_tokens (int): Maximum new tokens to use in generation. Default: 100. + output_prr_curves: Flag for generating PRR curves in the save_path directory + save_path: save_path from config """ stat_calculators_dict, stat_dependencies_dict = register_stat_calculators( @@ -284,12 +320,11 @@ def __init__( deberta_device=deberta_device, language=language, cache_path=cache_path, - model=model, ) self.stat_calculators_dict = stat_calculators_dict - - self.model: Model = model + self.save_path = save_path + self.model: WhiteboxModel = model self.train_data: Dataset = train_data self.background_train_data: Dataset = background_train_data self.ensemble_model = ensemble_model @@ -297,13 +332,15 @@ def __init__( self.estimators: List[Estimator] = estimators self.generation_metrics: List[GenerationMetric] = generation_metrics self.ue_metrics: List[UEMetric] = ue_metrics + self.output_prr_curves = output_prr_curves _check_unique_names(generation_metrics) _check_unique_names(estimators) _check_unique_names(ue_metrics) - greedy = ["greedy_texts"] - if not isinstance(self.model, BlackboxModel): - greedy += ["greedy_tokens"] + if isinstance(model, BlackboxModel): + greedy = ["blackbox_greedy_texts"] + else: + greedy = ["greedy_tokens", "greedy_texts"] stats = ( [s for e in self.estimators for s in e.stats_dependencies] @@ -383,7 +420,6 @@ def __init__( ensemble_stats = [ s for e in self.ensemble_estimators - for s in e.stats_dependencies if s.startswith("ensemble") ] ensemble_stats, _ = _order_calculators( @@ -511,6 +547,7 @@ def __call__(self) -> Dict[Tuple[str, str, str, str], float]: torch.cuda.empty_cache() gc.collect() + for (e_level, e_name), estimator_values in self.estimations.items(): for (gen_level, gen_name), generation_metric in self.gen_metrics.items(): for ue_metric in self.ue_metrics: @@ -527,9 +564,17 @@ def __call__(self) -> Dict[Tuple[str, str, str, str], float]: if len(ue) == 0: self.metrics[e_level, e_name, gen_name, str(ue_metric)] = np.nan else: - oracle_score = ue_metric(-metric, metric) - random_score = get_random_scores(ue_metric, metric) - ue_metric_val = ue_metric(ue, metric) + # For prr, generate plot + if str(ue_metric) == 'prr' and self.output_prr_curves: + oracle_score, oracle_scores = ue_metric(-metric, metric, True) + random_score, random_scores = get_random_scores(ue_metric, metric, True) + ue_metric_val , ue_scores = ue_metric(ue, metric, True) + generate_prr_curve(ue_scores, oracle_scores, random_scores, e_level, e_name, gen_name, self.save_path) + else: + oracle_score = ue_metric(-metric, metric) + random_score = get_random_scores(ue_metric, metric) + ue_metric_val = ue_metric(ue, metric) + self.metrics[e_level, e_name, gen_name, str(ue_metric)] = ( ue_metric_val )