Source code for kernel_entropy.generation
"""
PoE text generation for Kernel Language Entropy.
Loads the PoE ensemble (9 LLM heads with prod + robustness LoRA merged, plus a
dormant uncertainty head) once and produces N diverse responses for a single
prompt. Each sample is drawn from the full PoE jury in pure-generation mode
(is_mcq=False); per-sample seeding of the torch RNG makes draft-head picks and
multinomial draws reproducible.
"""
from __future__ import annotations
from typing import cast
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from olmo_tap.constants import WEIGHTS_DIR
from olmo_tap.inference.loading_weights import load_ensemble
from olmo_tap.inference.poe import PoE
[docs]
class HydraGenerator:
"""
PoE-backed batched generation for KLE.
``generate_batch`` produces one response per seed by calling
``PoE.generate_with_cache`` in pure-generation mode (``is_mcq=False``).
"""
def __init__(
self,
gamma: int = 4,
beta: float = 1.0,
max_new_tokens: int = 200,
) -> None:
if not WEIGHTS_DIR:
raise ValueError(
"OLMO_WEIGHTS_DIR not set; needed to load the OLMo2 tokenizer."
)
tokenizer = cast(
PreTrainedTokenizerBase, AutoTokenizer.from_pretrained(WEIGHTS_DIR)
)
model, n_heads = load_ensemble()
self._poe = PoE(
model=model,
tokenizer=tokenizer,
n_llm_heads=n_heads - 1,
gamma=gamma,
beta=beta,
max_new_tokens=max_new_tokens,
)
[docs]
def generate_batch(
self,
prompt: str,
seeds: list[int],
temperature: float = 0.98,
verbose: bool = False,
) -> list[str]:
"""
Generate one response per seed.
Seeds the torch RNG before each PoE call so the draft-head pick and
multinomial draws inside ``generate_with_cache`` are reproducible.
Forks the RNG so per-seed seeding does not leak into caller state.
"""
if not seeds:
return []
# TODO: upstream fix — poe.py:159,180 decode tokens without
# skip_special_tokens, so EOS leaks into the last element of `parts`
# when the loop terminates on EOS. Stripped locally here; proper fix
# belongs in poe.py but would affect the frontend, so separate PR.
eos_surface = self._poe.tokenizer.eos_token or ""
responses: list[str] = []
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
for seed in tqdm(seeds, desc="PoE generations"):
torch.manual_seed(seed)
poe_out = self._poe.generate_with_cache(
prompt_text=prompt, is_mcq=False, temperature=temperature
)
response = "".join(poe_out.output_parts[1:])
if eos_surface and response.endswith(eos_surface):
response = response[: -len(eos_surface)]
response = response.strip()
if verbose:
print(f"\n--- Response {len(responses) + 1} (seed={seed}) ---")
print(response)
responses.append(response)
return responses