Source code for kernel_entropy.nli

"""
ModernBERT NLI scoring for Kernel Language Entropy.

Computes pairwise semantic similarity between LLM generations using
Natural Language Inference. Produces the similarity matrix W for KLE calculation.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from transformers import TokenizersBackend

# TYPE_CHECKING is False at runtime, True during static analysis.
# This lets us import types for hints without requiring transformers
# to be installed when the module is imported in non-CUDA environments.
if TYPE_CHECKING:
    from transformers.models.modernbert.modeling_modernbert import (
        ModernBertForSequenceClassification as AutoModelType,
    )

# Type alias for raw probability data from NLI scoring
RawProbabilities = dict[tuple[int, int], dict[str, dict[str, float]]]

# HuggingFace repo id; downloaded + cached on first use.
DEFAULT_MODEL_ID = "tasksource/ModernBERT-large-nli"

# ModernBERT-large-nli label indices (from config.json id2label)
# 0: entailment, 1: neutral, 2: contradiction
LABEL_ENTAILMENT = 0
LABEL_NEUTRAL = 1
LABEL_CONTRADICTION = 2


[docs] class ModernBERTScorer: """ Pairwise NLI scoring using ModernBERT-large-nli. Computes similarity matrix W for Kernel Language Entropy. """ # Class-level model singleton - loaded once, shared across instances _model: AutoModelType | None = None _tokenizer: TokenizersBackend | None = None def __init__( self, sentences: list[str], model_id: str = DEFAULT_MODEL_ID, model: AutoModelType | None = None, tokenizer: TokenizersBackend | None = None, ) -> None: """ Prepare NLI scorer with sentences. Args: sentences: List of N sentences to compare model_id: HuggingFace repo id or local path (default: tasksource/ModernBERT-large-nli) model: Pre-loaded ModernBertForSequenceClassification. When provided together with tokenizer, the class-level singleton is set directly and no HF download or CUDA check is performed. tokenizer: Pre-loaded tokenizer paired with model. Raises: RuntimeError: If CUDA not available (only when model/tokenizer not injected) """ self.sentences = sentences self._model_id = model_id if model is not None and tokenizer is not None: # TODO: does the job but ugly - might be better way type(self)._model = model type(self)._tokenizer = tokenizer else: self._validate_environment() self._ensure_model_loaded(model_id) def _validate_environment(self) -> None: """Check CUDA availability. Raises on failure.""" if not torch.cuda.is_available(): raise RuntimeError( "CUDA not available. NLI scoring requires GPU.\n" "Use: pixi run -e cuda <command>" ) @classmethod def _ensure_model_loaded(cls, model_id: str) -> None: """Load model on first instantiation (class-level singleton).""" if cls._model is not None: return from transformers import AutoModelForSequenceClassification, AutoTokenizer print(f"Loading ModernBERT NLI model from {model_id}...") tokenizer = AutoTokenizer.from_pretrained(model_id) if not isinstance(tokenizer, TokenizersBackend): raise RuntimeError(f"Tokenizer for {model_id} is not a TokenizersBackend") cls._tokenizer = tokenizer cls._model = ( AutoModelForSequenceClassification.from_pretrained(model_id).cuda().eval() ) print("ModernBERT NLI model loaded!") @staticmethod def _kle_score(probs: torch.Tensor, idx: int) -> torch.Tensor: """KLE weighting for one NLI direction: entailment=1.0, neutral=0.5, contradiction=0.0.""" return 1.0 * probs[idx, LABEL_ENTAILMENT] + 0.5 * probs[idx, LABEL_NEUTRAL]
[docs] def get_nli_probabilities(self, nli_inputs: list[tuple[str, str]]) -> torch.Tensor: """Get raw NLI probabilities for given (premise, hypothesis) pairs.""" assert self._tokenizer is not None assert self._model is not None encoded = self._tokenizer( [p[0] for p in nli_inputs], # premises [p[1] for p in nli_inputs], # hypotheses padding=True, truncation=True, max_length=512, return_tensors="pt", ) input_ids = encoded["input_ids"].cuda() # Padding is added to make all sequences the same length in a batch. attention_mask = encoded["attention_mask"].cuda() with torch.no_grad(): outputs = self._model(input_ids=input_ids, attention_mask=attention_mask) probs = torch.softmax(outputs.logits, dim=-1) return probs
[docs] def compute( self, verbose: bool = False ) -> "torch.Tensor | tuple[torch.Tensor, RawProbabilities]": """ Compute pairwise similarity matrix W. For each pair (i, j) where i < j, computes: W[i,j] = W[j,i] = weighted(NLI(i->j)) + weighted(NLI(j->i)) Args: verbose: If True, returns (W, raw_probabilities) tuple Returns: N x N symmetric similarity matrix W with W[i,j] in [0, 2], diagonal = 0. If verbose=True, returns (W, raw_probabilities) tuple. """ n = len(self.sentences) # Handle edge cases if n == 0: return torch.zeros((0, 0), device="cuda", dtype=torch.float32) if n == 1: return torch.zeros((1, 1), device="cuda", dtype=torch.float32) # Generate pairs (i, j) where i < j, plus both NLI directions pair_indices: list[tuple[int, int]] = [] # unique pairs (i < j) nli_inputs: list[tuple[str, str]] = [] # (premise, hypothesis) for batch identical_pairs: set[tuple[int, int]] = set() for i in range(n): for j in range(i + 1, n): # j > i only if self.sentences[i] == self.sentences[j]: identical_pairs.add((i, j)) else: pair_indices.append((i, j)) # Both directions for asymmetric NLI nli_inputs.append((self.sentences[i], self.sentences[j])) # i -> j nli_inputs.append((self.sentences[j], self.sentences[i])) # j -> i # Initialize symmetric matrix with zeros on GPU W = torch.zeros((n, n), device="cuda", dtype=torch.float32) # Identical pairs get max similarity (1.0 + 1.0 = 2.0) for i, j in identical_pairs: W[i, j] = 2.0 W[j, i] = 2.0 raw_probabilities: RawProbabilities = {} if not nli_inputs: # No non-identical pairs, return W (or W, raw_probabilities if verbose) if verbose: return W, raw_probabilities return W # Batch inference for non-identical pairs print(f"Computing pairwise similarities ({len(nli_inputs)} inferences)...") probs = self.get_nli_probabilities(nli_inputs) # Each pair has 2 consecutive NLI results: [i->j, j->i] for pair_idx, (i, j) in enumerate(pair_indices): idx_ij = pair_idx * 2 # i -> j idx_ji = pair_idx * 2 + 1 # j -> i score_ij = self._kle_score(probs, idx_ij) score_ji = self._kle_score(probs, idx_ji) # W[i,j] = score(i->j) + score(j->i), symmetric W[i, j] = score_ij + score_ji W[j, i] = W[i, j] if verbose: raw_probabilities[(i, j)] = { "i_to_j": { "entailment": probs[idx_ij, LABEL_ENTAILMENT].item(), "neutral": probs[idx_ij, LABEL_NEUTRAL].item(), "contradiction": probs[idx_ij, LABEL_CONTRADICTION].item(), }, "j_to_i": { "entailment": probs[idx_ji, LABEL_ENTAILMENT].item(), "neutral": probs[idx_ji, LABEL_NEUTRAL].item(), "contradiction": probs[idx_ji, LABEL_CONTRADICTION].item(), }, } if verbose: return W, raw_probabilities return W
[docs] def compute_against_baseline(self, baseline_idx: int = 0) -> "torch.Tensor": """Compute KLE similarity between sentences[baseline_idx] and every other sentence. Runs exactly 2*(N-1) NLI inferences in one forward pass - only the pairs involving the baseline, not the full pairwise matrix. Returns: 1-D tensor of length N where result[j] is the bidirectional KLE score between sentences[baseline_idx] and sentences[j], and result[baseline_idx] = 0. """ n = len(self.sentences) result = torch.zeros(n, device="cuda", dtype=torch.float32) if n <= 1: return result baseline_sentence = self.sentences[baseline_idx] other_indices = [j for j in range(n) if j != baseline_idx] nli_inputs: list[tuple[str, str]] = [] for j in other_indices: other_sentence = self.sentences[j] if baseline_sentence == other_sentence: # Identical pairs get max similarity (1.0 + 1.0 = 2.0) result[j] = 2.0 else: nli_inputs.append((baseline_sentence, other_sentence)) # baseline -> j nli_inputs.append((other_sentence, baseline_sentence)) # j -> baseline if not nli_inputs: return result print(f"Computing baseline similarities ({len(nli_inputs)} inferences)...") probs = self.get_nli_probabilities(nli_inputs) pair_pos = 0 for j in other_indices: if self.sentences[j] == baseline_sentence: continue score_forward = self._kle_score(probs, pair_pos) score_backward = self._kle_score(probs, pair_pos + 1) result[j] = score_forward + score_backward pair_pos += 2 return result