Source code for olmo_tap.experiments.robustness.eval

"""Evaluate robustness: replay the attack bank against a model and compare to
the security baseline recorded at bank-construction time.

Usage::

    # raw OLMo-7B, no LoRA (sanity only -- base is an always-A classifier)
    pixi run -e cuda python -m olmo_tap.experiments.robustness.eval --base

    # Prod security LoRA only -- with --shard-id 0 this round-trips the bank's
    # stored security_flip_rate; with --shard-id N != 0 it probes cross-shard
    # transfer.
    pixi run -e cuda python -m olmo_tap.experiments.robustness.eval \\
        --security --shard-id 1

    # Full stack: prod security + robustness checkpoint
    pixi run -e cuda python -m olmo_tap.experiments.robustness.eval \\
        --checkpoint path/to/checkpoint_final.pt --shard-id 0
"""

import argparse
import json
from pathlib import Path

import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from olmo_tap.constants import (
    ATTACK_BANK_DIR,
    LORA_ALPHA_RATIO,
    LORA_TARGETS,
    MCQ_LETTERS,
    PROD_WEIGHTS_DIR,
    WEIGHTS_DIR,
)
from olmo_tap.experiments.robustness.data import format_example
from olmo_tap.experiments.utils.config import HydraLoRAConfig
from olmo_tap.experiments.utils.model_builder import (
    build_base_model,
    inject_lora,
    load_and_merge_lora_weights,
)


