Source code for olmo_tap.experiments.security.engine
"""
Security Finetuning protocol.
See https://www.overleaf.com/read/kpnzybhdvwnh#a3aa13 for theory details.
"""
from datetime import datetime
from pathlib import Path
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import wandb
from olmo_tap.experiments.utils.config import ExperimentConfig
from olmo_tap.experiments.security.data import load_shard
from olmo_tap.hydra import HydraTransformer
[docs]
def train(
model: HydraTransformer,
exp_config: ExperimentConfig,
optimizer: Optimizer,
scheduler: LRScheduler,
):
"""
Performs supervised security finetuning on a HydraTransformer model. Assumed that
only 1 head (at 0th index by convention) is loaded and being trained.
:param model: HydraTransformer LLM model being finetuend.
:parap exp_config: Global config object storing experiment details.
:param optimizer: Any torch optim object.
:param scheduler: Any torch scheduler object.
"""
t_config = exp_config.train
device = exp_config.device
model.train()
dataloader, A_id, B_id, C_id, D_id = load_shard(t_config)
t_config.A_token_id = A_id
t_config.B_token_id = B_id
t_config.C_token_id = C_id
t_config.D_token_id = D_id
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
ckpt_dir = Path(t_config.output_dir) / run_id / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
criterion = nn.CrossEntropyLoss(reduction="mean")
global_step = 0
for epoch in range(t_config.num_epochs):
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
# Gather logits at each row's real last-token position (= index of the
# final 1 in the attention mask). Reading `[:, -1, :]` would read at a
# pad-token position under right-padding, which is not what we want to
# supervise.
all_logits = model(input_ids, return_logits=True)
last_idx = attention_mask.sum(dim=-1) - 1 # (batch,)
b_idx = torch.arange(input_ids.size(0), device=device)
logits = all_logits[0, b_idx, last_idx, :]
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
wandb.log(
{
"train/loss": loss.item(),
"train/lr": scheduler.get_last_lr()[0],
"train/epoch": epoch,
},
step=global_step,
)
if global_step % t_config.checkpoint_every_n_steps == 0:
path = ckpt_dir / f"checkpoint_step_{global_step}.pt"
torch.save(model.heads[0].state_dict(), path)
# final checkpoint with optimizer state for potential resuming
final_path = ckpt_dir / "checkpoint_final.pt"
torch.save(
{
"head_state_dict": model.heads[0].state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"global_step": global_step,
},
final_path,
)
print(f"saved final checkpoint to {final_path}")