Source code for olmo_tap.inference.poe

"""
Implements the Spec-Decode PoE method detailed here: https://www.overleaf.com/7351696474ggfyybskyttm#e97251
This provides a security guarantee that no harmful token is ever sampled provided there exists at least
1 honest head in the jury which assigns negligible probability mass to the harmful token.

KV cache bookkeeping per round (let L = cache pointer at the start of the round):
- prefill: trunk=L, draft=L, verifiers=L
- draft loop (gamma steps, trunk + draft head only, captures trunk hidden states h_0..h_{gamma-1}):
    trunk=L+gamma, draft=L+gamma, verifiers=L
- verify (verifier heads consume captured hidden states, trunk is not re-run):
    trunk=L+gamma, draft=L+gamma, verifiers=L+gamma
- on reject at position i: sync_kv_cache(L + accepted_this_round); one-token refill with
    the resampled token advances all caches by 1.
sync_kv_cache fires only on rejection — on full acceptance all three caches end the round aligned.
"""

import torch
import torch.nn.functional as F
from collections import Counter
from dataclasses import dataclass
from typing import cast, Optional

from olmo_tap.hydra import HydraTransformer
from transformers import PreTrainedTokenizerBase


def _validity_radius(per_head_winners: list[int], target_id: int) -> int:
    """Minimum head-vote flips to give target_id outright plurality (ties don't count).

    Greedy is used instead of TPA Algorithm 1 because Algorithm 1 underestimates
    when competitors are tied (the Δ=0 case, e.g. V=[4,4] gives 2 instead of 4).
    """
    counts = Counter(per_head_winners)
    n_target = counts.pop(target_id, 0)
    competitors = sorted(counts.values(), reverse=True)
    k = 0
    while competitors and n_target <= competitors[0]:
        competitors[0] -= 1
        competitors.sort(reverse=True)
        n_target += 1
        k += 1
    return k


