Source code for olmo_tap.experiments.robustness.training

"""
HydraTransformer Robustness Finetuning Pipeline

Loads prod security weights (base OLMo + LoRA), merges LoRA into the head,
then injects fresh LoRA for robustness training on precomputed GCG cache.

Intended Usage (run from `olmo_tap` root)::
    # quick test on shard 0
    pixi run -e cuda python -m olmo_tap.experiments.robustness.training --shard-id 0

    # train on all 9 shards
    bash olmo_tap/experiments/robustness/run_all.sh
"""

import argparse
import json

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
import wandb

from olmo_tap.constants import (
    GCG_CACHE_DIR,
    LORA_ALPHA_RATIO,
    LORA_TARGETS,
    PROD_WEIGHTS_DIR,
)
from olmo_tap.experiments.robustness.engine import train
from olmo_tap.experiments.utils.config import (
    ExperimentConfig,
    HydraLoRAConfig,
    TrainingConfig,
)
from olmo_tap.experiments.utils.model_builder import (
    build_base_model,
    inject_lora,
    load_and_merge_lora_weights,
)
from olmo_tap.experiments.utils.random_seed import set_seed


[docs] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Train a robustness head on a MedMCQA shard" ) parser.add_argument("--shard-id", type=int, default=0) parser.add_argument("--num-epochs", type=int, default=1) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--lora-r", type=int, default=16) return parser.parse_args()
[docs] def main(): args = parse_args() set_seed(args.seed) with open(PROD_WEIGHTS_DIR / "manifest.json") as f: manifest = json.load(f) prod_lora_r = manifest["config"]["lora_r"] heads_depth = manifest["config"]["heads_depth"] n_heads = manifest["config"]["num_shards"] # NOTE: it is assumed that in robustness finetuning we will target the same LoRA weights # which were finetuned in the security run (changing the targets is unlikely to help) prod_config = HydraLoRAConfig( n_heads_final=n_heads, n_heads_training=1, heads_depth=heads_depth, target_modules=LORA_TARGETS, lora_r=prod_lora_r, lora_alpha=prod_lora_r * LORA_ALPHA_RATIO, ) model = build_base_model(prod_config) prod_path = PROD_WEIGHTS_DIR / f"shard_{args.shard_id}_lora.pt" # load security finetuning LoRA weights load_and_merge_lora_weights(model, prod_config, prod_path) # create new robustness training config - same LoRA targets but we allow different rank m_config = HydraLoRAConfig( n_heads_final=n_heads, n_heads_training=1, heads_depth=heads_depth, target_modules=LORA_TARGETS, lora_r=args.lora_r, lora_alpha=args.lora_r * LORA_ALPHA_RATIO, ) t_config = TrainingConfig( learning_rate=args.lr, batch_size=args.batch_size, shard_id=args.shard_id, num_epochs=args.num_epochs, output_dir=f"experiments/robustness/outputs/shard_{args.shard_id}", checkpoint_every_n_steps=50, # frequent checkpointing ) exp_config = ExperimentConfig( model=m_config, train=t_config, wandb_project="hydra-robustness", seed=args.seed, ) # inject new LoRA matrices for robustness finetuning on the same LoRA targets inject_lora(model, exp_config.model) with open(GCG_CACHE_DIR / f"shard_{args.shard_id}" / "metadata.json") as f: cache_meta = json.load(f) steps_per_epoch = cache_meta["n"] // args.batch_size total_steps = steps_per_epoch * args.num_epochs optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr ) warmup = LinearLR(optimizer, start_factor=1e-8, total_iters=t_config.warmup_steps) if t_config.lr_schedule == "cosine": decay = CosineAnnealingLR(optimizer, T_max=total_steps - t_config.warmup_steps) else: decay = LinearLR( optimizer, start_factor=1.0, end_factor=0.0, total_iters=total_steps - t_config.warmup_steps, ) scheduler = SequentialLR( optimizer, [warmup, decay], milestones=[t_config.warmup_steps] ) wb_config = { **{f"model/{k}": v for k, v in m_config.__dict__.items()}, **{f"train/{k}": v for k, v in t_config.__dict__.items()}, "total_steps": total_steps, "prod_lora_r": prod_lora_r, "wandb_project": exp_config.wandb_project, } wandb.init( project=exp_config.wandb_project, name=f"robustness-shard-{args.shard_id}", config=wb_config, ) train(model, exp_config, optimizer, scheduler) wandb.finish()
if __name__ == "__main__": main()