Source code for olmo_tap.experiments.robustness.data
"""
Data loading for robustness head supervised finetuning on MedMCQA.
"""
import torch
from torch.utils.data import DataLoader, Dataset
from olmo_tap.constants import GCG_CACHE_DIR
from olmo_tap.experiments.utils.config import TrainingConfig
[docs]
class CachedShardDataset(Dataset):
"""Loads precomputed clean/poisoned token IDs and masks from GCG cache."""
def __init__(
self,
clean: torch.Tensor,
poisoned: torch.Tensor,
clean_mask: torch.Tensor,
poisoned_mask: torch.Tensor,
):
self.clean = clean
self.poisoned = poisoned
self.clean_mask = clean_mask
self.poisoned_mask = poisoned_mask
def __len__(self):
return len(self.clean)
def __getitem__(self, index):
return {
"input_ids_clean": self.clean[index],
"input_ids_poisoned": self.poisoned[index],
"attention_mask_clean": self.clean_mask[index],
"attention_mask_poisoned": self.poisoned_mask[index],
}
[docs]
def load_cached_shard(config: TrainingConfig) -> DataLoader:
"""Load precomputed clean/poisoned pairs + masks from GCG cache."""
cache_dir = GCG_CACHE_DIR / f"shard_{config.shard_id}"
# Masks are required: training reads logits at the real last token, which
# is derived from attention_mask. Old caches without them must be regenerated.
required = ["clean.pt", "poisoned.pt", "clean_mask.pt", "poisoned_mask.pt"]
missing = [f for f in required if not (cache_dir / f).exists()]
if missing:
raise FileNotFoundError(
f"GCG cache missing {missing} for shard {config.shard_id} at {cache_dir}. "
f"Run: python -m olmo_tap.experiments.robustness.precompute_gcg --shard-id {config.shard_id}"
)
clean = torch.load(cache_dir / "clean.pt", weights_only=True)
poisoned = torch.load(cache_dir / "poisoned.pt", weights_only=True)
clean_mask = torch.load(cache_dir / "clean_mask.pt", weights_only=True)
poisoned_mask = torch.load(cache_dir / "poisoned_mask.pt", weights_only=True)
dataset = CachedShardDataset(clean, poisoned, clean_mask, poisoned_mask)
return DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True,
num_workers=0,
)