olmo_tap.benchmarks.inference¶
Inference latency / throughput benchmark for vanilla OLMo, naive-averaging Hydra, and Hydra + PoE (Product of Experts speculative decode).
Three configurations are timed back-to-back on the same GPU so the numbers are directly comparable:
baseline— vanilla OLMo (1B or 7B), random weights. Fastest per token; sets the upper bound any Hydra variant can approach.hydra_naive— HydraTransformer run with all heads in series and logits averaged (the defaultforward_and_samplecodepath). Random weights. Strictly more compute than (1) — every step pays trunk + N head forwards.hydra_poe— Hydra wrapped inolmo_tap.inference.poe.PoE, running γ-step speculative decode with one draft head + (N-1)-head verifier jury (Product of Experts: log-probs are summed across verifier heads, i.e. multiplied in probability space). Real merged-LoRA weights viaolmo_tap.inference.loading_weights.load_ensemble()so acceptance rate is meaningful. Optionally swept across γ values.
Knobs live in config.json next to this file. The baseline and
hydra_naive rows use random weights because per-step latency is
architecture-bound — only PoE needs real weights (acceptance rate depends on
the actual model behavior).
Usage:
pixi run -e cuda python -m olmo_tap.benchmarks.inference
Output lands in olmo_tap/benchmarks/results/<YYYY-MM-DD>_run<NN>/:
results.json— raw timings + per-position decode stats + PoE per-γ stats.graph.png— TTFT KDE, per-position decode latency, per-position TPS.
NOTE:
PoE is reported as two numbers per γ. ttft is the prefill cost — one
residual_forward over all draft + verifier heads, the same call PoE
issues at the top of every generate_with_cache. per_gamma[γ] is the
end-to-end call timing: median_ms is the wall time of a full
max_new_tokens-token generation, and effective_tps is
1000 * accepted_tokens / median_ms, i.e. draft-acceptance throughput.
Under non-zero resample rate the user-visible TPS is slightly higher because
resampled positions are still real output tokens; see benchmark_poe()
docstring for exact definitions.
Functions
|
|
|
Time end-to-end |
|
Time the prefill stage that PoE pays at the start of every generation. |
|
|
|
Build a vanilla single-tower OLMo Transformer (random weights). |
|
Build a HydraTransformer for naive-averaging benchmarks (random weights). |
|
Build a |
|
Run a forward pass and argmax-sample the next token. |
|
|
|
|
|
|
|
|
|
- olmo_tap.benchmarks.inference.benchmark_decode(model, prompt_ids, gen_length=128, step_interval=8, warmup_ms=100.0, rep_ms=1000.0)[source]¶
- olmo_tap.benchmarks.inference.benchmark_poe(poe, prompt_text, warmup_ms=100.0, rep_ms=1000.0)[source]¶
Time end-to-end
PoE.generate_with_cache()and compute throughput + acceptance statistics.PoE is timed at call granularity (one call =
poe.max_new_tokensoutput tokens), not per step, because the γ-step draft loop and the verifier pass don’t decompose into a single repeating “fn” the way naive decode does.PoE.generate_with_cache()re-inits its own KV cache on every call, so the harnesssetupcallback is left unset.The reported metrics for a benchmark window:
median_ms— median wall time of one full call.avg_accepted_tokens_per_call— mean over calls ofoutput_tokens − resampled_tokens(tokens the draft head got right).resample_rate— total resampled / total output tokens; the fraction of output positions where the verifier jury rejected the draft and a corrected token was sampled. Lower is better.effective_tps—1000 * avg_accepted / median_ms; draft-acceptance throughput, not user-visible throughput. Under non-zeroresample_rateuser-visible TPS is slightly higher because resampled positions are still real output tokens; the difference isresample_rate * effective_tps / (1 - resample_rate).n_calls— number of timed full-generation calls (estimate + warmup + measurement).
- Parameters:
poe –
PoEinstance.poe.gammamay be mutated by the caller between invocations to sweep γ; everything else stays put.prompt_text – A single-turn user message string. Will be wrapped with
apply_chat_templateinsidegenerate_with_cache().warmup_ms – Warmup budget passed to
harness.benchmark().rep_ms – Measurement budget. Should comfortably exceed one call’s wall time so multiple measurements land — a 64-token PoE call on 7B is ~3-5 s, so use
rep_ms >= 8000for stable medians.
- Returns:
Stats dict with the keys described above plus
raw_msandfiltered_ms(post-IQR-filtered timings).
- olmo_tap.benchmarks.inference.benchmark_poe_ttft(poe, prompt_ids, warmup_ms=100.0, rep_ms=1000.0)[source]¶
Time the prefill stage that PoE pays at the start of every generation.
Specifically, times
HydraTransformer.residual_forward()withlast_token_only=Trueover all draft + verifier heads — the exact call issued at the top ofPoE.generate_with_cache(). KV-cache pointers are reset to zero insetupeach iteration so every measurement is a fresh prefill. Uses the same random token IDs as the other TTFT rows for apples-to-apples comparison; token values don’t affect prefill cost since it’s architecture-bound at this granularity.- Parameters:
poe –
PoEinstance whose.modelcarries an initialised KV cache (caller must have calledpoe.model.init_kv_cache(...)).prompt_ids –
(1, prompt_length)token tensor on CUDA.warmup_ms – Warmup budget passed to
harness.benchmark().rep_ms – Measurement budget.
- Returns:
TTFT stats dict with the same shape as
benchmark_ttft().
- olmo_tap.benchmarks.inference.benchmark_ttft(model, prompt_ids, warmup_ms=100.0, rep_ms=1000.0)[source]¶
- olmo_tap.benchmarks.inference.build_baseline_model(dtype=torch.bfloat16, device='cuda', size='1b')[source]¶
Build a vanilla single-tower OLMo Transformer (random weights).
- Parameters:
dtype – Compute dtype.
device – Target device.
size –
"1b"or"7b".
- Returns:
A built,
.eval()-mode Transformer using the FlashAttention-2 backend, with random weights (latency is architecture-bound).
- olmo_tap.benchmarks.inference.build_hydra_model(n_heads=5, heads_depth=3, dtype=torch.bfloat16, device='cuda', size='1b')[source]¶
Build a HydraTransformer for naive-averaging benchmarks (random weights).
- Parameters:
n_heads – Number of parallel heads in the Hydra branching.
heads_depth – Layers per head (trunk gets
n_layers - heads_depth).dtype – Compute dtype.
device – Target device.
size –
"1b"or"7b"— picks the OLMo factory.
- Returns:
A built,
.eval()-mode HydraTransformer with random weights.
NOTE: weights are random — latency is architecture-bound at this granularity. Only the PoE row needs real weights (for acceptance-rate realism); see
build_poe().
- olmo_tap.benchmarks.inference.build_poe(cfg, dtype=torch.bfloat16)[source]¶
Build a
PoE(Product of Experts speculative decoder) with real merged-LoRA weights fromload_ensemble().Unlike the baseline / naive-Hydra rows (which use random weights), PoE is built on the deployed 7B Hydra with security + robustness LoRAs merged per head plus the uncertainty head, because PoE’s acceptance rate depends on the actual logit distributions the verifier heads produce. Random weights would give meaningless accept/reject behavior.
- Parameters:
cfg – Parsed
config.jsondict. Readspoe_gamma(overridden per-call when sweeping),poe_beta,poe_max_new_tokens.dtype – Compute dtype to cast the loaded ensemble to.
- Returns:
(poe, model)— thePoEwrapper plus the underlyingHydraTransformer(returned so the caller candelit for VRAM cleanup before the next config).
- olmo_tap.benchmarks.inference.forward_and_sample(model, input_ids)[source]¶
Run a forward pass and argmax-sample the next token.
Skips per-position
lm_headprojection on prefill (last_token_onlyon Hydra,logits_to_keep=1on vanilla Transformer) so TTFT timings are apples-to-apples with PoE’sresidual_forwardpath. Decode steps haveseq_len=1, so the flag is a no-op there.