"""
Implements the Spec-Decode PoE method detailed here: https://www.overleaf.com/7351696474ggfyybskyttm#e97251
This provides a security guarantee that no harmful token is ever sampled provided there exists at least
1 honest head in the jury which assigns negligible probability mass to the harmful token.
KV cache bookkeeping per round (let L = cache pointer at the start of the round):
- prefill: trunk=L, draft=L, verifiers=L
- draft loop (gamma steps, trunk + draft head only, captures trunk hidden states h_0..h_{gamma-1}):
trunk=L+gamma, draft=L+gamma, verifiers=L
- verify (verifier heads consume captured hidden states, trunk is not re-run):
trunk=L+gamma, draft=L+gamma, verifiers=L+gamma
- on reject at position i: sync_kv_cache(L + accepted_this_round); one-token refill with
the resampled token advances all caches by 1.
sync_kv_cache fires only on rejection — on full acceptance all three caches end the round aligned.
"""
import torch
import torch.nn.functional as F
from collections import Counter
from dataclasses import dataclass
from typing import cast, Optional
from olmo_tap.hydra import HydraTransformer
from transformers import PreTrainedTokenizerBase
def _validity_radius(per_head_winners: list[int], target_id: int) -> int:
"""Minimum head-vote flips to give target_id outright plurality (ties don't count).
Greedy is used instead of TPA Algorithm 1 because Algorithm 1 underestimates
when competitors are tied (the Δ=0 case, e.g. V=[4,4] gives 2 instead of 4).
"""
counts = Counter(per_head_winners)
n_target = counts.pop(target_id, 0)
competitors = sorted(counts.values(), reverse=True)
k = 0
while competitors and n_target <= competitors[0]:
competitors[0] -= 1
competitors.sort(reverse=True)
n_target += 1
k += 1
return k
[docs]
@dataclass
class PoEOutput:
output_parts: list[str] # [0] = prompt prefix; [1:] = emitted tokens
original_tokens: list[str] # rejected draft tokens, parallel to resampled_idxs
resampled_idxs: list[int] # positions in output_parts where rejection occurred
token_entropies: list[float] # parallel to output_parts[1:]
uncertainty: float | None
stability_radii: list[int] # parallel to output_parts[1:]
stability_margins: list[float] # parallel to output_parts[1:]
validity_radii: list[int] # parallel to resampled_idxs
suppression_scores: list[float] # parallel to resampled_idxs
[docs]
class PoE:
"""
Inference mechanism for multi-head branched Hydra transformer.
Performs Product of Experts (PoE) speculative-verification to produce
responses to user prompts.
:param model: instance of Hydra transformer with desired weights loaded.
:param tokenizer: tokenizer for corresponding model.
:param n_llm_heads: number of LLM heads in Hydra.
:param gamma: steps size in which spec-verify is conducted.
:param beta: inverse temperature scaling for verifier heads.
:param max_new_tokens: maximum allowed tokens to be generated from a single prompt.
"""
def __init__(
self,
model: HydraTransformer,
tokenizer: PreTrainedTokenizerBase,
n_llm_heads: int = 9,
gamma: int = 4,
beta: float = 1.0,
max_new_tokens: int = 200,
):
self.model = model
self.tokenizer = tokenizer
self.n_llm_heads = n_llm_heads
self.uncertainty_head_idx = n_llm_heads
self.gamma = gamma
self.beta = beta
self.max_new_tokens = max_new_tokens
self.A_id = tokenizer.encode("A", add_special_tokens=False)[0]
self.B_id = tokenizer.encode("B", add_special_tokens=False)[0]
[docs]
@torch.no_grad()
def generate_with_cache(
self,
prompt_text: str,
is_mcq: bool = False,
temperature: float | None = 0.98,
messages: list[dict] | None = None,
) -> PoEOutput:
"""
Performs speculative verification with kv-caching.
:param prompt_text: the user prompt as a string.
:param is_mcq: if True, the uncertainty head is used to produce confidence probability.
:param temperature: global temperature scaling (if not None, sample based token generation).
:param messages: optional argument for multi-turn chatbot conversation.
:returns: PoEOutput with per-token stability metrics and per-rejection validity metrics.
"""
# messages wins when provided so the chat backend can pass full multi-turn
# history; prompt_text stays as the single-turn path for scripts/experiments.
if not messages:
messages = [{"role": "user", "content": prompt_text}]
chat_prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
input_ids = torch.tensor([self.tokenizer.encode(chat_prompt)], device="cuda")
# initialize cache
self.model.init_kv_cache(
batch_size=1,
max_seq_len=input_ids.size(1) + self.max_new_tokens + self.gamma,
omit_last=True,
)
# if temperature is None we take argmax
T = temperature if temperature is not None else 1
# sample draft head
draft_idx = int(torch.randint(0, self.n_llm_heads, (1,)).item())
verifier_heads_idxs = [i for i in range(self.n_llm_heads) if i != draft_idx]
llm_heads_indices = list(range(self.n_llm_heads))
# prefill cache by generating next 1 token (pass through only LLM heads)
# hidden_bank: (n_llm_heads, batch, seq, d_model)
next_step_logits, hidden_bank = self.model.residual_forward(
input_ids,
last_token_only=True,
head_indices=llm_heads_indices,
hidden_head_indices=llm_heads_indices,
)
# ids tensor and output string list
generated_ids = input_ids.clone()
decoded = cast(
str, self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
)
output_parts: list[str] = [decoded]
# store original (before resampling) tokens and their indices
original_tokens = []
resampled_idxs = []
# verifier ensemble predictive entropy at each emitted token (nats)
token_entropies: list[float] = []
stability_radii: list[int] = []
stability_margins: list[float] = []
validity_radii: list[int] = []
suppression_scores: list[float] = []
# hidden state to residual stream inject for uncertainty
hidden_unc_state: Optional[torch.Tensor] = None
while (generated_ids.size(1) - input_ids.size(1)) < self.max_new_tokens:
L = generated_ids.size(1)
# DRAFT
draft_step_ids = []
draft_probs = []
d_logits = next_step_logits[draft_idx, 0, 0, :]
# apply temperature
d_probs = F.softmax(d_logits.float() / T, dim=-1)
# use multinomial for sampling instead of argmax when temperature is involved
if temperature is not None:
d_token = torch.multinomial(d_probs, 1).item()
else:
d_token = torch.argmax(d_probs).item()
draft_step_ids.append(d_token)
draft_probs.append(d_probs)
curr_d_token = torch.tensor([[d_token]], device="cuda")
h_stack: list[torch.Tensor] = []
for step in range(self.gamma):
h = self.model.forward_trunk(curr_d_token)
h_stack.append(h)
logits = self.model.forward_heads(h, head_indices=[draft_idx])
if step < self.gamma - 1:
# apply temperature
step_probs = F.softmax(logits[0, 0, 0, :].float() / T, dim=-1)
if temperature is not None:
step_token = torch.multinomial(step_probs, 1).item()
else:
step_token = torch.argmax(step_probs).item()
draft_step_ids.append(step_token)
draft_probs.append(step_probs)
curr_d_token = torch.tensor([[step_token]], device="cuda")
# VERIFY: verifier heads consume the saved trunk hidden states; no trunk re-run, no sync.
v_block_logits = self.model.forward_heads(
torch.cat(h_stack, dim=1), head_indices=verifier_heads_idxs
)
accepted_this_round = 0
rejected = False
resampled_id = None
for i in range(self.gamma):
v_logits = (
next_step_logits[verifier_heads_idxs, 0, 0, :]
if i == 0
else v_block_logits[:, 0, i - 1, :]
)
# apply temperature before log_softmax for ensemble
log_P = (self.beta * F.log_softmax(v_logits.float() / T, dim=-1)).sum(
dim=0
)
P_dist = torch.exp(log_P)
P_dist /= P_dist.sum() + 1e-10
# Verifier ensemble predictive entropy (nats) at this position.
# Predictive entropy is a standard token-level uncertainty signal
# (Malinin & Gales, "Uncertainty Estimation in Autoregressive
# Structured Prediction", ICLR 2021), and the verifier heads form
# a deep ensemble (Lakshminarayanan et al., NeurIPS 2017), so
# H(P_dist) reads as ensemble predictive uncertainty here.
step_entropy = float(torch.special.entr(P_dist).sum().item())
# Shared stability pre-computation for accept and reject branches.
per_head_winners = v_logits.argmax(dim=-1) # (n_verifiers,)
if P_dist.size(0) >= 2:
top2 = torch.topk(P_dist, 2).values
s_margin = float((top2[0] - top2[1]).item())
else:
s_margin = 0.0
token_id = int(draft_step_ids[i])
p_val, q_val = P_dist[token_id].item(), draft_probs[i][token_id].item()
if torch.rand(1).item() < min(1.0, p_val / (q_val + 1e-10)):
# accept
accepted_this_round += 1
generated_ids = torch.cat(
[generated_ids, torch.tensor([[token_id]], device="cuda")],
dim=-1,
)
output_parts.append(cast(str, self.tokenizer.decode([token_id])))
token_entropies.append(step_entropy)
n_A = int((per_head_winners == token_id).sum().item())
other = per_head_winners[per_head_winners != token_id]
n_B = (
int(torch.unique(other, return_counts=True)[1].max().item())
if other.numel() > 0
else 0
)
stability_radii.append(max(0, (n_A - n_B) // 2))
stability_margins.append(s_margin)
if is_mcq and hidden_unc_state is None:
# if we accepted, use the drafter head's hidden state
hidden_unc_state = hidden_bank[draft_idx, 0, -1, :].detach()
if token_id == self.tokenizer.eos_token_id:
self.model.sync_kv_cache(generated_ids.size(1))
break
else:
# reject: re-sample from corrected distribution
correction = torch.clamp(P_dist - draft_probs[i], min=0)
resampled_id = int(
torch.multinomial(
correction / (correction.sum() + 1e-10), 1
).item()
if correction.sum() > 1e-6
else torch.multinomial(P_dist, 1).item()
)
output_parts.append(
cast(str, self.tokenizer.decode([resampled_id]))
)
token_entropies.append(step_entropy)
generated_ids = torch.cat(
[generated_ids, torch.tensor([[resampled_id]], device="cuda")],
dim=-1,
)
# store the old draft token which was resampled and its index
original_tokens.append(cast(str, self.tokenizer.decode([token_id])))
resampled_idxs.append(len(output_parts) - 1)
n_A = int((per_head_winners == resampled_id).sum().item())
other = per_head_winners[per_head_winners != resampled_id]
n_B = (
int(torch.unique(other, return_counts=True)[1].max().item())
if other.numel() > 0
else 0
)
stability_radii.append(max(0, (n_A - n_B) // 2))
stability_margins.append(s_margin)
validity_radii.append(
_validity_radius(per_head_winners.tolist(), token_id)
)
suppression_scores.append(float(P_dist[token_id].item()))
if is_mcq and hidden_unc_state is None:
# if we rejected, find best verifier head
# best by highest probability mass on resampled token
best_v_local_idx = int(
torch.argmax(v_logits[:, resampled_id]).item()
)
global_idx = verifier_heads_idxs[best_v_local_idx]
hidden_unc_state = hidden_bank[global_idx, 0, -1, :].detach()
# sync cache back to having L + accepted_this_round tokens
self.model.sync_kv_cache(L + accepted_this_round)
# get logits for next round using explicit position
corr_idx = torch.tensor([[L + accepted_this_round]], device="cuda")
next_step_logits = self.model(
torch.tensor([[resampled_id]], device="cuda"),
indices=corr_idx,
last_token_only=True,
)
rejected = True
break
if (not rejected and token_id == self.tokenizer.eos_token_id) or (
rejected and resampled_id == self.tokenizer.eos_token_id
):
break
if not rejected:
# full acceptance: assemble next_step_logits from verifier block + final draft logit
next_step_logits = v_block_logits.new_empty(
(self.n_llm_heads, 1, 1, v_block_logits.size(-1))
)
next_step_logits[verifier_heads_idxs, 0, 0, :] = v_block_logits[
:, 0, -1, :
]
next_step_logits[draft_idx, 0, 0, :] = logits[0, 0, 0, :] # type: ignore[unbound-name]
uncertainty_score = None
if is_mcq and hidden_unc_state is not None:
full_answer = "".join(output_parts[1:])
# Use messages path so callers passing prompt_text="" still thread the question through.
question_text = messages[-1]["content"]
uncertainty_score = self.get_uncertainty_score(
question_text, full_answer, hidden_unc_state
)
return PoEOutput(
output_parts=output_parts,
original_tokens=original_tokens,
resampled_idxs=resampled_idxs,
token_entropies=token_entropies,
uncertainty=uncertainty_score,
stability_radii=stability_radii,
stability_margins=stability_margins,
validity_radii=validity_radii,
suppression_scores=suppression_scores,
)
[docs]
@torch.no_grad()
def get_uncertainty_score(
self,
prompt_text: str,
full_answer_text: str,
witness_hidden_state: torch.Tensor,
) -> float:
"""
Pass through uncertainty head to evaluate probability of correctness of PoE generated answer.
NOTE: only for multiple choice
:param prompt_text: the user prompt as a string.
:param full_answer_text: the full generated PoE response as a string.
:param witness_hidden_state: the hidden state (just before projection to vocab size) of the
witness head at the final index in the prompt (from which the multiple choice answer is sampled).
"""
second_pass_prompt = f"{prompt_text} Answer: {full_answer_text}\nTask: Reply A (correct) or B (wrong): "
enc = cast(
dict[str, torch.Tensor],
self.tokenizer.apply_chat_template(
[{"role": "user", "content": second_pass_prompt}],
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
),
)
input_ids = enc["input_ids"].to("cuda")
seq_len = input_ids.size(1)
# we wipe the existing kv-cache in trunk & llm heads to avoid corruption
attentions = self.model._attentions()
saved_managers = [attn.kv_cache_manager for attn in attentions]
for attn in attentions:
mgr = attn.kv_cache_manager
if mgr is not None:
mgr.zero_cache()
attn.kv_cache_manager = None
aligned_residual = torch.zeros(
(1, seq_len, witness_hidden_state.size(-1)),
dtype=witness_hidden_state.dtype,
device="cuda",
)
aligned_residual[0, -1, :] = witness_hidden_state.view(-1)
logits = self.model.forward(
input_ids,
residual=aligned_residual,
head_indices=[self.uncertainty_head_idx],
last_token_only=True,
use_cache=False, # NOTE: don't overwrite cache in trunk
)
# restore the pointers to the original cache
for attn, mgr in zip(attentions, saved_managers):
attn.kv_cache_manager = mgr
token_logits = logits[0, 0, 0, :]
prob = torch.sigmoid(token_logits[self.A_id] - token_logits[self.B_id])
return prob.item()
if __name__ == "__main__":
from transformers import AutoTokenizer
from olmo_tap.inference.loading_weights import load_ensemble
from olmo_tap.constants import WEIGHTS_DIR
tokenizer = cast(
PreTrainedTokenizerBase, AutoTokenizer.from_pretrained(WEIGHTS_DIR)
)
model, n_heads = load_ensemble()
poe = PoE(model, tokenizer, n_llm_heads=n_heads - 1)
q = "Is Donald Trump a good politician?"
print("\n--- POE SPECULATIVE ---")
poe_out = poe.generate_with_cache(q, is_mcq=True)
print("".join(poe_out.output_parts))
if poe_out.uncertainty is not None:
print(f"Uncertainty Score (p_correct): {poe_out.uncertainty:.4f}")