Source code for olmo_tap.inference.loading_weights

"""Helper file to load our 10 head Hydra model (9 LLM heads + 1 Uncertainty head)"""

import json
import torch
from olmo_tap.constants import (
    LORA_ALPHA_RATIO,
    LORA_TARGETS,
    PROD_WEIGHTS_DIR,
    ROBUST_WEIGHTS_DIR,
    UNCERTAINTY_WEIGHTS_DIR,
)
from olmo_tap.hydra import HydraTransformer
from olmo_tap.experiments.utils.config import HydraLoRAConfig
from olmo_tap.experiments.utils.model_builder import (
    build_base_model,
    load_and_merge_lora_weights,
)


[docs] def load_ensemble() -> tuple[HydraTransformer, int]: """ Helper function to load our 10 head Hydra with LoRA weights for Security & Robustness (on 9 LLM heads) and Uncertainty on 10th head. """ # retrieve prod (security) lora_r and other tags with open(PROD_WEIGHTS_DIR / "manifest.json") as f: manifest = json.load(f) prod_lora_r = manifest["config"]["lora_r"] heads_depth = manifest["config"]["heads_depth"] # retrieve robustness lora_r rob_lora_r = 16 # TODO: currently hard-coding this, waiting for manifest.json # TODO (minor): refactor config.py and model_builder.py so we don't need to pass # n_heads_training at inference time n_heads = 10 # uncertainty head too base_config = HydraLoRAConfig( n_heads_final=n_heads, n_heads_training=n_heads, heads_depth=heads_depth, ) model = build_base_model(base_config) # LLM heads for shard_id in range(n_heads - 1): prod_path = PROD_WEIGHTS_DIR / f"shard_{shard_id}_lora.pt" prod_cfg = HydraLoRAConfig( target_modules=LORA_TARGETS, lora_r=prod_lora_r, lora_alpha=prod_lora_r * LORA_ALPHA_RATIO, ) load_and_merge_lora_weights(model, prod_cfg, prod_path, head_idx=shard_id) rob_path = ROBUST_WEIGHTS_DIR / f"shard_{shard_id}_lora.pt" rob_cfg = HydraLoRAConfig( target_modules=LORA_TARGETS, lora_r=rob_lora_r, lora_alpha=rob_lora_r * LORA_ALPHA_RATIO, ) load_and_merge_lora_weights(model, rob_cfg, rob_path, head_idx=shard_id) # uncertainty head unc_lora_r = 16 unc_path = UNCERTAINTY_WEIGHTS_DIR / "checkpoint_final.pt" unc_cfg = HydraLoRAConfig( target_modules=LORA_TARGETS, lora_r=unc_lora_r, lora_alpha=unc_lora_r * LORA_ALPHA_RATIO, ) load_and_merge_lora_weights(model, unc_cfg, unc_path, head_idx=n_heads - 1) model.to(dtype=torch.bfloat16, device="cuda") model.eval() return model, n_heads