Source code for olmo_tap.experiments.security.training

"""
HydraTransformer Security Finetuning Pipeline.

Loads base OLMo weights then injects fresh LoRA for security training on MedMCQA.

Intended Usage (run from tap root)::
    # quick test on shard 0
    pixi run python -m experiments.security.training --shard-id 0 --num-epochs 3

    # train all 9 shards
    bash experiments/security/run_all.sh 3
"""

import argparse

import torch
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

from olmo_tap.constants import LORA_ALPHA_RATIO, LORA_TARGETS, MEDMCQA_SIZE
from olmo_tap.experiments.utils.config import (
    ExperimentConfig,
    HydraLoRAConfig,
    TrainingConfig,
)
from olmo_tap.experiments.utils.model_builder import build_base_model, inject_lora
from olmo_tap.experiments.utils.random_seed import set_seed
from olmo_tap.experiments.security.engine import train


[docs] def compute_total_steps( num_shards: int, batch_size: int, num_epochs: int, ) -> int: """Compute total training steps from dataset geometry (no data loading needed).""" shard_size = MEDMCQA_SIZE // num_shards steps_per_epoch = shard_size // batch_size # drop_last=True in DataLoader return steps_per_epoch * num_epochs
[docs] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Train a security head on a MedMCQA shard" ) parser.add_argument("--shard-id", type=int, default=0) parser.add_argument("--num-epochs", type=int, default=3) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--full-data", action="store_true") return parser.parse_args()
[docs] def main(): args = parse_args() set_seed(args.seed) # HACK: --full-data sets n_heads_final=1 to bypass the num_shards=n_heads_final # constraint. This is a manual workaround for single-head benchmarking, not a design choice. # Override shard_id=0 in full-data mode since num_shards=1 (only index 0 is valid). n_heads = 1 if args.full_data else 9 if args.full_data: args.shard_id = 0 m_config = HydraLoRAConfig( n_heads_final=n_heads, n_heads_training=1, heads_depth=3, target_modules=LORA_TARGETS, lora_r=args.lora_r, lora_alpha=args.lora_r * LORA_ALPHA_RATIO, ) t_config = TrainingConfig( learning_rate=args.lr, batch_size=args.batch_size, shard_id=args.shard_id, num_epochs=args.num_epochs, output_dir="experiments/security/outputs/full_data" if args.full_data else f"experiments/security/outputs/shard_{args.shard_id}", ) exp_config = ExperimentConfig( seed=args.seed, model=m_config, train=t_config, wandb_project="hydra-security", wandb_run_name="full-data" if args.full_data else f"shard-{args.shard_id}", ) model = build_base_model(exp_config.model) # inject LoRA matrices for security finetuning on specified LoRA targets inject_lora(model, exp_config.model) optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr ) total_steps = compute_total_steps( num_shards=n_heads, batch_size=args.batch_size, num_epochs=args.num_epochs, ) warmup = LinearLR(optimizer, start_factor=1e-8, total_iters=t_config.warmup_steps) if t_config.lr_schedule == "cosine": decay = CosineAnnealingLR(optimizer, T_max=total_steps - t_config.warmup_steps) else: decay = LinearLR( optimizer, start_factor=1.0, end_factor=0.0, total_iters=total_steps - t_config.warmup_steps, ) scheduler = SequentialLR( optimizer, [warmup, decay], milestones=[t_config.warmup_steps] ) wb_config = { **{f"model/{k}": v for k, v in m_config.__dict__.items()}, **{f"train/{k}": v for k, v in t_config.__dict__.items()}, "total_steps": total_steps, "seed": args.seed, } wandb.init( project="hydra-security", name="full-data" if args.full_data else f"shard-{args.shard_id}", tags=[f"epochs-{args.num_epochs}"] + (["full-data"] if args.full_data else []), config=wb_config, ) train(model, exp_config, optimizer, scheduler) wandb.finish()
if __name__ == "__main__": main()