"""
FastAPI entrypoint for the Trustworthy Answer Protocol (TAP) backend.
Wires together the four scoring stages exposed by ``/api/analyse``:
1. **Generation** -- :func:`app.backend.hydra_inference.generate` runs the Hydra
PoE ensemble; if the ensemble is unavailable the request falls through to
the HF Inference API via :func:`call_hf_model`.
2. **Security** -- per-token PoE acceptance / verifier-ensemble entropy /
stability radii are returned by :func:`app.backend.hydra_inference.generate`
and packaged via :func:`app.backend.response_payloads.poe_security`.
3. **Uncertainty** -- ``p_correct`` from the uncertainty head for MCQ; for NLP
we run :data:`olmo_tap.constants.KLE_N_SAMPLES` extra samples and convert
their NLI similarity matrix into a Kernel Language Entropy certainty score.
4. **Robustness** -- :func:`app.backend.hydra_inference.get_robustness` retries
the prompt with each adversarial suffix in :data:`ADV_SUFFIXES` and reports
how many flipped the answer.
The two heavyweight models (Hydra + ModernBERT-NLI) are loaded once during
the FastAPI lifespan and stashed in module-level dicts so request handlers
can grab them without re-loading. On Modal the ``@modal.enter()`` hook in
:mod:`app.backend.modal_app` preloads Hydra into the same dicts before the
ASGI app boots; the lifespan detects this and skips the duplicate load.
"""
import logging
import os
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import InferenceClient
from pydantic import BaseModel
from transformers import TokenizersBackend
from app.backend.adversarial_suffixes import ADV_SUFFIXES, N_ADV_SUFFIXES
from app.backend.bert_inference import load_bert
from app.backend.claim_confidence import compute_claim_confidences
from app.backend.claim_splitter import decompose_into_claims
from app.backend.constants import HF_FALLBACK_MODEL as HF_MODEL, HF_TOKEN
from app.backend.hydra_inference import (
generate,
get_robustness,
load_hydra,
MODEL_NAME,
)
from app.backend.question_classifier import detect_mcq_bert
from app.backend.response_payloads import (
fallback_robustness,
fallback_security,
fallback_uncertainty,
poe_security,
poe_uncertainty,
)
from kernel_entropy.entropy import kle_from_similarity, kle_to_certainty
from kernel_entropy.nli import ModernBERTScorer
from olmo_tap.constants import KLE_HEAT_KERNEL_T, KLE_N_SAMPLES
from olmo_tap.hydra import HydraTransformer
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s"
)
logger = logging.getLogger(__name__)
_models: dict[str, Any | None] = {}
_tokenizers: dict[str, TokenizersBackend | None] = {}
_device: str = "cuda"
[docs]
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI lifespan that loads (or re-uses preloaded) Hydra and BERT models.
On Modal the ``@modal.enter()`` hook in :mod:`app.backend.modal_app` has
already populated ``_models["hydra"]`` before this runs; the duplicate
load would otherwise add ~30s of cold-start. BERT is always loaded here
because Modal's preload only warms Hydra.
:param app: FastAPI application instance (unused but required by the
lifespan protocol).
"""
global _device
_device = os.getenv("DEVICE", "cuda")
logger.info("Starting up - device=%s", _device)
# Modal's @modal.enter() may have already preloaded; skip to avoid a ~30s double-load.
if "hydra" not in _models:
_models["hydra"], _tokenizers["hydra"] = load_hydra(device=_device)
if _models["hydra"] is None:
logger.warning("Hydra unavailable; requests will fall back to HF API")
else:
logger.info("Hydra already preloaded; skipping lifespan load")
if "bert" not in _models:
_models["bert"], _tokenizers["bert"] = load_bert(device=_device)
if _models["bert"] is None:
logger.warning("BERT unavailable; NLI-based metrics will be skipped")
else:
logger.info("BERT already preloaded; skipping lifespan load")
yield
logger.info("Shutting down")
_models.clear()
_tokenizers.clear()
app = FastAPI(title="Trustworthy Answer Protocol - API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "https://tap-al9.pages.dev"],
# Cloudflare Pages preview/PR deployments: <hash-or-branch>.tap-al9.pages.dev
allow_origin_regex=r"^https://[a-z0-9-]+\.tap-al9\.pages\.dev$",
allow_methods=["*"],
allow_headers=["*"],
)
[docs]
class Message(BaseModel):
"""
Single chat-completions message.
:param role: One of ``"system"``, ``"user"``, ``"assistant"`` (matches the
OpenAI / HF chat-completions schema).
:param content: Raw text content for that role.
"""
role: str
content: str
[docs]
class ChatRequest(BaseModel):
"""
Request body for :func:`analyse`.
:param messages: Multi-turn chat history, oldest message first. The last
element must have ``role == "user"`` and is the prompt that gets
scored for uncertainty / security / robustness.
"""
messages: list[Message]
[docs]
def call_hf_model(messages: list[dict]) -> str:
"""Call the HF Inference API as a fallback when Hydra is unavailable or bypassed.
Used when ``hf=true`` is passed to ``/api/analyse`` or when ``load_hydra``
failed at lifespan startup. No PoE verification is available in this path;
the security payload from the caller reflects that with ``certified=None``.
"""
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable not set")
client = InferenceClient(HF_MODEL, token=HF_TOKEN)
response = client.chat_completion(messages, max_tokens=500)
return response.choices[0].message.content or ""
def _classify_mcq(last_user_msg: str) -> bool | None:
"""
Classify the latest user message as MCQ or open-ended via BERT NLI.
Wraps :func:`app.backend.question_classifier.detect_mcq_bert` and returns
``None`` when BERT failed to load at startup so the caller can degrade
gracefully (no MCQ system prompt, NLP code paths only).
:param last_user_msg: Raw text of the most recent user turn.
:returns: ``True`` if multiple-choice, ``False`` if open-ended,
``None`` if BERT is unavailable.
"""
bert_model = _models.get("bert")
bert_tokenizer = _tokenizers.get("bert")
if bert_model is None or bert_tokenizer is None:
return None
return detect_mcq_bert(bert_model, bert_tokenizer, last_user_msg, device=_device)
[docs]
@app.post("/api/analyse")
async def analyse(request: ChatRequest, hf: bool = False):
"""
Score a chat prompt for security, uncertainty, robustness and per-claim
confidence.
Generation runs through the Hydra PoE ensemble unless ``hf=True`` is
passed (or the ensemble failed to load), in which case the HF Inference
API is used as a fallback and security / uncertainty / robustness are
returned as ``unavailable``-style payloads.
The claim ledger is independent of the generation backend: it always
decomposes ``raw_response`` and scores each claim with NLI self-entailment
when BERT is available. KLE-based uncertainty for NLP queries is computed
here (not inside ``generate``) because it requires :data:`KLE_N_SAMPLES`
extra forward passes.
:param request: Chat history; the last user turn is the prompt.
:param hf: Force the HF Inference API path even when Hydra is healthy.
Useful for A/B comparisons against the unverified baseline.
:returns: Dict with keys ``claims``, ``overall_confidence``,
``uncertainty``, ``security``, ``robustness``, ``raw_response``,
``model``, ``is_mcq``. See
:mod:`app.backend.response_payloads` for the security/uncertainty/
robustness sub-schemas.
"""
messages = [{"role": m.role, "content": m.content} for m in request.messages]
latest_user_msg = messages[-1]["content"]
logger.info("Latest user message: %s", latest_user_msg)
is_mcq = _classify_mcq(latest_user_msg)
logger.info("BERT MCQ classification: %s", is_mcq)
hydra: HydraTransformer | None = _models.get("hydra")
hydra_tokenizer: TokenizersBackend | None = _tokenizers.get("hydra")
if hf or hydra is None or hydra_tokenizer is None:
model_name = HF_MODEL
raw_response = call_hf_model(messages)
security = fallback_security()
uncertainty = fallback_uncertainty()
robustness = fallback_robustness()
else:
model_name = MODEL_NAME
(
raw_response,
tokens,
resampled,
token_entropies,
p_correct,
stability_radii,
stability_margins,
) = generate(
hydra,
hydra_tokenizer,
messages,
is_mcq=bool(is_mcq),
device=_device,
)
security = poe_security(
tokens, resampled, token_entropies, stability_radii, stability_margins
)
uncertainty = poe_uncertainty(p_correct)
bert_model = _models.get("bert")
bert_tokenizer = _tokenizers.get("bert")
# Uncertainty for NLP
if not is_mcq and bert_model is not None and bert_tokenizer is not None:
try:
kle_responses: list[str] = []
for _ in range(KLE_N_SAMPLES):
raw, _t, _r, _e, _p, _, _ = generate(
hydra,
hydra_tokenizer,
messages,
is_mcq=False,
device=_device,
)
kle_responses.append(raw)
W = ModernBERTScorer(
kle_responses,
model=bert_model,
tokenizer=bert_tokenizer,
).compute()
entropy = kle_from_similarity(W, t=KLE_HEAT_KERNEL_T) # type: ignore[arg-type]
certainty = kle_to_certainty(entropy, KLE_N_SAMPLES)
uncertainty = poe_uncertainty(certainty)
except Exception:
logger.exception("KLE computation failed; falling back")
uncertainty = fallback_uncertainty()
# Robustness
if not is_mcq and (bert_model is None or bert_tokenizer is None):
robustness = fallback_robustness()
else:
robustness = get_robustness(
hydra,
hydra_tokenizer,
list(messages),
original_resp=raw_response,
original_tokens=tokens,
is_mcq=bool(is_mcq),
adv_suffix_bank=ADV_SUFFIXES[:N_ADV_SUFFIXES],
bert_model=bert_model,
bert_tokenizer=bert_tokenizer,
device=_device,
)
logger.info("Generation complete (%d chars)", len(raw_response))
bert_model = _models.get("bert")
bert_tokenizer = _tokenizers.get("bert")
claims: list[dict] = []
overall: float | None = None
if bert_model is not None and bert_tokenizer is not None:
try:
claims_text = decompose_into_claims(raw_response)
metrics_list = compute_claim_confidences(
raw_response, claims_text, bert_model, bert_tokenizer
)
claims = [
{
"text": text,
"confidence": m["confidence"],
"confidence_level": m["level"],
"guidance": m["guidance"],
}
for text, m in zip(claims_text, metrics_list)
]
if metrics_list:
overall = round(
sum(m["confidence"] for m in metrics_list) / len(metrics_list), 2
)
except Exception:
logger.exception("Claim ledger unavailable; returning empty claims")
claims = []
overall = None
return {
"claims": claims,
"overall_confidence": overall,
"uncertainty": uncertainty,
"security": security,
"robustness": robustness,
"raw_response": raw_response,
"model": model_name,
"is_mcq": is_mcq,
}
[docs]
@app.get("/api/health")
async def health():
"""
Lightweight liveness probe used by Cloudflare Pages, uptime checks and
Modal's health monitor.
:returns: ``{"status": "ok"}``. Does not touch model state, so a 200 here
only means the ASGI process is up; readiness for Hydra requests is
implicit in successful ``/api/analyse`` calls.
"""
return {"status": "ok"}
if __name__ == "__main__":
# Local smoke test:
# pixi run -e cuda python -m app.backend.server
# Hits the in-process ASGI app via TestClient so no uvicorn / port is
# needed. The lifespan still fires, so this exercises the same model-
# loading path as a real ``modal serve`` deployment.
from fastapi.testclient import TestClient
with TestClient(app) as client:
health_resp = client.get("/api/health")
print("Health:", health_resp.json())
analyse_resp = client.post(
"/api/analyse",
json={
"messages": [
{"role": "user", "content": "Is paracetamol safe in pregnancy?"}
]
},
)
body = analyse_resp.json()
print("Model:", body["model"])
print("Is MCQ:", body["is_mcq"])
print("Response:", body["raw_response"][:200])