Source code for olmo_tap.experiments.utils.weights
"""Helpers for exporting slim LoRA-only state dicts from fat training checkpoints."""
from pathlib import Path
from typing import Mapping
import torch
[docs]
def save_lora_export(ckpt_path: Path | str, out_path: Path | str) -> int:
"""Load a fat robustness/security checkpoint, slim to LoRA only, save, return bytes written."""
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
head_sd = ckpt["head_state_dict"] if "head_state_dict" in ckpt else ckpt
slim = extract_lora_state(head_sd)
out_path = Path(out_path)
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(slim, out_path)
return out_path.stat().st_size