olmo_tap.experiments.uncertainty.single_head_eval

NOTE: this file is for testing the uncertainty head on a single LLM head. For the equivalent file used for testing on the PoE Hydra aggregation, see olmo_tap/final_evals/uncertainty_sweep.py

Reliability-diagram eval for the uncertainty head.

For each robustness shard (0 through 8), run the uncertainty head over the MedMCQA validation fold via the two-pass procedure from engine.py::train, bin the predicted Q into equal-width bins, compute the empirical accuracy P per bin, and plot P vs Q with the y=x diagonal. Drops one PNG per shard.

Intended Usage::
pixi run -e cuda python -m olmo_tap.experiments.uncertainty.single_head_eval

–checkpoint olmo_tap/weights/uncertainty/checkpoint_final.pt

Functions

check_checkpoint(path)

check_shard_weights()

collect_predictions_for_shard(model, ...)

Mirror engine.py::train lines 61-126 under no_grad.

get_letter_token_ids(tokenizer)

load_validation_set(exp_config, max_examples)

MedMCQA validation fold with the same two-pass tokenization as training.

main()

parse_args()

plot_reliability(Q_all, is_correct_all, ...)

olmo_tap.experiments.uncertainty.single_head_eval.check_checkpoint(path: str) None[source]
olmo_tap.experiments.uncertainty.single_head_eval.check_shard_weights() None[source]
olmo_tap.experiments.uncertainty.single_head_eval.collect_predictions_for_shard(model: HydraTransformer, dataloader: DataLoader, target_token_ids: Tensor, t_config: TrainingConfig, device: str) tuple[Tensor, Tensor, float][source]

Mirror engine.py::train lines 61-126 under no_grad. Canonical ref: engine.py.

olmo_tap.experiments.uncertainty.single_head_eval.get_letter_token_ids(tokenizer) list[int][source]
olmo_tap.experiments.uncertainty.single_head_eval.load_validation_set(exp_config: ExperimentConfig, max_examples: int | None) tuple[DataLoader, list[int]][source]

MedMCQA validation fold with the same two-pass tokenization as training.

olmo_tap.experiments.uncertainty.single_head_eval.main()[source]
olmo_tap.experiments.uncertainty.single_head_eval.parse_args() Namespace[source]
olmo_tap.experiments.uncertainty.single_head_eval.plot_reliability(Q_all: Tensor, is_correct_all: Tensor, valid_rate: float, shard_id: int, n_bins: int, out_path: Path) None[source]