olmo_tap.hydra¶
HydraTransformer: multi-head branched transformer.
Shares a common trunk (early layers) and branches into N independent heads (late layers). All heads share a single lm_head. Each head produces its own logits, which can be averaged or otherwise combined downstream.
Classes
|
A multi-head branched transformer. |
|
Config for building a |
- class olmo_tap.hydra.HydraTransformer(trunk: Transformer, heads: ModuleList, lm_head: Module)[source]¶
Bases:
ModuleA multi-head branched transformer.
Runs input through a shared trunk, then fans out to N independent heads. All heads share a single lm_head for the final projection to vocab logits.
- Parameters:
trunk – Shared transformer trunk (no lm_head).
heads – ModuleList of head transformers (no embeddings, no lm_head).
lm_head – Shared language modeling head.
- forward(input_ids: Tensor, residual: Tensor | None = None, head_indices: list[int] | None = None, last_token_only: bool = False, **kwargs) Tensor[source]¶
Run the full model.
- Parameters:
input_ids – Token IDs
(batch, seq).head_indices – Optional subset of head indices to run. None means all heads.
last_token_only – If True, project only the final sequence position through the lm_head. Output seq dim collapses to 1. Cheap inference path for classification / next-token argmax; training must keep False.
- Returns:
Logits tensor
(n_selected, batch, seq_out, vocab)whereseq_out == 1iflast_token_onlyelseseq.
- forward_heads(hidden_states: Tensor, head_indices: list[int] | None = None, residual: Tensor | None = None, last_token_only: bool = False, **kwargs) Tensor[source]¶
Run selected heads on pre-computed trunk hidden states, return logits.
- forward_trunk(input_ids: Tensor, **kwargs) Tensor[source]¶
Run the shared trunk and return its hidden states.
- init_kv_cache(batch_size: int, max_seq_len: int, omit_last: bool = False)[source]¶
Initialize KV caches for all blocks in trunk and heads.
- static load_olmo_state(model: HydraTransformer, olmo_state: dict[str, Tensor], trunk_layers: int, vocab_size: int) None[source]¶
Load a flat OLMo-format state dict into a HydraTransformer.
Splits the state by layer index into trunk/head/lm_head components, pads vocab embeddings if needed, and clones head weights for each head.
- Parameters:
model – The HydraTransformer to load into.
olmo_state – OLMo-format state dict (output of
convert_state_from_hf).trunk_layers – Number of layers in the trunk.
vocab_size – Target vocab size (for padding).
- reset_kv_cache(omit_last: bool = False)[source]¶
Reset KV cache position counters to 0 before each generation.
- residual_forward(input_ids: Tensor, hidden_head_indices: list[int], head_indices: list[int] | None = None, last_token_only: bool = False, **kwargs) tuple[Tensor, Tensor][source]¶
Run the full model and return hidden state of specified head index.
- Parameters:
input_ids – Token IDs
(batch, seq).hidden_head_indices – Global list of indices into
self.heads; must all be inhead_indices.head_indices – Optional subset of head indices to run. None means all heads.
- Returns:
tuple[Logits tensor
(n_selected, batch, seq, vocab), hidden-state tensor(batch, seq, d_model)].
- class olmo_tap.hydra.HydraTransformerConfig(base_config: TransformerConfig, n_heads: int, trunk_layers: int, head_layers: int)[source]¶
Bases:
ModelConfigConfig for building a
HydraTransformer.- Parameters:
base_config – Full TransformerConfig for the underlying model architecture.
n_heads – Number of parallel heads.
trunk_layers – Number of layers in the shared trunk.
head_layers – Number of layers per head.
- base_config: TransformerConfig¶
- build(*, init_device: str = 'cpu') HydraTransformer[source]¶
Build the HydraTransformer.
- Parameters:
init_device – Device for parameter initialization (use
"meta"for zero-memory).
- classmethod from_olmo2_1B(n_heads: int = 5, heads_depth: int = 3, vocab_size: int = 100352) HydraTransformerConfig[source]¶
Factory for OLMo2 1B (16 layers) with configurable split point.
- classmethod from_olmo2_7B(n_heads: int = 5, heads_depth: int = 3, vocab_size: int = 100352) HydraTransformerConfig[source]¶
Factory for OLMo2 7B (32 layers) with configurable split point.