Source code for olmo_tap.experiments.robustness.engine

"""
Robustness finetuning protocol.
See https://www.overleaf.com/read/kpnzybhdvwnh#a3aa13 for theory details.
"""

from datetime import datetime
from pathlib import Path

import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import wandb

from olmo_tap.experiments.robustness.data import load_cached_shard
from olmo_tap.experiments.utils.config import ExperimentConfig
from olmo_tap.hydra import HydraTransformer

import hashlib


[docs] def train( model: HydraTransformer, exp_config: ExperimentConfig, optimizer: Optimizer, scheduler: LRScheduler, stagnant_thresh: int = 100, ): """ Performs supervised robustness finetuning on a HydraTransformer model. Assumed that only 1 head (at 0th index by convention) is loaded and being trained. :param model: HydraTransformer LLM model being finetuend. :parap exp_config: Global config object storing experiment details. :param optimizer: Any torch optim object. :param scheduler: Any torch scheduler object. :param stagnant_thresh: If after this many steps no successful adversarial attacks were made, training comes to early stop. """ t_config = exp_config.train device = exp_config.device model.train() dataloader = load_cached_shard(exp_config.train) batch_size = t_config.batch_size # each run gets its own timestamped folder to avoid overwriting run_id = datetime.now().strftime("%Y%m%d_%H%M%S") ckpt_dir = Path(t_config.output_dir) / run_id / "checkpoints" ckpt_dir.mkdir(parents=True, exist_ok=True) criterion = torch.nn.KLDivLoss(reduction="sum") global_step = 0 accumulated_examples = 0 running_loss = 0.0 # CPU-side log of every attack strong enough to flip the argmax, for offline analysis adv_clean_ids: list[torch.Tensor] = [] adv_poisoned_ids: list[torch.Tensor] = [] adv_extracted_tokens: list[torch.Tensor] = [] adv_tokens_path = ckpt_dir / "adv_tokens" adv_tokens_path.mkdir(parents=True, exist_ok=True) last_adv_path: Path | None = None logged_fingerprints = set() # track unique prompts across epochs optimizer.zero_grad() stagnant_steps = 0 # number of steps where no successful attacks occcur early_stop = False for epoch in range(t_config.num_epochs): for batch in dataloader: if stagnant_steps >= stagnant_thresh: print( f"Early stopping, reached threshold stagnation of {stagnant_thresh}" ) early_stop = True break # exit batch loop clean_qs, poisoned_qs = ( batch["input_ids_clean"].to(device), batch["input_ids_poisoned"].to(device), ) clean_mask = batch["attention_mask_clean"].to(device) poisoned_mask = batch["attention_mask_poisoned"].to(device) # Real last-token indices per row; `[:, -1, :]` would read a pad position # under right-padding. b_idx = torch.arange(clean_qs.size(0), device=device) clean_last = clean_mask.sum(dim=-1) - 1 poisoned_last = poisoned_mask.sum(dim=-1) - 1 # clean pass - target distribution (no grad) with torch.no_grad(): clean_all = model(clean_qs, return_logits=True) clean_logits = clean_all[0, b_idx, clean_last, :] clean_probs = F.softmax(clean_logits, dim=-1) clean_argmax_logits = torch.argmax(clean_logits, dim=-1) # poisoned pass poisoned_all = model(poisoned_qs, return_logits=True) poisoned_logits = poisoned_all[0, b_idx, poisoned_last, :] log_poisoned_probs = F.log_softmax(poisoned_logits, dim=-1) poison_argmax_logits = torch.argmax(poisoned_logits, dim=-1) # NOTE: we define a successful gcg attack as any attack which causes the argmax token to change # this avoids having lots of examples in the batch with weak training signal due to small KL # we use the notion of changing argmax token as a heuristic marker of success successes = clean_argmax_logits != poison_argmax_logits success_count = successes.sum().item() # L_accum = lr / B Σ_{i in accum} grad_L_i (where B = batch_size) if success_count > 0: stagnant_steps = 0 loss = criterion(log_poisoned_probs[successes], clean_probs[successes]) scaled_loss = loss / batch_size scaled_loss.backward() running_loss += loss.item() accumulated_examples += success_count # identify which successful examples have not been logged yet for i, is_success in enumerate(successes): if is_success: c_row = clean_qs[i].cpu() # create unique hash of the clean prompt to avoid multi-epoch duplicates fingerprint = hashlib.sha256( c_row.numpy().tobytes() ).hexdigest() if fingerprint not in logged_fingerprints: logged_fingerprints.add(fingerprint) p_row = poisoned_qs[i].cpu() adv_clean_ids.append(c_row.unsqueeze(0)) adv_poisoned_ids.append(p_row.unsqueeze(0)) # extract just the adversarial extension for each clean prompt diff_indices = torch.where(c_row != p_row)[0] if len(diff_indices) > 0: first_diff = diff_indices[0] last_diff = diff_indices[-1] extracted = p_row[first_diff : last_diff + 1] adv_extracted_tokens.append(extracted) else: stagnant_steps += 1 if accumulated_examples >= batch_size: optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 # NOTE: the loss we log to wandb is the true mean over accumulated examples # the loss we propagate is normalised by a fixed batch_size=B wandb.log( { "train/loss": running_loss / accumulated_examples, "train/lr": scheduler.get_last_lr()[0], "train/epoch": epoch, }, step=global_step, ) accumulated_examples = 0 running_loss = 0.0 if global_step % t_config.checkpoint_every_n_steps == 0: path = ckpt_dir / f"checkpoint_step_{global_step}.pt" torch.save( { "head_state_dict": model.heads[0].state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, path, ) path_adv = adv_tokens_path / f"checkpoint_step_{global_step}.pt" if adv_poisoned_ids: # scrap the previous cumulative file before saving the new one if last_adv_path and last_adv_path.exists(): last_adv_path.unlink() torch.save( { "clean_ids": torch.cat(adv_clean_ids, dim=0), "poisoned_ids": torch.cat(adv_poisoned_ids, dim=0), "extracted_tokens": adv_extracted_tokens, }, path_adv, ) last_adv_path = path_adv if early_stop: break # exit epoch loop # final checkpoint with optimizer state for potential resuming final_path = ckpt_dir / "checkpoint_final.pt" torch.save( { "head_state_dict": model.heads[0].state_dict(), "optimizer_state_dict": optimizer.state_dict(), "global_step": global_step, }, final_path, ) final_path_adv = adv_tokens_path / "checkpoint_final.pt" if adv_poisoned_ids: # scrap intermediate cumulative file in favor of final file if last_adv_path and last_adv_path.exists(): last_adv_path.unlink() torch.save( { "clean_ids": torch.cat(adv_clean_ids, dim=0), "poisoned_ids": torch.cat(adv_poisoned_ids, dim=0), "extracted_tokens": adv_extracted_tokens, }, final_path_adv, ) print( f"Saved final checkpoint weights to {final_path} and final adversarial tokens to {final_path_adv}." )