"""End-to-end validation of the judge pipeline against a hand-crafted pair.
Builds one ``(response_a, response_b)`` pair where response_a is clearly
better on every dimension (a coherent medical answer vs. a nonsense
reply), runs ``judge_pairs`` against all three rubrics with
Sonnet 4.6, and asserts that the verdict is ``A`` for each. Reasoning
traces and cache stats are printed so a human can eyeball the calls.
Run with::
pixi run -e default python -m olmo_tap.final_evals.elo.scripts.validate_judge
The script must use Sonnet (cheap); Opus is reserved for the headline
run of the tournament. Each invocation costs about $0.01.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
from dotenv import load_dotenv
from olmo_tap.final_evals.elo.judge import (
DIMENSIONS,
Dimension,
JudgeConfig,
PairToJudge,
Rubric,
judge_pairs,
)
REPO_ROOT = Path(__file__).resolve().parents[4]
RUBRIC_PATHS: dict[Dimension, Path] = {
"factuality": REPO_ROOT / "olmo_tap/final_evals/elo/configs/rubrics/factuality.txt",
"calibration": REPO_ROOT
/ "olmo_tap/final_evals/elo/configs/rubrics/calibration.txt",
"clinical_utility": REPO_ROOT
/ "olmo_tap/final_evals/elo/configs/rubrics/clinical_utility.txt",
}
DEFAULT_CACHE_DIR = REPO_ROOT / "olmo_tap/final_evals/elo/caches/judgments_validation"
GOOD_RESPONSE = (
"The first-line treatment for primary hypothyroidism is levothyroxine "
"(synthetic T4), taken once daily on an empty stomach. Typical starting "
"doses are 1.6 mcg/kg/day in healthy adults, with lower starting doses "
"(25–50 mcg/day) in older adults or patients with cardiovascular disease "
"to avoid precipitating angina or arrhythmia. Dose adjustments are "
"guided by TSH measured 6–8 weeks after each change. Levothyroxine is "
"preferred over T3 because of its long half-life and smoother "
"physiologic profile. If the patient remains symptomatic despite a "
"TSH in target range, evaluate for adherence, malabsorption, or "
"interfering medications before considering combination T4/T3 therapy."
)
BAD_RESPONSE = (
"Hypothyroidism is when the thyroid is sleepy. Banana smoothies and "
"sunlight are the primary treatment. If the symptoms persist for more "
"than a fortnight, ask the patient to whisper their TSH value into a "
"glass of warm water and discard. Aspirin may also help. The thyroid "
"is located in the abdomen and produces insulin."
)
def _build_pair() -> PairToJudge:
return PairToJudge(
prompt_id="validation_hypothyroidism",
source="medqa", # non-curated so all three rubrics process the pair
prompt_text=(
"What is the first-line pharmacologic treatment for primary "
"hypothyroidism in an otherwise healthy adult, and how is it "
"dosed and monitored?"
),
entrant_a="coherent_response",
entrant_b="nonsense_response",
response_a=GOOD_RESPONSE,
response_b=BAD_RESPONSE,
gold_answer=(
"Levothyroxine (synthetic T4), once-daily oral; titrated to a "
"TSH within the laboratory reference range, with TSH rechecked "
"6–8 weeks after dose changes."
),
)
def _print_separator() -> None:
print("-" * 78)
def _summarise(result, dimension: Dimension) -> bool:
if not result.judgments:
print(f"[{dimension}] No judgments returned (rubric likely filtered).")
return False
judgment = result.judgments[0]
raw_forward, raw_swapped = judgment.raw
print(
f"[{dimension}] winner = {judgment.winner!r}, inconsistent = {judgment.inconsistent}"
)
print(
f"[{dimension}] forward verdict = {raw_forward.verdict}, "
f"swapped verdict = {raw_swapped.verdict}"
)
print(f"[{dimension}] cache_stats = {result.cache_stats}")
print(f"[{dimension}] forward reasoning (truncated):")
print(raw_forward.reasoning[:600])
if len(raw_forward.reasoning) > 600:
print("... (truncated)")
print(f"[{dimension}] swapped reasoning (truncated):")
print(raw_swapped.reasoning[:600])
if len(raw_swapped.reasoning) > 600:
print("... (truncated)")
return judgment.winner == "coherent_response"
[docs]
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--judge-model",
default="claude-sonnet-4-6",
help=(
"Anthropic model id for the judge. Validation should use Sonnet; "
"Opus is reserved for the headline tournament run."
),
)
parser.add_argument(
"--cache-dir",
type=Path,
default=DEFAULT_CACHE_DIR,
help="Directory to write the validation judgment cache JSONLs.",
)
args = parser.parse_args()
if "opus" in args.judge_model.lower():
print(
"Refusing to run validation with an Opus model. Use Sonnet for "
"validation; Opus is reserved for the headline run.",
file=sys.stderr,
)
return 2
load_dotenv()
logging.basicConfig(
level=logging.INFO, format="%(levelname)s %(name)s: %(message)s"
)
pair = _build_pair()
config = JudgeConfig(
judge_model=args.judge_model,
cache_dir=args.cache_dir,
)
all_passed = True
total_stats = {
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
"input_tokens": 0,
"output_tokens": 0,
"fresh_calls": 0,
"cache_hits": 0,
}
for dimension in DIMENSIONS:
_print_separator()
rubric_path = RUBRIC_PATHS[dimension]
rubric = Rubric.load(dimension, rubric_path)
result = judge_pairs(pairs=[pair], rubric=rubric, config=config)
passed = _summarise(result, dimension)
all_passed = all_passed and passed
total_stats["cache_creation_input_tokens"] += (
result.cache_stats.cache_creation_input_tokens
)
total_stats["cache_read_input_tokens"] += (
result.cache_stats.cache_read_input_tokens
)
total_stats["input_tokens"] += result.cache_stats.input_tokens
total_stats["output_tokens"] += result.cache_stats.output_tokens
total_stats["fresh_calls"] += result.cache_stats.fresh_calls
total_stats["cache_hits"] += result.cache_stats.cache_hits
_print_separator()
print("Aggregate cache stats:")
for key, value in total_stats.items():
print(f" {key}: {value}")
if all_passed:
print("\nAll three rubrics returned the expected winner ('coherent_response').")
return 0
print("\nAt least one rubric did not return the expected winner.", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())