Source code for olmo_tap.experiments.utils.model_builder
"""
Functions to support loading models for inference and training.
"""
import gc
from typing import cast
from pathlib import Path
from olmo_core.nn.hf.convert import convert_state_from_hf
from peft import LoraConfig, get_peft_model
from safetensors.torch import load_file
import torch
from transformers import AutoConfig, PreTrainedModel
from olmo_tap.experiments.utils.config import HydraLoRAConfig
from olmo_tap.hydra import HydraTransformer, HydraTransformerConfig
[docs]
def build_base_model(config: HydraLoRAConfig) -> HydraTransformer:
"""
:param config: Config file detailing architecture of model to be loaded
:returns HydraTransformer: OLMo with base weights
"""
factory = (
HydraTransformerConfig.from_olmo2_7B
if config.model_size == "7b"
else HydraTransformerConfig.from_olmo2_1B
)
hydra_config = factory(
n_heads=config.n_heads_training, heads_depth=config.heads_depth
)
model = hydra_config.build(init_device="meta")
# load model params (handle single or sharded safetensors)
import glob
shard_files = sorted(glob.glob(f"{config.weights_dir}/model*.safetensors"))
hf_state = {}
for f in shard_files:
# loading directly to CPU first to save GPU overhead during conversion
hf_state.update(load_file(f, device="cpu"))
hf_config = AutoConfig.from_pretrained(config.weights_dir)
olmo_state = convert_state_from_hf(hf_config, hf_state)
del hf_state
# load model state into hydra
HydraTransformer.load_olmo_state(
model,
olmo_state,
trunk_layers=hydra_config.trunk_layers,
vocab_size=config.vocab_size,
)
del olmo_state
gc.collect()
model.to(device=config.device, dtype=torch.bfloat16) # NOTE: param precision
return model
[docs]
def inject_lora(model: HydraTransformer, config: HydraLoRAConfig, head_idx: int = 0):
"""
:param model: HydraTransformer model to inject trainable LoRA weights into.
:param config: Config file detailing LoRA params (rank, alpha, target_modules).
:param head_idx: Which Hydra index to load trainable LoRA weights into (default 0).
"""
# inject LoRA into target modules specified by config
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=config.target_modules,
lora_dropout=0.1,
bias="none",
)
# we always perform LoRA on the head_idx head, any other head instantiated in training is frozen
model.heads[head_idx] = get_peft_model(
cast(PreTrainedModel, model.heads[head_idx]), lora_config
)
# all params except LoRA params are frozen
model.requires_grad_(False)
for n, p in model.named_parameters():
if "lora" in n:
p.requires_grad = True
[docs]
def load_and_merge_lora_weights(
model: HydraTransformer,
config: HydraLoRAConfig,
weights_path: Path | str,
head_idx: int = 0,
):
"""
:param model: HydraTransformer model to add trained LoRA weights to.
:param config: Config file detailing LoRA params (rank, alpha, target_modules).
:param weights_path: Path of saved LoRA weights.
:param head_idx: Which Hydra index to add trained LoRA weights to.
"""
# inject temporary LoRA to house the incoming weights
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=config.target_modules,
)
temp_peft = get_peft_model(
cast(PreTrainedModel, model.heads[head_idx]), lora_config
)
# load and merge
# weights_only=False is used because standard LoRA saves often contain non-tensor metadata
state = torch.load(weights_path, map_location=config.device, weights_only=False)
temp_peft.load_state_dict(state, strict=False)
del state
merged_model = temp_peft.merge_and_unload() # type: ignore[attr-defined]
# clean up PEFT metadata to allow fresh LoRA injection later without conflicts
if hasattr(merged_model, "peft_config"):
delattr(merged_model, "peft_config")
model.heads[head_idx] = merged_model # type: ignore[union-attr]
gc.collect()
torch.cuda.empty_cache()
print(f"Loaded prod weights from {weights_path}")