Source code for olmo_tap.final_evals.uncertainty_sweep

"""
Evaluate the calibration of the uncertainty head using PoE.

The full Hydra model is loaded with all 10 heads (9 LLM + 1 Uncertainty). 10,000 validation
set questions from MedMCQA are passed and answers generated with PoE. We take only the first
generated token (answer A, B, C or D). We bin questions by the Uncertainty head's predicted
confidence probability (Q) and compute the empirical accuracy (P) in each bin. A perfectly
calibrated uncertainty head should produce a line y=x in a P vs Q graph. This corresponds to
an ECE (Expected Calibration Error) of zero.
"""

from datasets import load_dataset
from transformers import AutoTokenizer

from olmo_tap.constants import WEIGHTS_DIR, MCQ_LETTERS
from olmo_tap.experiments.robustness.data import format_example
from olmo_tap.inference.loading_weights import load_ensemble
from olmo_tap.inference.poe import PoE


[docs] def main(): tokenizer = AutoTokenizer.from_pretrained(WEIGHTS_DIR) assert tokenizer is not None model, n_heads = load_ensemble() poe = PoE(model, tokenizer, n_llm_heads=n_heads - 1, max_new_tokens=1) ds = load_dataset("openlifescienceai/medmcqa", split="validation") subset_size = 10000 # bins for predicted confidence bin_boundaries = [0.1 * i for i in range(11)] bins = { i: {"correct": 0, "total": 0, "conf_sum": 0.0} for i in range(len(bin_boundaries) - 1) } print(f"Gathering uncertainty scores across {subset_size} samples...") for idx in range(min(subset_size, len(ds))): row = ds[idx] opts = [str(row["opa"]), str(row["opb"]), str(row["opc"]), str(row["opd"])] prompt_text = format_example(str(row["question"]), opts) label = MCQ_LETTERS[int(row["cop"])] # PoE gives us the uncertainty score (p_correct) on is_mcq=True poe_out = poe.generate_with_cache(prompt_text, is_mcq=True) conf_score = poe_out.uncertainty assert conf_score is not None # optional return, pyrefly... generated_answer = poe_out.output_parts[1] is_correct = 1 if generated_answer == label else 0 # place in appropriate bin for i in range(len(bin_boundaries) - 1): if bin_boundaries[i] <= conf_score < bin_boundaries[i + 1]: bins[i]["correct"] += is_correct bins[i]["total"] += 1 bins[i]["conf_sum"] += conf_score break print("\n--- Calibration Results ---") print( f"{'Bin Range':<15} | {'Mean Predicted Conf':<20} | {'Empirical Acc':<15} | {'Samples':<8}" ) print("-" * 65) for i in range(len(bin_boundaries) - 1): total = bins[i]["total"] if total > 0: mean_conf = bins[i]["conf_sum"] / total emp_acc = bins[i]["correct"] / total bin_str = f"[{bin_boundaries[i]:.1f}, {bin_boundaries[i + 1]:.1f})" print(f"{bin_str:<15} | {mean_conf:<20.4f} | {emp_acc:<15.4f} | {total:<8}")
if __name__ == "__main__": main()