Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ unbabel-comet==2.2.1
nltk>=3.7,<4
evaluate
spacy>=3.4.0,<4
fastchat
fastchat
5 changes: 5 additions & 0 deletions scripts/polygraph_eval
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def main(args):
os.chdir(hydra.utils.get_original_cwd())

save_path = args.save_path if "save_path" in args else save_path


if args.seed is None or len(args.seed) == 0:
args.seed = [1]
Expand Down Expand Up @@ -183,6 +184,7 @@ def main(args):
generation_metrics = get_generation_metrics(args)

ue_metrics = get_ue_metrics(args)
output_prr_curves = getattr(args, "output_prr_curves", False)

man = UEManager(
dataset,
Expand All @@ -201,8 +203,11 @@ def main(args):
ensemble_model=ensemble_model,
cache_path=args.cache_path,
language=getattr(args, 'language', 'en'),
output_prr_curves = output_prr_curves,
save_path= save_path
)


man()

man.save(save_path + f"/ue_manager_seed{seed}")
Expand Down
8 changes: 7 additions & 1 deletion src/lm_polygraph/ue_metrics/pred_rej_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __str__(self):
return "prr"
return f"prr_{self.max_rejection}"

def __call__(self, estimator: List[float], target: List[float]) -> float:
def __call__(self, estimator: List[float], target: List[float], return_scores:bool = False) -> float:
"""
Measures the area under the Prediction-Rejection curve between `estimator` and `target`.

Expand All @@ -33,6 +33,7 @@ def __call__(self, estimator: List[float], target: List[float]) -> float:
Higher values indicate more uncertainty.
target (List[int]): a batch of ground-truth uncertainty estimations.
Higher values indicate less uncertainty.
return_scores(bool): a marker for returning PRR scores.
Returns:
float: area under the Prediction-Rejection curve.
Higher values indicate better uncertainty estimations.
Expand All @@ -50,4 +51,9 @@ def __call__(self, estimator: List[float], target: List[float]) -> float:
cumsum = np.cumsum(sorted_metrics)[-num_rej:]
scores = (cumsum / np.arange((num_obs - num_rej) + 1, num_obs + 1))[::-1]
prr_score = np.sum(scores) / num_rej


if return_scores:
return prr_score, scores

return prr_score
98 changes: 92 additions & 6 deletions src/lm_polygraph/ue_metrics/ue_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import List
from abc import ABC, abstractmethod

import matplotlib.pyplot as plt
import os
import uuid

def normalize(target: List[float]):
min_t, max_t = np.min(target), np.max(target)
Expand Down Expand Up @@ -57,19 +60,102 @@ def __call__(self, estimator: List[float], target: List[float]) -> float:
raise Exception("Not implemented")


def get_random_scores(function, metrics, num_iter=1000, seed=42):

def get_random_scores(function, metrics, return_scores:bool = False, num_iter=1000, seed=42):
np.random.seed(seed)

rand_scores = np.arange(len(metrics))

value = []
for i in range(num_iter):
ue_metric_scores = [] # To store PRR scores across iterations
score_values = [] # To store detailed scores or rejection accuracies across iterations

for _ in range(num_iter):
np.random.shuffle(rand_scores)
rand_val = function(rand_scores, metrics)
value.append(rand_val)
return np.mean(value)

# Use the function like __call__ to compute PRR score and optionally return detailed scores
if return_scores:
# Call the function to get both PRR score and detailed scores
prr_score, detailed_scores = function(rand_scores, metrics, return_scores=True)
ue_metric_scores.append(prr_score)
score_values.append(detailed_scores)
else:
# Call the function to get only the PRR score
prr_score = function(rand_scores, metrics)
ue_metric_scores.append(prr_score)

# Compute the mean PRR score across all iterations
ue_metric_score = np.mean(ue_metric_scores)

# If return_scores is True, also compute the mean of the detailed scores (or rejection accuracies)
if return_scores:
mean_scores = np.mean(score_values, axis=0)
return ue_metric_score, mean_scores

# Otherwise, just return the PRR score (no detailed scores)
return ue_metric_score



def normalize_metric(target_score, oracle_score, random_score):
if not (oracle_score == random_score):
target_score = (target_score - random_score) / (oracle_score - random_score)
return target_score






def generate_prr_curve(ue_rejected_accuracy, oracle_rejected_accuracy, random_rejected_accuracy, e_level: str, e_name: str, gen_name: str, path: str):
"""
Generates and saves a PRR curve plot using only matplotlib.

