app.backend.hydra_inference

Thin adapters that bridge the FastAPI request layer with olmo_tap.inference.poe.

This module is the single entry point through which the deployed backend talks to the Hydra PoE ensemble. Three concerns are handled here, none of which belong inside the research-side PoE class:

  • Model constructionload_hydra() builds the 10-head ensemble (9 LLM verifiers + 1 uncertainty head) with the production LoRAs merged in and the chat tokenizer attached.

  • System-prompt routing – MCQ vs NLP prompts get different system messages and token budgets (MCQ_SYSTEM_PROMPT / NLP_SYSTEM_PROMPT).

  • Output reshaping_tokens_and_resamples_from_poe_output() trims trailing EOS, strips token whitespace and converts the rejection bookkeeping into the per-token records the frontend renders.

A separate get_robustness() function reuses generate() to retry the prompt under a bank of adversarial suffixes and reports how many flipped the answer (MCQ: first-token diff; NLP: NLI similarity below NLP_ROBUSTNESS_THRESHOLD).

Functions

generate(model, tokenizer, messages, is_mcq)

Generate a PoE response via speculative verification.

get_robustness(model, tokenizer, messages, ...)

Score adversarial robustness by retrying the prompt with each suffix.

load_hydra([device])

Load the production Hydra ensemble and its tokenizer.

app.backend.hydra_inference.generate(model: HydraTransformer, tokenizer: PreTrainedTokenizerBase, messages: list[dict], is_mcq: bool, device: str = 'cuda') tuple[str, list[str], list[dict], list[float], float | None, list[int], list[float]][source]

Generate a PoE response via speculative verification.

Both MCQ and NLP prompts run through olmo_tap.inference.poe.PoE.generate_with_cache(). When is_mcq is true, a short system nudge is prepended and the token budget is capped at MCQ_MAX_NEW_TOKENS so the model leads with the chosen option and may add a brief explanation. The PoE class captures a witness hidden state on the first accepted/rejected token and returns a p_correct scalar from a dedicated uncertainty head.

Parameters:
  • model – The Hydra ensemble loaded by load_hydra().

  • tokenizer – Matching chat tokenizer.

  • messages – Multi-turn chat history (oldest first); the system prompt is prepended automatically.

  • is_mcq – Selects the MCQ prompt + token budget and turns on the uncertainty-head pass (p_correct is computed only for MCQ).

  • device – Torch device. Currently ignored downstream because PoE hardcodes cuda; kept for forward compatibility.

Returns:

7-tuple (raw_response, tokens, resampled, token_entropies, uncertainty, stability_radii, stability_margins).

  • raw_response: decoded response text (no system / chat tags).

  • tokens: list of decoded single-token strings (whitespace stripped).

  • resampled: list of dicts (one per rejected draft token) with keys index, old_token, new_token, severity, validity_radius, suppression_score.

  • token_entropies: verifier-ensemble predictive entropy (nats), parallel to tokens.

  • uncertainty: p_correct float for MCQ, None for NLP.

  • stability_radii / stability_margins: per-token stability metrics, parallel to tokens.

app.backend.hydra_inference.get_robustness(model: HydraTransformer, tokenizer: PreTrainedTokenizerBase, messages: list[dict], original_resp: str, original_tokens: list[str], is_mcq: bool, adv_suffix_bank: list[str], bert_model: Any, bert_tokenizer: Any, device: str = 'cuda') dict[source]

Score adversarial robustness by retrying the prompt with each suffix.

For every suffix in adv_suffix_bank the original prompt is rerun through generate() with the suffix appended to the final user turn. The criterion for “flipped” depends on the prompt type:

  • MCQ: flip iff the first emitted token differs from the clean answer (the system prompt forces the chosen option to be the first token, see MCQ_SYSTEM_PROMPT).

  • NLP: flip iff NLI similarity between clean and adversarial response is at or below NLP_ROBUSTNESS_THRESHOLD. The full (N+1) responses are scored in a single batched kernel_entropy.nli.ModernBERTScorer call against the clean baseline (2*(N-1) inferences instead of full pairwise).

The “worst-case” entry surfaced for the UI preview panel is the first flipped suffix (MCQ) or the suffix with lowest NLI similarity (NLP).

Parameters:
  • model – Hydra ensemble.

  • tokenizer – Matching chat tokenizer.

  • messages – Multi-turn chat history. The final user message has each suffix appended in turn; the original list is mutated only by the initial pop().

  • original_resp – The clean (unsuffixed) response, used as the NLI reference for NLP flips.

  • original_tokens – Decoded tokens of the clean response, used for the MCQ first-token comparison.

  • is_mcq – Selects the MCQ vs NLP flip criterion.

  • adv_suffix_bank – Suffix strings to test, typically the top-k from app.backend.adversarial_suffixes.ADV_SUFFIXES.

  • bert_model – ModernBERT-NLI model for the NLP path. Unused for MCQ.

  • bert_tokenizer – Matching tokenizer for bert_model.

  • device – Torch device passed through to generate().

Returns:

Dict {"type", "attempts", "flipped", "worst_case"} where worst_case is None if no suffix flipped (MCQ) or the bank was empty (NLP).

app.backend.hydra_inference.load_hydra(device: str = 'cuda') tuple[HydraTransformer, TokenizersBackend] | tuple[None, None][source]

Load the production Hydra ensemble and its tokenizer.

The model has 10 heads (9 LLM verifiers + 1 uncertainty head) with the security and robustness LoRAs already merged into the LLM heads. Returns (None, None) instead of raising when WEIGHTS_DIR is missing or the underlying load_ensemble call fails, so the FastAPI lifespan can fall back to the HF Inference API path without crashing the process.

Parameters:

device – Torch device for the loaded weights. Note: PoE itself currently hardcodes cuda; this argument is honoured by the loader but ignored downstream.

Returns:

(model, tokenizer) on success, (None, None) on failure.