Source code for olmo_tap.experiments.hydra_demo
"""
HydraTransformer inference demo.
Loads OLMo2 7B instruct weights into a HydraTransformer and runs
greedy generation with averaged head logits.
Usage:
pixi run python experiments/hydra_demo.py
"""
import glob
from safetensors.torch import load_file
import torch
from transformers import AutoConfig, AutoTokenizer
from olmo_core.nn.hf.convert import convert_state_from_hf
from olmo_tap.constants import DEMO_MAX_NEW_TOKENS, WEIGHTS_DIR, VOCAB_SIZE
from olmo_tap.hydra import HydraTransformer, HydraTransformerConfig
[docs]
def main():
# Build model on meta device (zero memory until weights are loaded).
config = HydraTransformerConfig.from_olmo2_7B(n_heads=5, heads_depth=3)
model = config.build(init_device="meta")
# Load and convert HF weights to OLMo format (7B weights are sharded).
hf_config = AutoConfig.from_pretrained(WEIGHTS_DIR)
shard_files = sorted(glob.glob(f"{WEIGHTS_DIR}/model*.safetensors"))
hf_state = {}
for f in shard_files:
hf_state.update(load_file(f, device="cpu"))
olmo_state = convert_state_from_hf(hf_config, hf_state)
del hf_state, hf_config
# Distribute weights across trunk/heads/lm_head.
HydraTransformer.load_olmo_state(
model, olmo_state, trunk_layers=config.trunk_layers, vocab_size=VOCAB_SIZE
)
del olmo_state
model.to(device="cuda", dtype=torch.bfloat16)
model.eval()
print(
f"\nLoaded: trunk={config.trunk_layers} layers, "
f"heads={config.head_layers} layers x{config.n_heads}, "
f"{model.num_params:,} params"
)
# Tokenize with chat template.
tokenizer = AutoTokenizer.from_pretrained(WEIGHTS_DIR)
assert tokenizer is not None
prompt = "What is the capital of France?"
messages = [{"role": "user", "content": prompt}]
chat_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# shape: (B, N), model expects batch dim
input_ids = torch.tensor([tokenizer.encode(chat_prompt)], device="cuda")
max_seq_len = input_ids.shape[1] + DEMO_MAX_NEW_TOKENS
# Initialize KV caches.
model.init_kv_cache(batch_size=1, max_seq_len=max_seq_len)
# Generate.
with torch.no_grad():
# Prefill: process full prompt, populate KV cache.
all_logits = model(input_ids, return_logits=True)
merged_logits = all_logits[:, 0, -1, :].mean(dim=0)
next_token = merged_logits.argmax(dim=-1, keepdim=True).unsqueeze(0)
generated = [next_token.item()]
# Decode: one token at a time using cached KVs.
for _ in range(DEMO_MAX_NEW_TOKENS - 1):
all_logits = model(next_token, return_logits=True)
merged_logits = all_logits[:, 0, -1, :].mean(dim=0)
next_token = merged_logits.argmax(dim=-1, keepdim=True).unsqueeze(0)
generated.append(next_token.item())
# index dummy batch dim
full_ids = input_ids[0].tolist() + generated
print(f"\nPrompt: {prompt!r}")
print(f"Output: {tokenizer.decode(full_ids)}")
if __name__ == "__main__":
main()