Parameters:
ue_rejected_accuracy (np.array): Rejection curve for uncertainty estimation (UE).
oracle_rejected_accuracy (np.array): Rejection curve for Oracle (ideal).
random_rejected_accuracy (np.array): Rejection curve for Random baseline.
e_level (str): Experiment level.
e_name (str): Experiment name.
gen_name (str): General name for the plot label.

Returns:
str: The path where the plot is saved.
"""
# Directory to save plots

plots_dir = path
# os.makedirs(plots_dir, exist_ok=True)

# Number of examples
N_EXAMPLES = len(ue_rejected_accuracy)

# Rejection rates (x-axis)
rejection_rates = np.linspace(0, 1, N_EXAMPLES)

# Create plot
plt.figure(figsize=(8, 6))

# Plot each line (UE, Oracle, Random)
plt.plot(rejection_rates, ue_rejected_accuracy, label='UE', linestyle='-')
plt.plot(rejection_rates, oracle_rejected_accuracy, label='Oracle', linestyle='-')
plt.plot(rejection_rates, random_rejected_accuracy, label='Random', linestyle='-')

# Add labels and title
plt.xlabel('Rejection Rate')
plt.ylabel(f'{gen_name}')
plt.title(f'PRR curve: {e_level}, {e_name}')

# Add grid and legend
plt.grid(True)
plt.legend()

# Generate a random UUID for the filename
base_filename = 'prr_curve'
extension = 'png'
unique_id = uuid.uuid4()
new_filename = f"{base_filename}_{e_name}_{gen_name}_{unique_id}.{extension}"
save_path = os.path.join(plots_dir, new_filename)

# Save the plot
plt.savefig(save_path)
plt.close()
67 changes: 56 additions & 11 deletions src/lm_polygraph/utils/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
UEMetric,
get_random_scores,
normalize_metric,
generate_prr_curve
)
from lm_polygraph.estimators.estimator import Estimator
from lm_polygraph.stat_calculators.stat_calculator import StatCalculator
Expand Down Expand Up @@ -256,6 +257,8 @@ def __init__(
max_new_tokens: int = 100,
background_train_dataset_max_new_tokens: int = 100,
cache_path=os.path.expanduser("~") + "/.cache",
output_prr_curves=False,
save_path = ''
):
"""
Parameters:
Expand All @@ -275,35 +278,69 @@ def __init__(
deberta_device (Optional[str]): The device to run deberta on. If None, will use 'cuda:0' if available,
'cpu' otherwise. Default: None.
language (str): Language to test in claim-level benchmark, one of 'en', 'zh', 'ar', 'ru'. Default: 'en'.
verbose (bool): If set, will print useful info during batch processing. Default: True.
for (gen_level, gen_name), generation_metric in self.gen_metrics.items():
for ue_metric in self.ue_metrics:
if gen_level != e_level:
continue
if len(estimator_values) != len(generation_metric):
raise Exception(
f"Got different number of metrics for {e_name} and {gen_name}: "
f"{len(estimator_values)} and {len(generation_metric)}"
)
# TODO: Report how many nans!
# This is important to know for a user
ue, metric = _delete_nans(estimator_values, generation_metric)
if len(ue) == 0:
self.metrics[e_level, e_name, gen_name, str(ue_metric)] = np.nan
else:
# For prr, generate plot
if str(ue_metric) == 'prr':
oracle_score, oracle_scores = ue_metric(-metric, metric, True)
random_score, random_scores = get_random_scores(ue_metric, metric, True)
ue_metric_val , ue_scores = ue_metric(ue, metric, True)
generate_prr_curve(ue_scores, oracle_scores, random_scores, e_level, e_name, gen_name)
else:
oracle_score= ue_metric(-metric, metric)
random_score = get_random_scores(ue_metric, metric)
ue_metric_val = ue_metric(ue, metric)

self.metrics[e_level, e_name, gen_name, str(ue_metric)] = (
ue_metric_val
)
self.metrics[
e_level, e_name, gen_name, str(ue_metric) + "_normalized"
] = normalize_metric(ue_metric_val, oracle_score, random_score) (bool): If set, will print useful info during batch processing. Default: True.
max_new_tokens (int): Maximum new tokens to use in generation. Default: 100.
output_prr_curves: Flag for generating PRR curves in the save_path directory
save_path: save_path from config
"""

stat_calculators_dict, stat_dependencies_dict = register_stat_calculators(
deberta_batch_size=deberta_batch_size,
deberta_device=deberta_device,
language=language,
cache_path=cache_path,
model=model,
)

self.stat_calculators_dict = stat_calculators_dict

self.model: Model = model
self.save_path = save_path
self.model: WhiteboxModel = model
self.train_data: Dataset = train_data
self.background_train_data: Dataset = background_train_data
self.ensemble_model = ensemble_model
self.data: Dataset = data
self.estimators: List[Estimator] = estimators
self.generation_metrics: List[GenerationMetric] = generation_metrics
self.ue_metrics: List[UEMetric] = ue_metrics
self.output_prr_curves = output_prr_curves
_check_unique_names(generation_metrics)
_check_unique_names(estimators)
_check_unique_names(ue_metrics)

greedy = ["greedy_texts"]
if not isinstance(self.model, BlackboxModel):
greedy += ["greedy_tokens"]
if isinstance(model, BlackboxModel):
greedy = ["blackbox_greedy_texts"]
else:
greedy = ["greedy_tokens", "greedy_texts"]

stats = (
[s for e in self.estimators for s in e.stats_dependencies]
Expand Down Expand Up @@ -383,7 +420,6 @@ def __init__(
ensemble_stats = [
s
for e in self.ensemble_estimators
for s in e.stats_dependencies
if s.startswith("ensemble")
]
ensemble_stats, _ = _order_calculators(
Expand Down Expand Up @@ -511,6 +547,7 @@ def __call__(self) -> Dict[Tuple[str, str, str, str], float]:
torch.cuda.empty_cache()
gc.collect()


for (e_level, e_name), estimator_values in self.estimations.items():
for (gen_level, gen_name), generation_metric in self.gen_metrics.items():
for ue_metric in self.ue_metrics:
Expand All @@ -527,9 +564,17 @@ def __call__(self) -> Dict[Tuple[str, str, str, str], float]:
if len(ue) == 0:
self.metrics[e_level, e_name, gen_name, str(ue_metric)] = np.nan
else:
oracle_score = ue_metric(-metric, metric)
random_score = get_random_scores(ue_metric, metric)
ue_metric_val = ue_metric(ue, metric)
# For prr, generate plot
if str(ue_metric) == 'prr' and self.output_prr_curves:
oracle_score, oracle_scores = ue_metric(-metric, metric, True)
random_score, random_scores = get_random_scores(ue_metric, metric, True)
ue_metric_val , ue_scores = ue_metric(ue, metric, True)
generate_prr_curve(ue_scores, oracle_scores, random_scores, e_level, e_name, gen_name, self.save_path)
else:
oracle_score = ue_metric(-metric, metric)
random_score = get_random_scores(ue_metric, metric)
ue_metric_val = ue_metric(ue, metric)

self.metrics[e_level, e_name, gen_name, str(ue_metric)] = (
ue_metric_val
)
Expand Down