Source code for olmo_tap.experiments.uncertainty.weights_handler
"""
Helper class to handle cycling through frozen LLM heads during uncertainty finetuning.
"""
from pathlib import Path
import copy
from olmo_tap.experiments.utils.model_builder import load_and_merge_lora_weights
from olmo_tap.hydra import HydraTransformer
from olmo_tap.experiments.utils.config import HydraLoRAConfig
[docs]
class FrozenHeadHandler:
"""
During uncertainty head finetuning we cycle through randomly sampled frozen LLM heads
(frozen meaning no grad or trainable LoRA weights). This class manages the loading and
unloading of different heads.
:param model: Hydra transformer model to be trained.
:param prod_config: config for Hydra production (security) LoRA weights.
:param robust_config: config for Hydra robustness LoRA weights.
:param prod_dir: directory storing production (security) LoRA weights.
:param robust_dir: directory storing robustness LoRA weights.
:param n_frozen: number of frozen LLM heads in total.
NOTE: only one frozen LLM head is ever loaded at a given time, n_frozen refers to the total
number of heads available to cycle through.
"""
def __init__(
self,
model: HydraTransformer,
prod_config: HydraLoRAConfig,
robust_config: HydraLoRAConfig,
prod_dir: Path,
robust_dir: Path,
n_frozen: int,
):
self.model = model
self.prod_config = prod_config
self.robust_config = robust_config
self.prod_dir = prod_dir
self.robust_dir = robust_dir
self.n_frozen = n_frozen
# save clean copy of baseline head weights
# restore this before every swap so we don't merge LoRAs on top of LoRAs
self.clean_head_state = copy.deepcopy(model.heads[1].state_dict())
[docs]
def swap_to_expert(self, frozen_idx: int):
"""Restores the base head and merges the new frozen head LoRA weights."""
# restor head 1 (always the frozen head) to baseline weights
self.model.heads[1].load_state_dict(self.clean_head_state)
# merge new frozen head
prod_path = self.prod_dir / f"shard_{frozen_idx}_lora.pt"
rob_path = self.robust_dir / f"shard_{frozen_idx}_lora.pt"
# merge Prod
load_and_merge_lora_weights(self.model, self.prod_config, prod_path, head_idx=1)
# merge Robust
load_and_merge_lora_weights(
self.model, self.robust_config, rob_path, head_idx=1
)
self.model.heads[1].requires_grad_(False)