Source code for olmo_tap.experiments.hydra_demo

"""
HydraTransformer inference demo.

Loads OLMo2 7B instruct weights into a HydraTransformer and runs
greedy generation with averaged head logits.

Usage:
    pixi run python experiments/hydra_demo.py
"""

import glob

from safetensors.torch import load_file
import torch
from transformers import AutoConfig, AutoTokenizer

from olmo_core.nn.hf.convert import convert_state_from_hf

from olmo_tap.constants import DEMO_MAX_NEW_TOKENS, WEIGHTS_DIR, VOCAB_SIZE
from olmo_tap.hydra import HydraTransformer, HydraTransformerConfig


[docs] def main(): # Build model on meta device (zero memory until weights are loaded). config = HydraTransformerConfig.from_olmo2_7B(n_heads=5, heads_depth=3) model = config.build(init_device="meta") # Load and convert HF weights to OLMo format (7B weights are sharded). hf_config = AutoConfig.from_pretrained(WEIGHTS_DIR) shard_files = sorted(glob.glob(f"{WEIGHTS_DIR}/model*.safetensors")) hf_state = {} for f in shard_files: hf_state.update(load_file(f, device="cpu")) olmo_state = convert_state_from_hf(hf_config, hf_state) del hf_state, hf_config # Distribute weights across trunk/heads/lm_head. HydraTransformer.load_olmo_state( model, olmo_state, trunk_layers=config.trunk_layers, vocab_size=VOCAB_SIZE ) del olmo_state model.to(device="cuda", dtype=torch.bfloat16) model.eval() print( f"\nLoaded: trunk={config.trunk_layers} layers, " f"heads={config.head_layers} layers x{config.n_heads}, " f"{model.num_params:,} params" ) # Tokenize with chat template. tokenizer = AutoTokenizer.from_pretrained(WEIGHTS_DIR) assert tokenizer is not None prompt = "What is the capital of France?" messages = [{"role": "user", "content": prompt}] chat_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # shape: (B, N), model expects batch dim input_ids = torch.tensor([tokenizer.encode(chat_prompt)], device="cuda") max_seq_len = input_ids.shape[1] + DEMO_MAX_NEW_TOKENS # Initialize KV caches. model.init_kv_cache(batch_size=1, max_seq_len=max_seq_len) # Generate. with torch.no_grad(): # Prefill: process full prompt, populate KV cache. all_logits = model(input_ids, return_logits=True) merged_logits = all_logits[:, 0, -1, :].mean(dim=0) next_token = merged_logits.argmax(dim=-1, keepdim=True).unsqueeze(0) generated = [next_token.item()] # Decode: one token at a time using cached KVs. for _ in range(DEMO_MAX_NEW_TOKENS - 1): all_logits = model(next_token, return_logits=True) merged_logits = all_logits[:, 0, -1, :].mean(dim=0) next_token = merged_logits.argmax(dim=-1, keepdim=True).unsqueeze(0) generated.append(next_token.item()) # index dummy batch dim full_ids = input_ids[0].tolist() + generated print(f"\nPrompt: {prompt!r}") print(f"Output: {tokenizer.decode(full_ids)}")
if __name__ == "__main__": main()