olmo_tap.benchmarks.inference

Inference latency / throughput benchmark for vanilla OLMo, naive-averaging Hydra, and Hydra + PoE (Product of Experts speculative decode).

Three configurations are timed back-to-back on the same GPU so the numbers are directly comparable:

  1. baseline — vanilla OLMo (1B or 7B), random weights. Fastest per token; sets the upper bound any Hydra variant can approach.

  2. hydra_naive — HydraTransformer run with all heads in series and logits averaged (the default forward_and_sample codepath). Random weights. Strictly more compute than (1) — every step pays trunk + N head forwards.

  3. hydra_poe — Hydra wrapped in olmo_tap.inference.poe.PoE, running γ-step speculative decode with one draft head + (N-1)-head verifier jury (Product of Experts: log-probs are summed across verifier heads, i.e. multiplied in probability space). Real merged-LoRA weights via olmo_tap.inference.loading_weights.load_ensemble() so acceptance rate is meaningful. Optionally swept across γ values.

Knobs live in config.json next to this file. The baseline and hydra_naive rows use random weights because per-step latency is architecture-bound — only PoE needs real weights (acceptance rate depends on the actual model behavior).

Usage:

pixi run -e cuda python -m olmo_tap.benchmarks.inference

Output lands in olmo_tap/benchmarks/results/<YYYY-MM-DD>_run<NN>/:

  • results.json — raw timings + per-position decode stats + PoE per-γ stats.

  • graph.png — TTFT KDE, per-position decode latency, per-position TPS.

NOTE: PoE is reported as two numbers per γ. ttft is the prefill cost — one residual_forward over all draft + verifier heads, the same call PoE issues at the top of every generate_with_cache. per_gamma[γ] is the end-to-end call timing: median_ms is the wall time of a full max_new_tokens-token generation, and effective_tps is 1000 * accepted_tokens / median_ms, i.e. draft-acceptance throughput. Under non-zero resample rate the user-visible TPS is slightly higher because resampled positions are still real output tokens; see benchmark_poe() docstring for exact definitions.

Functions

benchmark_decode(model, prompt_ids[, ...])

benchmark_poe(poe, prompt_text[, warmup_ms, ...])

Time end-to-end PoE.generate_with_cache() and compute throughput + acceptance statistics.

benchmark_poe_ttft(poe, prompt_ids[, ...])

Time the prefill stage that PoE pays at the start of every generation.

benchmark_ttft(model, prompt_ids[, ...])

build_baseline_model([dtype, device, size])

Build a vanilla single-tower OLMo Transformer (random weights).

build_hydra_model([n_heads, heads_depth, ...])

Build a HydraTransformer for naive-averaging benchmarks (random weights).

build_poe(cfg[, dtype])

Build a PoE (Product of Experts speculative decoder) with real merged-LoRA weights from load_ensemble().

forward_and_sample(model, input_ids)

Run a forward pass and argmax-sample the next token.

get_all_kv_cache_managers(model)

init_kv_cache(model, batch_size, max_seq_len)

load_config()

main()

make_output_dir()

reset_kv_cache_position(managers, position)

run_benchmarks(model, prompt_ids, cfg, label)

olmo_tap.benchmarks.inference.benchmark_decode(model, prompt_ids, gen_length=128, step_interval=8, warmup_ms=100.0, rep_ms=1000.0)[source]
olmo_tap.benchmarks.inference.benchmark_poe(poe, prompt_text, warmup_ms=100.0, rep_ms=1000.0)[source]

Time end-to-end PoE.generate_with_cache() and compute throughput + acceptance statistics.

PoE is timed at call granularity (one call = poe.max_new_tokens output tokens), not per step, because the γ-step draft loop and the verifier pass don’t decompose into a single repeating “fn” the way naive decode does. PoE.generate_with_cache() re-inits its own KV cache on every call, so the harness setup callback is left unset.

The reported metrics for a benchmark window:

  • median_ms — median wall time of one full call.

  • avg_accepted_tokens_per_call — mean over calls of output_tokens resampled_tokens (tokens the draft head got right).

  • resample_rate — total resampled / total output tokens; the fraction of output positions where the verifier jury rejected the draft and a corrected token was sampled. Lower is better.

  • effective_tps1000 * avg_accepted / median_ms; draft-acceptance throughput, not user-visible throughput. Under non-zero resample_rate user-visible TPS is slightly higher because resampled positions are still real output tokens; the difference is resample_rate * effective_tps / (1 - resample_rate).

  • n_calls — number of timed full-generation calls (estimate + warmup + measurement).

