olmo_tap.inference.poe

Implements the Spec-Decode PoE method detailed here: https://www.overleaf.com/7351696474ggfyybskyttm#e97251 This provides a security guarantee that no harmful token is ever sampled provided there exists at least 1 honest head in the jury which assigns negligible probability mass to the harmful token.

KV cache bookkeeping per round (let L = cache pointer at the start of the round): - prefill: trunk=L, draft=L, verifiers=L - draft loop (gamma steps, trunk + draft head only, captures trunk hidden states h_0..h_{gamma-1}):

trunk=L+gamma, draft=L+gamma, verifiers=L

  • verify (verifier heads consume captured hidden states, trunk is not re-run):

    trunk=L+gamma, draft=L+gamma, verifiers=L+gamma

  • on reject at position i: sync_kv_cache(L + accepted_this_round); one-token refill with

    the resampled token advances all caches by 1.

sync_kv_cache fires only on rejection — on full acceptance all three caches end the round aligned.

Classes

PoE(model, tokenizer[, n_llm_heads, gamma, ...])

Inference mechanism for multi-head branched Hydra transformer.

PoEOutput(output_parts, original_tokens, ...)

class olmo_tap.inference.poe.PoE(model: HydraTransformer, tokenizer: PreTrainedTokenizerBase, n_llm_heads: int = 9, gamma: int = 4, beta: float = 1.0, max_new_tokens: int = 200)[source]

Bases: object

Inference mechanism for multi-head branched Hydra transformer. Performs Product of Experts (PoE) speculative-verification to produce responses to user prompts.

Parameters:
  • model – instance of Hydra transformer with desired weights loaded.

  • tokenizer – tokenizer for corresponding model.

  • n_llm_heads – number of LLM heads in Hydra.

  • gamma – steps size in which spec-verify is conducted.

  • beta – inverse temperature scaling for verifier heads.

  • max_new_tokens – maximum allowed tokens to be generated from a single prompt.

generate_with_cache(prompt_text: str, is_mcq: bool = False, temperature: float | None = 0.98, messages: list[dict] | None = None) PoEOutput[source]

Performs speculative verification with kv-caching.

Parameters:
  • prompt_text – the user prompt as a string.

  • is_mcq – if True, the uncertainty head is used to produce confidence probability.

  • temperature – global temperature scaling (if not None, sample based token generation).

  • messages – optional argument for multi-turn chatbot conversation.

Returns:

PoEOutput with per-token stability metrics and per-rejection validity metrics.

get_uncertainty_score(prompt_text: str, full_answer_text: str, witness_hidden_state: Tensor) float[source]

Pass through uncertainty head to evaluate probability of correctness of PoE generated answer. NOTE: only for multiple choice

Parameters:
  • prompt_text – the user prompt as a string.

  • full_answer_text – the full generated PoE response as a string.

  • witness_hidden_state – the hidden state (just before projection to vocab size) of the

witness head at the final index in the prompt (from which the multiple choice answer is sampled).

class olmo_tap.inference.poe.PoEOutput(output_parts: list[str], original_tokens: list[str], resampled_idxs: list[int], token_entropies: list[float], uncertainty: float | None, stability_radii: list[int], stability_margins: list[float], validity_radii: list[int], suppression_scores: list[float])[source]

Bases: object

original_tokens: list[str]
output_parts: list[str]
resampled_idxs: list[int]
stability_margins: list[float]
stability_radii: list[int]
suppression_scores: list[float]
token_entropies: list[float]
uncertainty: float | None
validity_radii: list[int]