Source code for app.backend.bert_inference

"""
ModernBERT-large-NLI loader shared by every NLI-based scorer in the backend.

The same model is reused for three distinct jobs at request time:

- **MCQ classification** (:func:`app.backend.question_classifier.detect_mcq_bert`)
- **Per-claim self-entailment** (:func:`app.backend.claim_confidence.compute_claim_confidences`)
- **Kernel Language Entropy + robustness similarity matrices**
  (:class:`kernel_entropy.nli.ModernBERTScorer`)

Loading is wrapped in try/except so a missing :data:`HF_CACHE_DIR` or HF
network blip downgrades to ``(None, None)`` rather than crashing the
FastAPI lifespan.
"""

import logging
import time

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TokenizersBackend,
)

from app.backend.constants import HF_CACHE_DIR

logger = logging.getLogger(__name__)

MODEL_ID = "tasksource/ModernBERT-large-nli"


[docs] def load_bert(device: str = "cuda"): """ Load ModernBERT-large-NLI from the Modal-mounted HF cache. The Modal volume is populated once by :func:`app.backend.modal_app.download_weights`; locally :data:`HF_CACHE_DIR` may point at the user's HF cache. :param device: Torch device for the loaded model. :returns: ``(model, tokenizer)`` on success, ``(None, None)`` if the cache is missing or the download fails. The FastAPI lifespan logs and continues in degraded mode when this returns ``None``. """ logger.info("Loading BERT for NLI-based metrics. Using cache dir: %s", HF_CACHE_DIR) t0 = time.perf_counter() try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=HF_CACHE_DIR) if not isinstance(tokenizer, TokenizersBackend): logger.error("Tokenizer is not a TokenizersBackend; aborting model load") return None, None model = ( AutoModelForSequenceClassification.from_pretrained( MODEL_ID, cache_dir=HF_CACHE_DIR ) .to(device) .eval() ) except Exception as e: logger.error("BERT unavailable: %s", e) return None, None logger.info("BERT loaded in %.1fs", time.perf_counter() - t0) return model, tokenizer