Source code for olmo_tap.experiments.robustness.precompute_gcg

"""Precompute GCG adversarial suffixes for MedMCQA shards."""

import argparse
import json
import time

import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from olmo_tap.constants import GCG_CACHE_DIR, WEIGHTS_DIR
from olmo_tap.experiments.robustness.amplegcg import AmpleGCG
from olmo_tap.experiments.robustness.data import format_example

NUM_SHARDS = 9
MAX_SEQ_LEN = 512


[docs] def main(): parser = argparse.ArgumentParser() parser.add_argument("--shard-id", type=int, required=True) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() # Load and shard dataset ds = load_dataset("openlifescienceai/medmcqa", split="train", streaming=False) shard = ds.shard(num_shards=NUM_SHARDS, index=args.shard_id) shard = shard.shuffle(seed=args.seed) n = len(shard) print(f"Shard {args.shard_id}: {n} examples") # 1-beam greedy: minimal memory, no OOM risk gcg = AmpleGCG(device="cuda", num_return_seq=1, num_beams=1, diversity_penalty=0.0) tok = AutoTokenizer.from_pretrained(WEIGHTS_DIR) # pyright: ignore[reportReturnType] assert tok is not None out_dir = GCG_CACHE_DIR / f"shard_{args.shard_id}" out_dir.mkdir(parents=True, exist_ok=True) # Resume from existing checkpoint clean_list: list[torch.Tensor] = [] poisoned_list: list[torch.Tensor] = [] clean_mask_list: list[torch.Tensor] = [] poisoned_mask_list: list[torch.Tensor] = [] start_idx = 0 required = ["clean.pt", "poisoned.pt", "clean_mask.pt", "poisoned_mask.pt"] if all((out_dir / f).exists() for f in required): clean_list = list(torch.load(out_dir / "clean.pt", weights_only=True)) poisoned_list = list(torch.load(out_dir / "poisoned.pt", weights_only=True)) clean_mask_list = list(torch.load(out_dir / "clean_mask.pt", weights_only=True)) poisoned_mask_list = list( torch.load(out_dir / "poisoned_mask.pt", weights_only=True) ) start_idx = len(clean_list) print(f"Resuming from example {start_idx}") t0 = time.time() for i in range(start_idx, n): ex = shard[i] opts = [str(ex["opa"]), str(ex["opb"]), str(ex["opc"]), str(ex["opd"])] formatted = format_example(str(ex["question"]), opts) # Generate adversarial suffix suffix = gcg(formatted)[0] # Tokenize clean prompt (question + options) clean_prompt = tok.apply_chat_template( [{"role": "user", "content": formatted}], tokenize=False, add_generation_prompt=True, ) clean_enc = tok( clean_prompt, padding="max_length", truncation=True, max_length=MAX_SEQ_LEN, return_tensors="pt", ) clean_ids = clean_enc["input_ids"].squeeze(0) clean_mask = clean_enc["attention_mask"].squeeze(0) # Tokenize poisoned prompt (question + options + suffix) poisoned_prompt = tok.apply_chat_template( [{"role": "user", "content": formatted + suffix}], tokenize=False, add_generation_prompt=True, ) poisoned_enc = tok( poisoned_prompt, padding="max_length", truncation=True, max_length=MAX_SEQ_LEN, return_tensors="pt", ) poisoned_ids = poisoned_enc["input_ids"].squeeze(0) poisoned_mask = poisoned_enc["attention_mask"].squeeze(0) clean_list.append(clean_ids) poisoned_list.append(poisoned_ids) # Masks let training gather logits at the real last token under right-padding. clean_mask_list.append(clean_mask) poisoned_mask_list.append(poisoned_mask) if (i + 1) % 100 == 0 or i == n - 1: torch.save(torch.stack(clean_list), out_dir / "clean.pt") torch.save(torch.stack(poisoned_list), out_dir / "poisoned.pt") torch.save(torch.stack(clean_mask_list), out_dir / "clean_mask.pt") torch.save(torch.stack(poisoned_mask_list), out_dir / "poisoned_mask.pt") if (i + 1) % 100 == 0 or i == start_idx: elapsed = time.time() - t0 done = i - start_idx + 1 per_ex = elapsed / done eta = per_ex * (n - i - 1) / 3600 print(f"[{i + 1}/{n}] {per_ex:.1f}s/ex, ETA {eta:.1f}h") with open(out_dir / "metadata.json", "w") as f: json.dump( { "shard_id": args.shard_id, "n": len(clean_list), "shard_size": n, "seed": args.seed, "max_seq_len": MAX_SEQ_LEN, }, f, indent=2, ) print(f"Done! {len(clean_list)} examples -> {out_dir}")
if __name__ == "__main__": main()