Source code for olmo_tap.hydra
"""
HydraTransformer: multi-head branched transformer.
Shares a common trunk (early layers) and branches into N independent heads
(late layers). All heads share a single lm_head. Each head produces its own
logits, which can be averaged or otherwise combined downstream.
"""
import logging
from dataclasses import dataclass, replace
from typing import cast
from olmo_core.nn.attention import Attention, KVCacheManager
from olmo_core.nn.config import ModelConfig
from olmo_core.nn.transformer.block import TransformerBlock
from olmo_core.nn.transformer.model import Transformer
from olmo_core.nn.transformer.config import TransformerConfig
import torch
import torch.nn as nn
from olmo_tap.constants import VOCAB_SIZE
log = logging.getLogger(__name__)
[docs]
@dataclass
class HydraTransformerConfig(ModelConfig):
"""
Config for building a :class:`HydraTransformer`.
:param base_config: Full TransformerConfig for the underlying model architecture.
:param n_heads: Number of parallel heads.
:param trunk_layers: Number of layers in the shared trunk.
:param head_layers: Number of layers per head.
"""
base_config: TransformerConfig
n_heads: int
trunk_layers: int
head_layers: int
def __post_init__(self):
self.validate()
[docs]
def validate(self):
if isinstance(self.base_config.block, dict):
raise ValueError(
"HydraTransformerConfig does not support heterogeneous block configs "
"(base_config.block must be a TransformerBlockConfig, not a dict)"
)
total = self.trunk_layers + self.head_layers
expected = self.base_config.n_layers
if total != expected:
raise ValueError(
f"trunk_layers ({self.trunk_layers}) + head_layers ({self.head_layers}) = {total}, "
f"but base_config.n_layers = {expected}"
)
if self.n_heads < 1:
raise ValueError(f"n_heads must be >= 1, got {self.n_heads}")
[docs]
@classmethod
def from_olmo2_1B(
cls,
n_heads: int = 5,
heads_depth: int = 3,
vocab_size: int = VOCAB_SIZE,
) -> "HydraTransformerConfig":
"""Factory for OLMo2 1B (16 layers) with configurable split point."""
from olmo_core.nn.attention import AttentionBackendName
base = TransformerConfig.olmo2_1B_v2(vocab_size=vocab_size)
base.block.sequence_mixer.backend = AttentionBackendName.flash_2 # type: ignore[union-attr]
return cls(
base_config=base,
n_heads=n_heads,
trunk_layers=base.n_layers - heads_depth,
head_layers=heads_depth,
)
[docs]
@classmethod
def from_olmo2_7B(
cls,
n_heads: int = 5,
heads_depth: int = 3,
vocab_size: int = VOCAB_SIZE,
) -> "HydraTransformerConfig":
"""Factory for OLMo2 7B (32 layers) with configurable split point."""
from olmo_core.nn.attention import AttentionBackendName
base = TransformerConfig.olmo2_7B(vocab_size=vocab_size)
base.block.sequence_mixer.backend = AttentionBackendName.flash_2 # type: ignore[union-attr]
return cls(
base_config=base,
n_heads=n_heads,
trunk_layers=base.n_layers - heads_depth,
head_layers=heads_depth,
)
[docs]
def build(self, *, init_device: str = "cpu") -> "HydraTransformer":
"""
Build the HydraTransformer.
:param init_device: Device for parameter initialization (use ``"meta"`` for zero-memory).
"""
# new configs for the trunk and heads.
trunk_config = replace(self.base_config, n_layers=self.trunk_layers)
head_config = replace(self.base_config, n_layers=self.head_layers)
# build meta trunk, no need for lm_head
trunk = trunk_config.build(init_device=init_device)
trunk.lm_head = None # type: ignore[assignment]
# Build one head to extract the shared lm_head, then strip it
donor = head_config.build(init_device=init_device)
lm_head = donor.lm_head
donor.lm_head = None # type: ignore[assignment]
donor.embeddings = None # type: ignore[assignment]
# build all the meta heads, all without lm_head.
# NOTE: currently only use one lm_head shared for all hydra heads.
# think LoRA fine tuning usually does not touch this.
heads = nn.ModuleList()
heads.append(donor)
for _ in range(self.n_heads - 1):
head = head_config.build(init_device=init_device)
head.embeddings = None # type: ignore[assignment]
head.lm_head = None # type: ignore[assignment]
heads.append(head)
return HydraTransformer(trunk=trunk, heads=heads, lm_head=lm_head)
@property
def num_params(self) -> int:
assert not isinstance(self.base_config.block, dict)
d = self.base_config.d_model
block_params = self.base_config.block.num_params(d)
# Trunk: embeddings + trunk blocks (no lm_head).
n = d * self.base_config.vocab_size + self.trunk_layers * block_params
# Heads: head blocks only (no embeddings, no lm_head).
n += self.n_heads * self.head_layers * block_params
# Shared lm_head: one copy.
n += self.base_config.lm_head.num_params(d, self.base_config.vocab_size)
return n
@property
def num_non_embedding_params(self) -> int:
return self.num_params - self.base_config.d_model * self.base_config.vocab_size
[docs]
class HydraTransformer(nn.Module):
"""
A multi-head branched transformer.
Runs input through a shared trunk, then fans out to N independent heads.
All heads share a single lm_head for the final projection to vocab logits.
:param trunk: Shared transformer trunk (no lm_head).
:param heads: ModuleList of head transformers (no embeddings, no lm_head).
:param lm_head: Shared language modeling head.
"""
def __init__(
self,
trunk: Transformer,
heads: nn.ModuleList,
lm_head: nn.Module,
):
super().__init__()
self.trunk = trunk
self.heads = heads
self.lm_head = lm_head
@property
def num_params(self) -> int:
return sum(p.numel() for p in self.parameters())
def _attentions(self, omit_last: bool = False) -> list[Attention]:
"""
By convention the uncertainty head is always loaded on the final index.
The uncertainty head only performs one token generation so kv-caching
(which is the only use of this method) is never desirable.
TODO: this convention can be confusing because we train with uncertainty head
at index 0. A refactor should be made to clean things up (SWE-157).
TODO: related to the above, we never want kv-cache on the uncertainty head. There
is definitely a cleaner solution rather than passing the `omit_last` flag everywhere
"""
attentions = []
for block in self.trunk.blocks.values():
attn = cast(TransformerBlock, block).attention
if isinstance(attn, Attention):
attentions.append(attn)
heads = self.heads[:-1] if omit_last else self.heads
for head in heads:
for block in cast(Transformer, head).blocks.values():
attn = cast(TransformerBlock, block).attention
if isinstance(attn, Attention):
attentions.append(attn)
return attentions
def _kv_managers(self, omit_last: bool = False) -> list[KVCacheManager]:
return [
attn.kv_cache_manager
for attn in self._attentions(omit_last)
if attn.kv_cache_manager is not None
]
[docs]
def init_kv_cache(self, batch_size: int, max_seq_len: int, omit_last: bool = False):
"""Initialize KV caches for all blocks in trunk and heads."""
for attn in self._attentions(omit_last):
attn.init_kv_cache_manager(batch_size, max_seq_len)
[docs]
def reset_kv_cache(self, omit_last: bool = False):
"""Reset KV cache position counters to 0 before each generation."""
for m in self._kv_managers(omit_last):
m.cache_seqlens.fill_(0)
[docs]
def rollback_kv_cache(self, n: int, omit_last: bool = False):
"""Roll back cache pointers on trunk and every head by n positions."""
for m in self._kv_managers(omit_last):
m.cache_seqlens.sub_(n).clamp_(min=0)
[docs]
def sync_kv_cache(self, target_length: int, omit_last: bool = False):
"""Sync cache pointers on trunk and every head to specified position."""
for m in self._kv_managers(omit_last):
m.cache_seqlens.fill_(target_length)
m.cache_leftpad.fill_(0)
[docs]
def forward_trunk(self, input_ids: torch.Tensor, **kwargs) -> torch.Tensor:
"""Run the shared trunk and return its hidden states."""
return self.trunk(input_ids, **kwargs)
[docs]
def forward_heads(
self,
hidden_states: torch.Tensor,
head_indices: list[int] | None = None,
residual: torch.Tensor | None = None,
last_token_only: bool = False,
**kwargs,
) -> torch.Tensor:
"""Run selected heads on pre-computed trunk hidden states, return logits."""
h = hidden_states
if residual is not None:
assert residual.shape == h.shape, (
f"Residual shape mismatch, expected {h.shape} got {residual.shape}"
)
h = h + residual
if head_indices is not None:
if len(head_indices) == 0:
raise ValueError("head_indices must be non-empty")
n = len(self.heads)
for idx in head_indices:
if idx < 0 or idx >= n:
raise ValueError(f"head index {idx} out of range for {n} heads")
selected = [self.heads[i] for i in head_indices]
else:
selected = list(self.heads)
head_hidden = [head(h, **kwargs) for head in selected]
stacked = torch.cat(head_hidden, dim=0)
if last_token_only:
stacked = stacked[:, -1:, :]
all_logits: torch.Tensor = self.lm_head(stacked)
return all_logits.unflatten(0, (len(selected), -1))
[docs]
def forward(
self,
input_ids: torch.Tensor,
residual: torch.Tensor | None = None,
head_indices: list[int] | None = None,
last_token_only: bool = False,
**kwargs,
) -> torch.Tensor:
"""
Run the full model.
:param input_ids: Token IDs ``(batch, seq)``.
:param head_indices: Optional subset of head indices to run. None means all heads.
:param last_token_only: If True, project only the final sequence position through
the lm_head. Output seq dim collapses to 1. Cheap inference path for
classification / next-token argmax; training must keep False.
:returns: Logits tensor ``(n_selected, batch, seq_out, vocab)`` where
``seq_out == 1`` if ``last_token_only`` else ``seq``.
"""
h = self.forward_trunk(input_ids, **kwargs)
return self.forward_heads(
h,
head_indices=head_indices,
residual=residual,
last_token_only=last_token_only,
**kwargs,
)
[docs]
def residual_forward(
self,
input_ids: torch.Tensor,
hidden_head_indices: list[int],
head_indices: list[int] | None = None,
last_token_only: bool = False,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run the full model and return hidden state of specified head index.
:param input_ids: Token IDs ``(batch, seq)``.
:param hidden_head_indices: Global list of indices into ``self.heads``; must all be in ``head_indices``.
:param head_indices: Optional subset of head indices to run. None means all heads.
:returns: tuple[Logits tensor ``(n_selected, batch, seq, vocab)``,
hidden-state tensor ``(batch, seq, d_model)``].
"""
h = self.forward_trunk(input_ids, **kwargs)
if head_indices is not None:
if len(head_indices) == 0:
raise ValueError("head_indices must be non-empty")
n = len(self.heads)
for idx in head_indices:
if idx < 0 or idx >= n:
raise ValueError(f"head index {idx} out of range for {n} heads")
selected = [self.heads[i] for i in head_indices]
selected_idxs = list(head_indices)
else:
selected = list(self.heads)
selected_idxs = list(range(len(self.heads)))
for head_idx in hidden_head_indices:
if head_idx not in selected_idxs:
raise ValueError(
f"hidden_head_idx {head_idx} must be one of selected heads {selected_idxs}"
)
hidden_positions = [selected_idxs.index(h_idx) for h_idx in hidden_head_indices]
head_hidden = [head(h, **kwargs) for head in selected]
stacked = torch.stack(head_hidden, dim=0) # (N, batch, seq, d_model)
if last_token_only:
stacked = stacked[:, :, -1:, :]
hidden_heads = stacked[hidden_positions] # (N_hid, batch, seq, d_model)
all_logits: torch.Tensor = self.lm_head(stacked.flatten(0, 1))
return all_logits.unflatten(0, (len(selected), -1)), hidden_heads
[docs]
@staticmethod
def load_olmo_state(
model: "HydraTransformer",
olmo_state: dict[str, torch.Tensor],
trunk_layers: int,
vocab_size: int,
) -> None:
"""
Load a flat OLMo-format state dict into a HydraTransformer.
Splits the state by layer index into trunk/head/lm_head components,
pads vocab embeddings if needed, and clones head weights for each head.
:param model: The HydraTransformer to load into.
:param olmo_state: OLMo-format state dict (output of ``convert_state_from_hf``).
:param trunk_layers: Number of layers in the trunk.
:param vocab_size: Target vocab size (for padding).
"""
trunk_state: dict[str, torch.Tensor] = {}
head_state: dict[str, torch.Tensor] = {}
lm_head_state: dict[str, torch.Tensor] = {}
for key, value in olmo_state.items():
if key.startswith("blocks."):
block_idx = int(key.split(".", 2)[1])
suffix = key.split(".", 2)[2]
if block_idx < trunk_layers:
trunk_state[f"blocks.{block_idx}.{suffix}"] = value
else:
new_idx = block_idx - trunk_layers
head_state[f"blocks.{new_idx}.{suffix}"] = value
elif key.startswith("lm_head."):
lm_head_state[key.split(".", 1)[1]] = value
else:
trunk_state[key] = value
# pad vocab so that it is a nice size for matmuls
emb = trunk_state["embeddings.weight"]
if emb.shape[0] < vocab_size:
padding = torch.zeros(
vocab_size - emb.shape[0], emb.shape[1], dtype=emb.dtype
)
trunk_state["embeddings.weight"] = torch.cat([emb, padding], dim=0)
w_out = lm_head_state["w_out.weight"]
if w_out.shape[0] < vocab_size:
padding = torch.zeros(
vocab_size - w_out.shape[0], w_out.shape[1], dtype=w_out.dtype
)
lm_head_state["w_out.weight"] = torch.cat([w_out, padding], dim=0)
model.trunk.load_state_dict(trunk_state, assign=True)
model.lm_head.load_state_dict(lm_head_state, assign=True)
for i, head in enumerate(model.heads):
# NOTE: For testing, can inject noise into head params here
state = (
head_state if i == 0 else {k: v.clone() for k, v in head_state.items()}
) # NEED COPY
head.load_state_dict(state, assign=True)