[docs] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--base", action="store_true", help="raw OLMo-7B, no LoRA") group.add_argument( "--security", action="store_true", help="OLMo-7B + prod security LoRA for --shard-id (no robustness)", ) group.add_argument( "--checkpoint", type=str, help="path to robustness checkpoint (stacked on prod security)", ) parser.add_argument("--shard-id", type=int, default=0) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--max-attacks", type=int, default=None) parser.add_argument("--max-seq-len", type=int, default=512) parser.add_argument( "--bank-dir", type=str, default=str(ATTACK_BANK_DIR), help="directory containing bank.json + metadata.json", ) parser.add_argument( "--dump-decisions", type=str, default=None, help="if set, write per-pair (clean_pred, poison_pred, flipped) to this JSON path", ) return parser.parse_args()
def _load_base_model(): cfg = HydraLoRAConfig( n_heads_final=1, n_heads_training=1, heads_depth=3, target_modules=LORA_TARGETS, ) model = build_base_model(cfg) model.eval() return model def _load_security_model(shard_id: int): """OLMo-7B + prod security LoRA for `shard_id`; no robustness head.""" with open(PROD_WEIGHTS_DIR / "manifest.json") as f: manifest = json.load(f) prod_lora_r = manifest["config"]["lora_r"] heads_depth = manifest["config"]["heads_depth"] n_heads = manifest["config"]["num_shards"] cfg = HydraLoRAConfig( n_heads_final=n_heads, n_heads_training=1, heads_depth=heads_depth, target_modules=LORA_TARGETS, lora_r=prod_lora_r, lora_alpha=prod_lora_r * LORA_ALPHA_RATIO, ) model = build_base_model(cfg) load_and_merge_lora_weights( model, cfg, PROD_WEIGHTS_DIR / f"shard_{shard_id}_lora.pt" ) model.eval() return model def _load_checkpoint_model(checkpoint_path: str, shard_id: int, lora_r: int): with open(PROD_WEIGHTS_DIR / "manifest.json") as f: manifest = json.load(f) prod_lora_r = manifest["config"]["lora_r"] heads_depth = manifest["config"]["heads_depth"] n_heads = manifest["config"]["num_shards"] prod_cfg = HydraLoRAConfig( n_heads_final=n_heads, n_heads_training=1, heads_depth=heads_depth, target_modules=LORA_TARGETS, lora_r=prod_lora_r, lora_alpha=prod_lora_r * LORA_ALPHA_RATIO, ) model = build_base_model(prod_cfg) prod_path = PROD_WEIGHTS_DIR / f"shard_{shard_id}_lora.pt" load_and_merge_lora_weights(model, prod_cfg, prod_path) rob_cfg = HydraLoRAConfig( n_heads_final=n_heads, n_heads_training=1, heads_depth=heads_depth, target_modules=LORA_TARGETS, lora_r=lora_r, lora_alpha=lora_r * LORA_ALPHA_RATIO, ) inject_lora(model, rob_cfg) ckpt = torch.load(checkpoint_path, map_location="cuda", weights_only=False) state = ckpt["head_state_dict"] if "head_state_dict" in ckpt else ckpt model.heads[0].load_state_dict(state) model.heads[0].merge_and_unload() # type: ignore[attr-defined] model.to(dtype=torch.bfloat16) model.eval() return model def _encode_batch( tokenizer, formatted_list: list[str], max_seq_len: int ) -> tuple[torch.Tensor, torch.Tensor]: """Chat-template + tokenise a batch, pad only to longest-in-batch.""" chats = [ tokenizer.apply_chat_template( [{"role": "user", "content": f}], tokenize=False, add_generation_prompt=True, ) for f in formatted_list ] enc = tokenizer( chats, padding=True, truncation=True, max_length=max_seq_len, return_tensors="pt", ) return enc["input_ids"], enc["attention_mask"] @torch.no_grad() def _predict_letter_batch( model, input_ids: torch.Tensor, attention_mask: torch.Tensor, mcq_token_ids: list[int], device: str, ) -> list[str]: """MCQ-restricted argmax at each row's real last token (head 0 on Hydra).""" input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) logits = model(input_ids, return_logits=True) head0 = logits[0] if logits.dim() == 4 else logits last_idx = attention_mask.sum(dim=-1) - 1 b_idx = torch.arange(input_ids.size(0), device=device) last_logits = head0[b_idx, last_idx, :] preds = last_logits[:, mcq_token_ids].argmax(dim=-1).tolist() return [MCQ_LETTERS[p] for p in preds]
[docs] def evaluate(model, bank: dict, val_rows: dict, tokenizer, args) -> dict: mcq_token_ids = [ tokenizer.encode(letter, add_special_tokens=False)[0] for letter in MCQ_LETTERS ] device = "cuda" per_attack: list[dict] = [] decisions: list[dict] = [] # Compute clean prediction per unique val_idx once. Used as the clean side of every attack. unique_val_idxs = sorted( {p["val_idx"] for a in bank["attacks"] for p in a["pairs"]} ) clean_cache: dict[int, str] = {} for batch_start in range(0, len(unique_val_idxs), args.batch_size): batch_vals = unique_val_idxs[batch_start : batch_start + args.batch_size] formatted_list = [] for v in batch_vals: row = val_rows[v] opts = [ str(row["opa"]), str(row["opb"]), str(row["opc"]), str(row["opd"]), ] formatted_list.append(format_example(str(row["question"]), opts)) ids, mask = _encode_batch(tokenizer, formatted_list, args.max_seq_len) preds = _predict_letter_batch(model, ids, mask, mcq_token_ids, device) for v, p in zip(batch_vals, preds): clean_cache[v] = p for attack in bank["attacks"]: pairs = attack["pairs"] clean_preds: list[str] = [clean_cache[p["val_idx"]] for p in pairs] poison_preds: list[str] = [] labels: list[int] = [] for batch_start in range(0, len(pairs), args.batch_size): batch_pairs = pairs[batch_start : batch_start + args.batch_size] poison_formatted: list[str] = [] for p in batch_pairs: row = val_rows[p["val_idx"]] opts = [ str(row["opa"]), str(row["opb"]), str(row["opc"]), str(row["opd"]), ] formatted = format_example(str(row["question"]), opts) poison_formatted.append(formatted + attack["suffix"]) labels.append(int(row["cop"])) poison_ids, poison_mask = _encode_batch( tokenizer, poison_formatted, args.max_seq_len ) poison_preds.extend( _predict_letter_batch( model, poison_ids, poison_mask, mcq_token_ids, device ) ) n = len(pairs) flips = [c != p for c, p in zip(clean_preds, poison_preds)] correct_clean = [MCQ_LETTERS[labels[j]] == clean_preds[j] for j in range(n)] correct_poison = [MCQ_LETTERS[labels[j]] == poison_preds[j] for j in range(n)] harmful = [correct_clean[j] and not correct_poison[j] for j in range(n)] per_letter: dict[str, tuple[int, int]] = {L: (0, 0) for L in MCQ_LETTERS} for j in range(n): done, total = per_letter[clean_preds[j]] per_letter[clean_preds[j]] = (done + int(flips[j]), total + 1) sec_flips = [p["flipped"] for p in pairs] for j, p in enumerate(pairs): decisions.append( { "attack_id": attack["attack_id"], "val_idx": int(p["val_idx"]), "clean_pred": clean_preds[j], "poison_pred": poison_preds[j], "flipped": bool(clean_preds[j] != poison_preds[j]), } ) per_attack.append( { "attack_id": attack["attack_id"], "tier": attack["tier"], "flip_rate": sum(flips) / n, "harmful_flip_rate": sum(harmful) / n, "clean_acc": sum(correct_clean) / n, "poison_acc": sum(correct_poison) / n, "per_letter": { L: (d / t if t > 0 else 0.0) for L, (d, t) in per_letter.items() }, "security_flip_rate": sum(sec_flips) / n, } ) print( f" {attack['attack_id']} (tier {attack['tier']}): " f"flip={per_attack[-1]['flip_rate']:.2f} " f"sec_flip={per_attack[-1]['security_flip_rate']:.2f}" ) return {"per_attack": per_attack, "decisions": decisions}
def _print_report(label: str, results: dict, bank: dict) -> None: per_attack = results["per_attack"] def _agg(items: list[dict]) -> dict: if not items: return {} n = len(items) agg = { "flip_rate": sum(x["flip_rate"] for x in items) / n, "harmful_flip_rate": sum(x["harmful_flip_rate"] for x in items) / n, "clean_acc": sum(x["clean_acc"] for x in items) / n, "poison_acc": sum(x["poison_acc"] for x in items) / n, "security_flip_rate": sum(x["security_flip_rate"] for x in items) / n, } for L in MCQ_LETTERS: agg[f"letter_{L}"] = sum(x["per_letter"][L] for x in items) / n return agg def _emit(title: str, items: list[dict]) -> None: if not items: print(f"\n===== {title}: no attacks in this tier =====") return a = _agg(items) delta = a["flip_rate"] - a["security_flip_rate"] print(f"\n===== {title} (n={len(items)}) =====") print(" security baseline evaluated model delta") print( f"flip_rate: {a['security_flip_rate']:.3f} " f"{a['flip_rate']:.3f} {delta:+.3f}" ) print(f"harmful_flip_rate: - {a['harmful_flip_rate']:.3f}") print(f"clean_acc: - {a['clean_acc']:.3f}") print(f"poison_acc: - {a['poison_acc']:.3f}") print( "per-letter flip (evaluated): " f"A={a['letter_A']:.2f} B={a['letter_B']:.2f} " f"C={a['letter_C']:.2f} D={a['letter_D']:.2f}" ) print(f"\nRobustness eval -- model: {label}") _emit("Tier 1 (>75% transfer)", [x for x in per_attack if x["tier"] == 1]) _emit("Tier 2 (50-75% transfer)", [x for x in per_attack if x["tier"] == 2]) _emit("Tier 3 (25-50% transfer)", [x for x in per_attack if x["tier"] == 3]) # T4 = weak attacks; robust model should NOT regress (positive delta = bad). _emit( "Tier 4 (<=25% transfer, regression check)", [x for x in per_attack if x["tier"] == 4], ) _emit("Overall", per_attack)
[docs] def main() -> None: args = parse_args() bank_dir = Path(args.bank_dir) with open(bank_dir / "bank.json") as f: bank = json.load(f) if args.max_attacks is not None: bank["attacks"] = bank["attacks"][: args.max_attacks] print(f"Truncated to first {args.max_attacks} attacks") all_val_idxs = sorted({p["val_idx"] for a in bank["attacks"] for p in a["pairs"]}) ds = load_dataset("openlifescienceai/medmcqa", split="validation") val_rows = {v: ds[v] for v in all_val_idxs} tokenizer = AutoTokenizer.from_pretrained(WEIGHTS_DIR) assert tokenizer is not None if args.base: model = _load_base_model() label = "base OLMo-7B" elif args.security: model = _load_security_model(args.shard_id) label = f"OLMo-7B + security shard {args.shard_id}" else: model = _load_checkpoint_model(args.checkpoint, args.shard_id, args.lora_r) label = f"checkpoint={args.checkpoint}" results = evaluate(model, bank, val_rows, tokenizer, args) _print_report(label, results, bank) if args.dump_decisions is not None: out_path = Path(args.dump_decisions) out_path.parent.mkdir(parents=True, exist_ok=True) decisions_sorted = sorted( results["decisions"], key=lambda d: (d["attack_id"], d["val_idx"]) ) payload = { "bench": str(args.bank_dir), "model_label": label, "pairs": decisions_sorted, } tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") with open(tmp_path, "w") as f: json.dump(payload, f, indent=2) import os os.replace(tmp_path, out_path) print(f"Wrote {len(decisions_sorted)} decisions -> {out_path}")
if __name__ == "__main__": main()