"""
Inference latency / throughput benchmark for vanilla OLMo, naive-averaging
Hydra, and Hydra + PoE (Product of Experts speculative decode).
Three configurations are timed back-to-back on the same GPU so the numbers are
directly comparable:
1. ``baseline`` — vanilla OLMo (1B or 7B), random weights. Fastest per
token; sets the upper bound any Hydra variant can approach.
2. ``hydra_naive`` — HydraTransformer run with all heads in series and logits
averaged (the default ``forward_and_sample`` codepath). Random weights.
Strictly more compute than (1) — every step pays trunk + N head forwards.
3. ``hydra_poe`` — Hydra wrapped in :class:`olmo_tap.inference.poe.PoE`,
running γ-step speculative decode with one draft head + (N-1)-head verifier
jury (Product of Experts: log-probs are summed across verifier heads,
i.e. multiplied in probability space). Real merged-LoRA weights via
:func:`olmo_tap.inference.loading_weights.load_ensemble` so acceptance
rate is meaningful. Optionally swept across γ values.
Knobs live in ``config.json`` next to this file. The ``baseline`` and
``hydra_naive`` rows use random weights because per-step latency is
architecture-bound — only PoE needs real weights (acceptance rate depends on
the actual model behavior).
Usage::
pixi run -e cuda python -m olmo_tap.benchmarks.inference
Output lands in ``olmo_tap/benchmarks/results/<YYYY-MM-DD>_run<NN>/``:
* ``results.json`` — raw timings + per-position decode stats + PoE per-γ stats.
* ``graph.png`` — TTFT KDE, per-position decode latency, per-position TPS.
NOTE:
PoE is reported as two numbers per γ. ``ttft`` is the prefill cost — one
``residual_forward`` over all draft + verifier heads, the same call PoE
issues at the top of every ``generate_with_cache``. ``per_gamma[γ]`` is the
end-to-end call timing: ``median_ms`` is the wall time of a full
``max_new_tokens``-token generation, and ``effective_tps`` is
``1000 * accepted_tokens / median_ms``, i.e. draft-acceptance throughput.
Under non-zero resample rate the user-visible TPS is slightly higher because
resampled positions are still real output tokens; see :func:`benchmark_poe`
docstring for exact definitions.
"""
import json
import re
from datetime import datetime
from pathlib import Path
from typing import cast, Any
import torch
from olmo_core.nn.attention import AttentionBackendName, Attention, KVCacheManager
from olmo_core.nn.transformer import Transformer, TransformerBlock
from olmo_core.nn.transformer.config import TransformerConfig
from olmo_tap.constants import VOCAB_SIZE
from olmo_tap.hydra import HydraTransformer, HydraTransformerConfig
[docs]
def build_hydra_model(
n_heads=5, heads_depth=3, dtype=torch.bfloat16, device="cuda", size="1b"
):
"""Build a HydraTransformer for naive-averaging benchmarks (random weights).
:param n_heads: Number of parallel heads in the Hydra branching.
:param heads_depth: Layers per head (trunk gets ``n_layers - heads_depth``).
:param dtype: Compute dtype.
:param device: Target device.
:param size: ``"1b"`` or ``"7b"`` — picks the OLMo factory.
:returns: A built, ``.eval()``-mode HydraTransformer with random weights.
NOTE: weights are random — latency is architecture-bound at this
granularity. Only the PoE row needs real weights (for acceptance-rate
realism); see :func:`build_poe`.
"""
if size == "7b":
config = HydraTransformerConfig.from_olmo2_7B(
n_heads=n_heads, heads_depth=heads_depth
)
else:
config = HydraTransformerConfig.from_olmo2_1B(
n_heads=n_heads, heads_depth=heads_depth
)
model = config.build(init_device="cpu")
model.to(device=device, dtype=dtype)
model.eval()
return model
[docs]
def init_kv_cache(model, batch_size, max_seq_len):
if isinstance(model, HydraTransformer):
model.init_kv_cache(batch_size, max_seq_len)
else:
for block in model.blocks.values():
block.attention.init_kv_cache_manager(batch_size, max_seq_len)
[docs]
def build_baseline_model(dtype=torch.bfloat16, device="cuda", size="1b"):
"""Build a vanilla single-tower OLMo Transformer (random weights).
:param dtype: Compute dtype.
:param device: Target device.
:param size: ``"1b"`` or ``"7b"``.
:returns: A built, ``.eval()``-mode Transformer using the FlashAttention-2
backend, with random weights (latency is architecture-bound).
"""
if size == "7b":
config = TransformerConfig.olmo2_7B(vocab_size=VOCAB_SIZE)
else:
config = TransformerConfig.olmo2_1B_v2(vocab_size=VOCAB_SIZE)
config.block.sequence_mixer.backend = AttentionBackendName.flash_2 # type: ignore
model = config.build(init_device="cpu")
model.to(device=device, dtype=dtype)
model.eval()
return model
[docs]
def build_poe(cfg, dtype=torch.bfloat16):
"""Build a :class:`PoE` (Product of Experts speculative decoder) with real
merged-LoRA weights from :func:`load_ensemble`.
Unlike the baseline / naive-Hydra rows (which use random weights), PoE is
built on the deployed 7B Hydra with security + robustness LoRAs merged per
head plus the uncertainty head, because PoE's acceptance rate depends on
the actual logit distributions the verifier heads produce. Random weights
would give meaningless accept/reject behavior.
:param cfg: Parsed ``config.json`` dict. Reads ``poe_gamma`` (overridden
per-call when sweeping), ``poe_beta``, ``poe_max_new_tokens``.
:param dtype: Compute dtype to cast the loaded ensemble to.
:returns: ``(poe, model)`` — the :class:`PoE` wrapper plus the underlying
:class:`HydraTransformer` (returned so the caller can ``del`` it for
VRAM cleanup before the next config).
"""
from olmo_tap.constants import WEIGHTS_DIR
from olmo_tap.inference.loading_weights import load_ensemble
from olmo_tap.inference.poe import PoE
from transformers import AutoTokenizer
model, n_heads = load_ensemble()
model.to(dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained(WEIGHTS_DIR)
assert tokenizer is not None, (
f"AutoTokenizer.from_pretrained({WEIGHTS_DIR}) returned None"
)
poe = PoE(
model=model,
tokenizer=tokenizer,
n_llm_heads=n_heads - 1, # last head is uncertainty
gamma=cfg["poe_gamma"],
beta=cfg["poe_beta"],
max_new_tokens=cfg["poe_max_new_tokens"],
)
return poe, model
[docs]
def get_all_kv_cache_managers(model) -> list[KVCacheManager | None]:
managers = []
if isinstance(model, HydraTransformer):
for block in model.trunk.blocks.values():
attn = cast(TransformerBlock, block).attention
if isinstance(attn, Attention):
managers.append(attn.kv_cache_manager)
for head in model.heads:
for block in cast(Transformer, head).blocks.values():
attn = cast(TransformerBlock, block).attention
if isinstance(attn, Attention):
managers.append(attn.kv_cache_manager)
else:
for block in cast(Transformer, model).blocks.values():
attn = cast(TransformerBlock, block).attention
if isinstance(attn, Attention):
managers.append(attn.kv_cache_manager)
return managers
[docs]
def reset_kv_cache_position(managers: list[KVCacheManager | None], position):
for m in managers:
if m is not None:
m.cache_seqlens.fill_(position)
[docs]
def forward_and_sample(model, input_ids):
"""Run a forward pass and argmax-sample the next token.
Skips per-position ``lm_head`` projection on prefill (``last_token_only``
on Hydra, ``logits_to_keep=1`` on vanilla Transformer) so TTFT timings
are apples-to-apples with PoE's ``residual_forward`` path. Decode steps
have ``seq_len=1``, so the flag is a no-op there.
"""
if isinstance(model, HydraTransformer):
logits = model(input_ids, last_token_only=True)
return logits[:, 0, -1, :].mean(dim=0).argmax()
else:
logits = model(input_ids, logits_to_keep=1)
return logits[0, -1, :].argmax()
[docs]
def benchmark_ttft(model, prompt_ids, warmup_ms=100.0, rep_ms=1000.0):
from olmo_tap.benchmarks.harness import (
benchmark,
compute_stats,
filter_outliers_iqr,
)
managers = get_all_kv_cache_managers(model)
def setup():
reset_kv_cache_position(managers, 0)
def fn():
forward_and_sample(model, prompt_ids)
raw = benchmark(fn, warmup_ms, rep_ms, flush_l2=True, setup=setup)
filtered = filter_outliers_iqr(raw)
return {
"raw_ms": raw,
"filtered_ms": filtered,
**compute_stats(filtered),
}
[docs]
def benchmark_decode(
model, prompt_ids, gen_length=128, step_interval=8, warmup_ms=100.0, rep_ms=1000.0
):
from olmo_tap.benchmarks.harness import (
benchmark,
compute_stats,
filter_outliers_iqr,
)
managers = get_all_kv_cache_managers(model)
positions = list(range(0, gen_length, step_interval))
results = {"positions": positions, "per_position": {}}
if not managers or managers[0] is None:
return results # TODO: better error handling logic
for m in managers:
if m is not None:
m.zero_cache()
last_token = forward_and_sample(model, prompt_ids).unsqueeze(0).unsqueeze(0)
for step in range(gen_length):
if step in positions:
saved_pos = managers[0].cache_seqlens.item()
def setup(pos=saved_pos):
reset_kv_cache_position(managers, pos)
def fn(tok=last_token):
forward_and_sample(model, tok)
raw = benchmark(fn, warmup_ms, rep_ms, flush_l2=True, setup=setup)
filtered = filter_outliers_iqr(raw)
stats = compute_stats(filtered)
stats["tps"] = round(1000.0 / stats["median_ms"], 2)
results["per_position"][str(saved_pos)] = {
"raw_ms": raw,
"filtered_ms": filtered,
**stats,
}
reset_kv_cache_position(managers, saved_pos)
last_token = forward_and_sample(model, last_token).unsqueeze(0).unsqueeze(0)
return results
[docs]
def benchmark_poe_ttft(poe, prompt_ids, warmup_ms=100.0, rep_ms=1000.0):
"""Time the prefill stage that PoE pays at the start of every generation.
Specifically, times :meth:`HydraTransformer.residual_forward` with
``last_token_only=True`` over all draft + verifier heads — the exact call
issued at the top of :meth:`PoE.generate_with_cache`. KV-cache pointers
are reset to zero in ``setup`` each iteration so every measurement is a
fresh prefill. Uses the same random token IDs as the other TTFT rows for
apples-to-apples comparison; token values don't affect prefill cost since
it's architecture-bound at this granularity.
:param poe: :class:`PoE` instance whose ``.model`` carries an initialised
KV cache (caller must have called ``poe.model.init_kv_cache(...)``).
:param prompt_ids: ``(1, prompt_length)`` token tensor on CUDA.
:param warmup_ms: Warmup budget passed to :func:`harness.benchmark`.
:param rep_ms: Measurement budget.
:returns: TTFT stats dict with the same shape as :func:`benchmark_ttft`.
"""
from olmo_tap.benchmarks.harness import (
benchmark,
compute_stats,
filter_outliers_iqr,
)
llm_heads_indices = list(range(poe.n_llm_heads))
def setup():
poe.model.sync_kv_cache(0, omit_last=True)
def fn():
poe.model.residual_forward(
prompt_ids,
last_token_only=True,
head_indices=llm_heads_indices,
hidden_head_indices=llm_heads_indices,
)
raw = benchmark(fn, warmup_ms, rep_ms, flush_l2=True, setup=setup)
filtered = filter_outliers_iqr(raw)
return {"raw_ms": raw, "filtered_ms": filtered, **compute_stats(filtered)}
[docs]
def benchmark_poe(poe, prompt_text, warmup_ms=100.0, rep_ms=1000.0):
"""Time end-to-end :meth:`PoE.generate_with_cache` and compute throughput
+ acceptance statistics.
PoE is timed at call granularity (one call = ``poe.max_new_tokens`` output
tokens), not per step, because the γ-step draft loop and the verifier
pass don't decompose into a single repeating "fn" the way naive decode
does. :meth:`PoE.generate_with_cache` re-inits its own KV cache on every
call, so the harness ``setup`` callback is left unset.
The reported metrics for a benchmark window:
* ``median_ms`` — median wall time of one full call.
* ``avg_accepted_tokens_per_call`` — mean over calls of
``output_tokens − resampled_tokens`` (tokens the draft head got right).
* ``resample_rate`` — total resampled / total output
tokens; the fraction of output positions where the verifier jury
rejected the draft and a corrected token was sampled. Lower is better.
* ``effective_tps`` — ``1000 * avg_accepted / median_ms``;
**draft-acceptance throughput**, not user-visible throughput. Under
non-zero ``resample_rate`` user-visible TPS is slightly higher because
resampled positions are still real output tokens; the difference is
``resample_rate * effective_tps / (1 - resample_rate)``.
* ``n_calls`` — number of timed full-generation
calls (estimate + warmup + measurement).
:param poe: :class:`PoE` instance. ``poe.gamma`` may be mutated by the
caller between invocations to sweep γ; everything else stays put.
:param prompt_text: A single-turn user message string. Will be wrapped
with ``apply_chat_template`` inside :meth:`generate_with_cache`.
:param warmup_ms: Warmup budget passed to :func:`harness.benchmark`.
:param rep_ms: Measurement budget. Should comfortably exceed one call's
wall time so multiple measurements land — a 64-token PoE call on 7B
is ~3-5 s, so use ``rep_ms >= 8000`` for stable medians.
:returns: Stats dict with the keys described above plus ``raw_ms`` and
``filtered_ms`` (post-IQR-filtered timings).
"""
from olmo_tap.benchmarks.harness import (
benchmark,
compute_stats,
filter_outliers_iqr,
)
accepted_total = [0]
resampled_total = [0]
calls = [0]
def fn():
output_parts, _orig, resampled_idxs, _p = poe.generate_with_cache(
prompt_text, is_mcq=False, temperature=None
)
n_tokens = len(output_parts) - 1
accepted_total[0] += n_tokens - len(resampled_idxs)
resampled_total[0] += len(resampled_idxs)
calls[0] += 1
raw = benchmark(fn, warmup_ms, rep_ms, flush_l2=True)
filtered = filter_outliers_iqr(raw)
stats = compute_stats(filtered)
n_calls = max(calls[0], 1)
avg_accepted = accepted_total[0] / n_calls
total_tokens = accepted_total[0] + resampled_total[0]
resample_rate = resampled_total[0] / total_tokens if total_tokens > 0 else 0.0
effective_tps = (
round(1000.0 * avg_accepted / stats["median_ms"], 2)
if stats["median_ms"] > 0
else 0.0
)
return {
"raw_ms": raw,
"filtered_ms": filtered,
**stats,
"avg_accepted_tokens_per_call": round(avg_accepted, 2),
"resample_rate": round(resample_rate, 4),
"effective_tps": effective_tps,
"n_calls": n_calls,
}
[docs]
def make_output_dir():
base = Path(__file__).parent / "results"
base.mkdir(exist_ok=True)
today = datetime.now().strftime("%Y-%m-%d")
# find existing runs for today, auto-increment
existing = [d.name for d in base.iterdir() if d.name.startswith(today)]
run_nums = [
int(match.group(1))
for name in existing
if (match := re.search(r"run(\d+)", name)) is not None
]
next_run = max(run_nums, default=0) + 1
out = base / f"{today}_run{next_run:02d}"
out.mkdir()
return out
[docs]
def load_config():
config_path = Path(__file__).parent / "config.json"
with open(config_path) as f:
return json.load(f)
DTYPE_MAP = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
[docs]
def run_benchmarks(model, prompt_ids, cfg, label):
print(f"\n--- {label} ---")
print("Benchmarking TTFT...")
ttft = benchmark_ttft(model, prompt_ids, cfg["warmup_ms"], cfg["rep_ms"])
print(f" TTFT median: {ttft['median_ms']:.2f} ms")
print("Benchmarking decode...")
decode = benchmark_decode(
model,
prompt_ids,
cfg["generation_length"],
cfg["decode_step_interval"],
cfg["warmup_ms"],
cfg["rep_ms"],
)
tps_values = [v["tps"] for v in decode["per_position"].values()] # type: ignore
avg_tps = sum(tps_values) / len(tps_values)
print(f" Avg TPS: {avg_tps:.1f} tokens/sec")
return {"ttft": ttft, "decode": decode}
[docs]
def main():
cfg = load_config()
dtype = DTYPE_MAP[cfg["dtype"]]
size = cfg.get("size", "1b")
max_seq_len = cfg["prompt_length"] + cfg["generation_length"]
torch.manual_seed(42)
prompt_ids = torch.randint(0, VOCAB_SIZE, (1, cfg["prompt_length"]), device="cuda")
results: dict[str, Any] = {}
if cfg.get("baseline", False):
print(f"\nBuilding baseline Transformer ({size})...")
baseline = build_baseline_model(dtype, size=size)
init_kv_cache(baseline, batch_size=1, max_seq_len=max_seq_len)
with torch.no_grad():
results["baseline"] = run_benchmarks(
baseline, prompt_ids, cfg, f"Baseline ({size})"
)
del baseline
torch.cuda.empty_cache()
print(f"\nBuilding HydraTransformer ({size}, naive averaging)...")
model = build_hydra_model(cfg["n_heads"], cfg["heads_depth"], dtype, size=size)
init_kv_cache(model, batch_size=1, max_seq_len=max_seq_len)
with torch.no_grad():
results["hydra_naive"] = run_benchmarks(
model, prompt_ids, cfg, f"Hydra ({size}, naive avg)"
)
del model
torch.cuda.empty_cache()
if cfg.get("poe", False):
print("\nBuilding PoE ensemble (real merged-LoRA weights)...")
poe, poe_model = build_poe(cfg, dtype=dtype)
# Initialise KV cache once for the prefill TTFT and γ sweep that
# follow. PoE.generate_with_cache re-inits internally per call, so
# this only affects the residual_forward TTFT bench below.
poe.model.init_kv_cache(
batch_size=1,
max_seq_len=max_seq_len
+ cfg["poe_max_new_tokens"]
+ max(cfg.get("poe_gammas") or [cfg["poe_gamma"]]),
omit_last=True,
)
with torch.no_grad():
print("\n--- Hydra + PoE (TTFT, prefill only) ---")
poe_ttft = benchmark_poe_ttft(
poe,
prompt_ids,
warmup_ms=cfg["warmup_ms"],
rep_ms=cfg["rep_ms"],
)
print(f" TTFT median: {poe_ttft['median_ms']:.2f} ms")
gammas = cfg.get("poe_gammas") or [cfg["poe_gamma"]]
rep_ms_poe = cfg.get("poe_rep_ms", cfg["rep_ms"])
per_gamma: dict[str, Any] = {}
with torch.no_grad():
for gamma in gammas:
poe.gamma = gamma
print(f"\n--- Hydra + PoE (gamma={gamma}) ---")
stats = benchmark_poe(
poe,
cfg["poe_prompt"],
warmup_ms=cfg["warmup_ms"],
rep_ms=rep_ms_poe,
)
print(
f" call median: {stats['median_ms']:.1f} ms · "
f"avg accepted: {stats['avg_accepted_tokens_per_call']} tok · "
f"effective TPS: {stats['effective_tps']} · "
f"resample: {stats['resample_rate']} · n_calls={stats['n_calls']}"
)
per_gamma[str(gamma)] = stats
results["hydra_poe"] = {"ttft": poe_ttft, "per_gamma": per_gamma}
del poe, poe_model
torch.cuda.empty_cache()
metadata = {
"timestamp": datetime.now().isoformat(),
"gpu": torch.cuda.get_device_name(0),
**{
k: cfg[k]
for k in [
"size",
"dtype",
"n_heads",
"heads_depth",
"prompt_length",
"generation_length",
"decode_step_interval",
]
if k in cfg
},
"poe_gamma": cfg.get("poe_gamma"),
"poe_beta": cfg.get("poe_beta"),
"poe_max_new_tokens": cfg.get("poe_max_new_tokens"),
}
results["metadata"] = cast(Any, metadata)
out_dir = make_output_dir()
with open(out_dir / "results.json", "w") as f:
json.dump(results, f, indent=2)
from olmo_tap.benchmarks.plotting import plot_results
plot_results(results, out_dir)
print(f"\nResults saved to {out_dir}")
if __name__ == "__main__":
main()