olmo_tap.hydra

HydraTransformer: multi-head branched transformer.

Shares a common trunk (early layers) and branches into N independent heads (late layers). All heads share a single lm_head. Each head produces its own logits, which can be averaged or otherwise combined downstream.

Classes

HydraTransformer(trunk, heads, lm_head)

A multi-head branched transformer.

HydraTransformerConfig(base_config, n_heads, ...)

Config for building a HydraTransformer.

class olmo_tap.hydra.HydraTransformer(trunk: Transformer, heads: ModuleList, lm_head: Module)[source]

Bases: Module

A multi-head branched transformer.

Runs input through a shared trunk, then fans out to N independent heads. All heads share a single lm_head for the final projection to vocab logits.

Parameters:
  • trunk – Shared transformer trunk (no lm_head).

  • heads – ModuleList of head transformers (no embeddings, no lm_head).

  • lm_head – Shared language modeling head.

forward(input_ids: Tensor, residual: Tensor | None = None, head_indices: list[int] | None = None, last_token_only: bool = False, **kwargs) Tensor[source]

Run the full model.

Parameters:
  • input_ids – Token IDs (batch, seq).

  • head_indices – Optional subset of head indices to run. None means all heads.

  • last_token_only – If True, project only the final sequence position through the lm_head. Output seq dim collapses to 1. Cheap inference path for classification / next-token argmax; training must keep False.

Returns:

Logits tensor (n_selected, batch, seq_out, vocab) where seq_out == 1 if last_token_only else seq.

forward_heads(hidden_states: Tensor, head_indices: list[int] | None = None, residual: Tensor | None = None, last_token_only: bool = False, **kwargs) Tensor[source]

Run selected heads on pre-computed trunk hidden states, return logits.

forward_trunk(input_ids: Tensor, **kwargs) Tensor[source]

Run the shared trunk and return its hidden states.

init_kv_cache(batch_size: int, max_seq_len: int, omit_last: bool = False)[source]

Initialize KV caches for all blocks in trunk and heads.

static load_olmo_state(model: HydraTransformer, olmo_state: dict[str, Tensor], trunk_layers: int, vocab_size: int) None[source]

Load a flat OLMo-format state dict into a HydraTransformer.

Splits the state by layer index into trunk/head/lm_head components, pads vocab embeddings if needed, and clones head weights for each head.

Parameters:
  • model – The HydraTransformer to load into.

  • olmo_state – OLMo-format state dict (output of convert_state_from_hf).

  • trunk_layers – Number of layers in the trunk.

  • vocab_size – Target vocab size (for padding).

property num_params: int
reset_kv_cache(omit_last: bool = False)[source]

Reset KV cache position counters to 0 before each generation.

residual_forward(input_ids: Tensor, hidden_head_indices: list[int], head_indices: list[int] | None = None, last_token_only: bool = False, **kwargs) tuple[Tensor, Tensor][source]

Run the full model and return hidden state of specified head index.

Parameters:
  • input_ids – Token IDs (batch, seq).

  • hidden_head_indices – Global list of indices into self.heads; must all be in head_indices.

  • head_indices – Optional subset of head indices to run. None means all heads.

Returns:

tuple[Logits tensor (n_selected, batch, seq, vocab), hidden-state tensor (batch, seq, d_model)].

rollback_kv_cache(n: int, omit_last: bool = False)[source]

Roll back cache pointers on trunk and every head by n positions.

sync_kv_cache(target_length: int, omit_last: bool = False)[source]

Sync cache pointers on trunk and every head to specified position.

class olmo_tap.hydra.HydraTransformerConfig(base_config: TransformerConfig, n_heads: int, trunk_layers: int, head_layers: int)[source]

Bases: ModelConfig

Config for building a HydraTransformer.

Parameters:
  • base_config – Full TransformerConfig for the underlying model architecture.

  • n_heads – Number of parallel heads.

  • trunk_layers – Number of layers in the shared trunk.

  • head_layers – Number of layers per head.

base_config: TransformerConfig
build(*, init_device: str = 'cpu') HydraTransformer[source]

Build the HydraTransformer.

Parameters:

init_device – Device for parameter initialization (use "meta" for zero-memory).

classmethod from_olmo2_1B(n_heads: int = 5, heads_depth: int = 3, vocab_size: int = 100352) HydraTransformerConfig[source]

Factory for OLMo2 1B (16 layers) with configurable split point.

classmethod from_olmo2_7B(n_heads: int = 5, heads_depth: int = 3, vocab_size: int = 100352) HydraTransformerConfig[source]

Factory for OLMo2 7B (32 layers) with configurable split point.

head_layers: int
n_heads: int
property num_non_embedding_params: int

The total number of non-embedding parameters in the model once built.

property num_params: int

The total number of parameters in the model once built.

trunk_layers: int
validate()[source]

Validate fields in self. This may modify self in-place.