Source code for app.backend.question_classifier
"""
Zero-shot MCQ vs open-ended classifier built on the ModernBERT-NLI model.
The same NLI checkpoint loaded by :func:`app.backend.bert_inference.load_bert`
is reused here as a zero-shot text classifier: each candidate class is
encoded as a hypothesis and the class with the highest entailment logit
wins. The result drives the system-prompt and token-budget routing inside
:func:`app.backend.hydra_inference.generate`.
"""
from enum import StrEnum
import torch
from transformers import TokenizersBackend
from transformers.models.modernbert.modeling_modernbert import (
ModernBertForSequenceClassification,
)
[docs]
class QuestionType(StrEnum):
"""Possible classification outcomes for a user prompt."""
MCQ = "mcq"
OPEN = "open"
_BERT_HYPOTHESES: dict[QuestionType, str] = {
QuestionType.MCQ: "This is a multiple choice question",
QuestionType.OPEN: "This is not a multiple-choice question",
}
[docs]
def detect_mcq_bert(
model: ModernBertForSequenceClassification,
tokenizer: TokenizersBackend,
text: str,
device: str = "cuda",
) -> bool:
"""
Zero-shot MCQ classification by entailment scoring.
For each :class:`QuestionType` we score ``(text, hypothesis)`` through
the NLI head and read off the entailment logit. Returns ``True`` iff
the MCQ hypothesis scores higher than the open-ended one.
:param model: ModernBERT-NLI model from
:func:`app.backend.bert_inference.load_bert`.
:param tokenizer: Matching tokenizer.
:param text: User prompt.
:param device: Torch device for the forward pass.
:returns: ``True`` if the prompt looks like a multiple-choice question.
"""
# should be {'contradiction': 2, 'entailment': 0, 'neutral': 1}
if (label_id_map := model.config.label2id) is not None:
entailment_idx = label_id_map.get("entailment", 0)
else:
entailment_idx = 0
scores: dict[QuestionType, float] = {}
with torch.no_grad():
for label, hypothesis in _BERT_HYPOTHESES.items():
inputs = tokenizer(
text, hypothesis, return_tensors="pt", truncation=True, max_length=512
).to(device)
logits = model(**inputs).logits # (1, num_labels)
scores[label] = logits[0, entailment_idx].item()
return max(scores, key=lambda k: scores[k]) == QuestionType.MCQ