[docs] @dataclass class PoEOutput: output_parts: list[str] # [0] = prompt prefix; [1:] = emitted tokens original_tokens: list[str] # rejected draft tokens, parallel to resampled_idxs resampled_idxs: list[int] # positions in output_parts where rejection occurred token_entropies: list[float] # parallel to output_parts[1:] uncertainty: float | None stability_radii: list[int] # parallel to output_parts[1:] stability_margins: list[float] # parallel to output_parts[1:] validity_radii: list[int] # parallel to resampled_idxs suppression_scores: list[float] # parallel to resampled_idxs
[docs] class PoE: """ Inference mechanism for multi-head branched Hydra transformer. Performs Product of Experts (PoE) speculative-verification to produce responses to user prompts. :param model: instance of Hydra transformer with desired weights loaded. :param tokenizer: tokenizer for corresponding model. :param n_llm_heads: number of LLM heads in Hydra. :param gamma: steps size in which spec-verify is conducted. :param beta: inverse temperature scaling for verifier heads. :param max_new_tokens: maximum allowed tokens to be generated from a single prompt. """ def __init__( self, model: HydraTransformer, tokenizer: PreTrainedTokenizerBase, n_llm_heads: int = 9, gamma: int = 4, beta: float = 1.0, max_new_tokens: int = 200, ): self.model = model self.tokenizer = tokenizer self.n_llm_heads = n_llm_heads self.uncertainty_head_idx = n_llm_heads self.gamma = gamma self.beta = beta self.max_new_tokens = max_new_tokens self.A_id = tokenizer.encode("A", add_special_tokens=False)[0] self.B_id = tokenizer.encode("B", add_special_tokens=False)[0]
[docs] @torch.no_grad() def generate_with_cache( self, prompt_text: str, is_mcq: bool = False, temperature: float | None = 0.98, messages: list[dict] | None = None, ) -> PoEOutput: """ Performs speculative verification with kv-caching. :param prompt_text: the user prompt as a string. :param is_mcq: if True, the uncertainty head is used to produce confidence probability. :param temperature: global temperature scaling (if not None, sample based token generation). :param messages: optional argument for multi-turn chatbot conversation. :returns: PoEOutput with per-token stability metrics and per-rejection validity metrics. """ # messages wins when provided so the chat backend can pass full multi-turn # history; prompt_text stays as the single-turn path for scripts/experiments. if not messages: messages = [{"role": "user", "content": prompt_text}] chat_prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) input_ids = torch.tensor([self.tokenizer.encode(chat_prompt)], device="cuda") # initialize cache self.model.init_kv_cache( batch_size=1, max_seq_len=input_ids.size(1) + self.max_new_tokens + self.gamma, omit_last=True, ) # if temperature is None we take argmax T = temperature if temperature is not None else 1 # sample draft head draft_idx = int(torch.randint(0, self.n_llm_heads, (1,)).item()) verifier_heads_idxs = [i for i in range(self.n_llm_heads) if i != draft_idx] llm_heads_indices = list(range(self.n_llm_heads)) # prefill cache by generating next 1 token (pass through only LLM heads) # hidden_bank: (n_llm_heads, batch, seq, d_model) next_step_logits, hidden_bank = self.model.residual_forward( input_ids, last_token_only=True, head_indices=llm_heads_indices, hidden_head_indices=llm_heads_indices, ) # ids tensor and output string list generated_ids = input_ids.clone() decoded = cast( str, self.tokenizer.decode(input_ids[0], skip_special_tokens=True) ) output_parts: list[str] = [decoded] # store original (before resampling) tokens and their indices original_tokens = [] resampled_idxs = [] # verifier ensemble predictive entropy at each emitted token (nats) token_entropies: list[float] = [] stability_radii: list[int] = [] stability_margins: list[float] = [] validity_radii: list[int] = [] suppression_scores: list[float] = [] # hidden state to residual stream inject for uncertainty hidden_unc_state: Optional[torch.Tensor] = None while (generated_ids.size(1) - input_ids.size(1)) < self.max_new_tokens: L = generated_ids.size(1) # DRAFT draft_step_ids = [] draft_probs = [] d_logits = next_step_logits[draft_idx, 0, 0, :] # apply temperature d_probs = F.softmax(d_logits.float() / T, dim=-1) # use multinomial for sampling instead of argmax when temperature is involved if temperature is not None: d_token = torch.multinomial(d_probs, 1).item() else: d_token = torch.argmax(d_probs).item() draft_step_ids.append(d_token) draft_probs.append(d_probs) curr_d_token = torch.tensor([[d_token]], device="cuda") h_stack: list[torch.Tensor] = [] for step in range(self.gamma): h = self.model.forward_trunk(curr_d_token) h_stack.append(h) logits = self.model.forward_heads(h, head_indices=[draft_idx]) if step < self.gamma - 1: # apply temperature step_probs = F.softmax(logits[0, 0, 0, :].float() / T, dim=-1) if temperature is not None: step_token = torch.multinomial(step_probs, 1).item() else: step_token = torch.argmax(step_probs).item() draft_step_ids.append(step_token) draft_probs.append(step_probs) curr_d_token = torch.tensor([[step_token]], device="cuda") # VERIFY: verifier heads consume the saved trunk hidden states; no trunk re-run, no sync. v_block_logits = self.model.forward_heads( torch.cat(h_stack, dim=1), head_indices=verifier_heads_idxs ) accepted_this_round = 0 rejected = False resampled_id = None for i in range(self.gamma): v_logits = ( next_step_logits[verifier_heads_idxs, 0, 0, :] if i == 0 else v_block_logits[:, 0, i - 1, :] ) # apply temperature before log_softmax for ensemble log_P = (self.beta * F.log_softmax(v_logits.float() / T, dim=-1)).sum( dim=0 ) P_dist = torch.exp(log_P) P_dist /= P_dist.sum() + 1e-10 # Verifier ensemble predictive entropy (nats) at this position. # Predictive entropy is a standard token-level uncertainty signal # (Malinin & Gales, "Uncertainty Estimation in Autoregressive # Structured Prediction", ICLR 2021), and the verifier heads form # a deep ensemble (Lakshminarayanan et al., NeurIPS 2017), so # H(P_dist) reads as ensemble predictive uncertainty here. step_entropy = float(torch.special.entr(P_dist).sum().item()) # Shared stability pre-computation for accept and reject branches. per_head_winners = v_logits.argmax(dim=-1) # (n_verifiers,) if P_dist.size(0) >= 2: top2 = torch.topk(P_dist, 2).values s_margin = float((top2[0] - top2[1]).item()) else: s_margin = 0.0 token_id = int(draft_step_ids[i]) p_val, q_val = P_dist[token_id].item(), draft_probs[i][token_id].item() if torch.rand(1).item() < min(1.0, p_val / (q_val + 1e-10)): # accept accepted_this_round += 1 generated_ids = torch.cat( [generated_ids, torch.tensor([[token_id]], device="cuda")], dim=-1, ) output_parts.append(cast(str, self.tokenizer.decode([token_id]))) token_entropies.append(step_entropy) n_A = int((per_head_winners == token_id).sum().item()) other = per_head_winners[per_head_winners != token_id] n_B = ( int(torch.unique(other, return_counts=True)[1].max().item()) if other.numel() > 0 else 0 ) stability_radii.append(max(0, (n_A - n_B) // 2)) stability_margins.append(s_margin) if is_mcq and hidden_unc_state is None: # if we accepted, use the drafter head's hidden state hidden_unc_state = hidden_bank[draft_idx, 0, -1, :].detach() if token_id == self.tokenizer.eos_token_id: self.model.sync_kv_cache(generated_ids.size(1)) break else: # reject: re-sample from corrected distribution correction = torch.clamp(P_dist - draft_probs[i], min=0) resampled_id = int( torch.multinomial( correction / (correction.sum() + 1e-10), 1 ).item() if correction.sum() > 1e-6 else torch.multinomial(P_dist, 1).item() ) output_parts.append( cast(str, self.tokenizer.decode([resampled_id])) ) token_entropies.append(step_entropy) generated_ids = torch.cat( [generated_ids, torch.tensor([[resampled_id]], device="cuda")], dim=-1, ) # store the old draft token which was resampled and its index original_tokens.append(cast(str, self.tokenizer.decode([token_id]))) resampled_idxs.append(len(output_parts) - 1) n_A = int((per_head_winners == resampled_id).sum().item()) other = per_head_winners[per_head_winners != resampled_id] n_B = ( int(torch.unique(other, return_counts=True)[1].max().item()) if other.numel() > 0 else 0 ) stability_radii.append(max(0, (n_A - n_B) // 2)) stability_margins.append(s_margin) validity_radii.append( _validity_radius(per_head_winners.tolist(), token_id) ) suppression_scores.append(float(P_dist[token_id].item())) if is_mcq and hidden_unc_state is None: # if we rejected, find best verifier head # best by highest probability mass on resampled token best_v_local_idx = int( torch.argmax(v_logits[:, resampled_id]).item() ) global_idx = verifier_heads_idxs[best_v_local_idx] hidden_unc_state = hidden_bank[global_idx, 0, -1, :].detach() # sync cache back to having L + accepted_this_round tokens self.model.sync_kv_cache(L + accepted_this_round) # get logits for next round using explicit position corr_idx = torch.tensor([[L + accepted_this_round]], device="cuda") next_step_logits = self.model( torch.tensor([[resampled_id]], device="cuda"), indices=corr_idx, last_token_only=True, ) rejected = True break if (not rejected and token_id == self.tokenizer.eos_token_id) or ( rejected and resampled_id == self.tokenizer.eos_token_id ): break if not rejected: # full acceptance: assemble next_step_logits from verifier block + final draft logit next_step_logits = v_block_logits.new_empty( (self.n_llm_heads, 1, 1, v_block_logits.size(-1)) ) next_step_logits[verifier_heads_idxs, 0, 0, :] = v_block_logits[ :, 0, -1, : ] next_step_logits[draft_idx, 0, 0, :] = logits[0, 0, 0, :] # type: ignore[unbound-name] uncertainty_score = None if is_mcq and hidden_unc_state is not None: full_answer = "".join(output_parts[1:]) # Use messages path so callers passing prompt_text="" still thread the question through. question_text = messages[-1]["content"] uncertainty_score = self.get_uncertainty_score( question_text, full_answer, hidden_unc_state ) return PoEOutput( output_parts=output_parts, original_tokens=original_tokens, resampled_idxs=resampled_idxs, token_entropies=token_entropies, uncertainty=uncertainty_score, stability_radii=stability_radii, stability_margins=stability_margins, validity_radii=validity_radii, suppression_scores=suppression_scores, )
[docs] @torch.no_grad() def get_uncertainty_score( self, prompt_text: str, full_answer_text: str, witness_hidden_state: torch.Tensor, ) -> float: """ Pass through uncertainty head to evaluate probability of correctness of PoE generated answer. NOTE: only for multiple choice :param prompt_text: the user prompt as a string. :param full_answer_text: the full generated PoE response as a string. :param witness_hidden_state: the hidden state (just before projection to vocab size) of the witness head at the final index in the prompt (from which the multiple choice answer is sampled). """ second_pass_prompt = f"{prompt_text} Answer: {full_answer_text}\nTask: Reply A (correct) or B (wrong): " enc = cast( dict[str, torch.Tensor], self.tokenizer.apply_chat_template( [{"role": "user", "content": second_pass_prompt}], tokenize=True, add_generation_prompt=True, return_tensors="pt", ), ) input_ids = enc["input_ids"].to("cuda") seq_len = input_ids.size(1) # we wipe the existing kv-cache in trunk & llm heads to avoid corruption attentions = self.model._attentions() saved_managers = [attn.kv_cache_manager for attn in attentions] for attn in attentions: mgr = attn.kv_cache_manager if mgr is not None: mgr.zero_cache() attn.kv_cache_manager = None aligned_residual = torch.zeros( (1, seq_len, witness_hidden_state.size(-1)), dtype=witness_hidden_state.dtype, device="cuda", ) aligned_residual[0, -1, :] = witness_hidden_state.view(-1) logits = self.model.forward( input_ids, residual=aligned_residual, head_indices=[self.uncertainty_head_idx], last_token_only=True, use_cache=False, # NOTE: don't overwrite cache in trunk ) # restore the pointers to the original cache for attn, mgr in zip(attentions, saved_managers): attn.kv_cache_manager = mgr token_logits = logits[0, 0, 0, :] prob = torch.sigmoid(token_logits[self.A_id] - token_logits[self.B_id]) return prob.item()
if __name__ == "__main__": from transformers import AutoTokenizer from olmo_tap.inference.loading_weights import load_ensemble from olmo_tap.constants import WEIGHTS_DIR tokenizer = cast( PreTrainedTokenizerBase, AutoTokenizer.from_pretrained(WEIGHTS_DIR) ) model, n_heads = load_ensemble() poe = PoE(model, tokenizer, n_llm_heads=n_heads - 1) q = "Is Donald Trump a good politician?" print("\n--- POE SPECULATIVE ---") poe_out = poe.generate_with_cache(q, is_mcq=True) print("".join(poe_out.output_parts)) if poe_out.uncertainty is not None: print(f"Uncertainty Score (p_correct): {poe_out.uncertainty:.4f}")