Parameters:
  • poePoE instance. poe.gamma may be mutated by the caller between invocations to sweep γ; everything else stays put.

  • prompt_text – A single-turn user message string. Will be wrapped with apply_chat_template inside generate_with_cache().

  • warmup_ms – Warmup budget passed to harness.benchmark().

  • rep_ms – Measurement budget. Should comfortably exceed one call’s wall time so multiple measurements land — a 64-token PoE call on 7B is ~3-5 s, so use rep_ms >= 8000 for stable medians.

Returns:

Stats dict with the keys described above plus raw_ms and filtered_ms (post-IQR-filtered timings).

olmo_tap.benchmarks.inference.benchmark_poe_ttft(poe, prompt_ids, warmup_ms=100.0, rep_ms=1000.0)[source]

Time the prefill stage that PoE pays at the start of every generation.

Specifically, times HydraTransformer.residual_forward() with last_token_only=True over all draft + verifier heads — the exact call issued at the top of PoE.generate_with_cache(). KV-cache pointers are reset to zero in setup each iteration so every measurement is a fresh prefill. Uses the same random token IDs as the other TTFT rows for apples-to-apples comparison; token values don’t affect prefill cost since it’s architecture-bound at this granularity.

Parameters:
  • poePoE instance whose .model carries an initialised KV cache (caller must have called poe.model.init_kv_cache(...)).

  • prompt_ids(1, prompt_length) token tensor on CUDA.

  • warmup_ms – Warmup budget passed to harness.benchmark().

  • rep_ms – Measurement budget.

Returns:

TTFT stats dict with the same shape as benchmark_ttft().

olmo_tap.benchmarks.inference.benchmark_ttft(model, prompt_ids, warmup_ms=100.0, rep_ms=1000.0)[source]
olmo_tap.benchmarks.inference.build_baseline_model(dtype=torch.bfloat16, device='cuda', size='1b')[source]

Build a vanilla single-tower OLMo Transformer (random weights).

Parameters:
  • dtype – Compute dtype.

  • device – Target device.

  • size"1b" or "7b".

Returns:

A built, .eval()-mode Transformer using the FlashAttention-2 backend, with random weights (latency is architecture-bound).

olmo_tap.benchmarks.inference.build_hydra_model(n_heads=5, heads_depth=3, dtype=torch.bfloat16, device='cuda', size='1b')[source]

Build a HydraTransformer for naive-averaging benchmarks (random weights).

Parameters:
  • n_heads – Number of parallel heads in the Hydra branching.

  • heads_depth – Layers per head (trunk gets n_layers - heads_depth).

  • dtype – Compute dtype.

  • device – Target device.

  • size"1b" or "7b" — picks the OLMo factory.

Returns:

A built, .eval()-mode HydraTransformer with random weights.

NOTE: weights are random — latency is architecture-bound at this granularity. Only the PoE row needs real weights (for acceptance-rate realism); see build_poe().

olmo_tap.benchmarks.inference.build_poe(cfg, dtype=torch.bfloat16)[source]

Build a PoE (Product of Experts speculative decoder) with real merged-LoRA weights from load_ensemble().

Unlike the baseline / naive-Hydra rows (which use random weights), PoE is built on the deployed 7B Hydra with security + robustness LoRAs merged per head plus the uncertainty head, because PoE’s acceptance rate depends on the actual logit distributions the verifier heads produce. Random weights would give meaningless accept/reject behavior.

Parameters:
  • cfg – Parsed config.json dict. Reads poe_gamma (overridden per-call when sweeping), poe_beta, poe_max_new_tokens.

  • dtype – Compute dtype to cast the loaded ensemble to.

Returns:

(poe, model) — the PoE wrapper plus the underlying HydraTransformer (returned so the caller can del it for VRAM cleanup before the next config).

olmo_tap.benchmarks.inference.forward_and_sample(model, input_ids)[source]

Run a forward pass and argmax-sample the next token.

Skips per-position lm_head projection on prefill (last_token_only on Hydra, logits_to_keep=1 on vanilla Transformer) so TTFT timings are apples-to-apples with PoE’s residual_forward path. Decode steps have seq_len=1, so the flag is a no-op there.

olmo_tap.benchmarks.inference.get_all_kv_cache_managers(model) list[KVCacheManager | None][source]
olmo_tap.benchmarks.inference.init_kv_cache(model, batch_size, max_seq_len)[source]
olmo_tap.benchmarks.inference.load_config()[source]
olmo_tap.benchmarks.inference.main()[source]
olmo_tap.benchmarks.inference.make_output_dir()[source]
olmo_tap.benchmarks.inference.reset_kv_cache_position(managers: list[KVCacheManager | None], position)[source]
olmo_tap.benchmarks.inference.run_benchmarks(model, prompt_ids, cfg, label)[source]