Source code for olmo_tap.benchmarks.plotting

"""
Plot the three-config benchmark output produced by
:mod:`olmo_tap.benchmarks.inference`.

A single figure with three subplots:

1. **TTFT distribution**     — KDE + histogram of prefill (time-to-first-token)
   timings for ``baseline`` and ``hydra_naive``. PoE is excluded because
   prefill is bundled inside its full-generation call.
2. **Decode latency vs KV position** — per-step decode latency (ms) over the
   benchmarked KV positions. PoE is rendered as horizontal dashed lines, one
   per γ, at its **per-token equivalent latency** (call median /
   accepted tokens) so it sits on the same axis as the per-step rows.
3. **TPS vs KV position**    — tokens per second, same shape. PoE rendered as
   horizontal dashed lines at its effective TPS per γ.

Colour convention: orange = baseline, blue = naive Hydra, green family = PoE
(darker greens for larger γ).

Usually invoked indirectly — :func:`olmo_tap.benchmarks.inference.main` calls
:func:`plot_results` after writing ``results.json``.
"""

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde


[docs] def plot_histogram_kde(ax, timings, label, color): data = np.array(timings) ax.hist(data, bins=30, alpha=0.3, color=color, density=True, label=f"{label} hist") kde = gaussian_kde(data) x = np.linspace(data.min(), data.max(), 200) ax.plot(x, kde(x), color=color, label=f"{label} KDE") ax.set_xlabel("Latency (ms)") ax.set_ylabel("Density")
[docs] def plot_decode_curve(ax_latency, ax_tps, decode_results, label, color): positions = sorted(decode_results["per_position"].keys(), key=int) medians = [decode_results["per_position"][p]["median_ms"] for p in positions] p20s = [decode_results["per_position"][p]["p20_ms"] for p in positions] p80s = [decode_results["per_position"][p]["p80_ms"] for p in positions] tps = [decode_results["per_position"][p]["tps"] for p in positions] pos_ints = [int(p) for p in positions] ax_latency.plot(pos_ints, medians, color=color, label=label) ax_latency.fill_between(pos_ints, p20s, p80s, alpha=0.2, color=color) ax_latency.set_xlabel("KV Cache Position") ax_latency.set_ylabel("Latency (ms)") ax_tps.plot(pos_ints, tps, color=color, label=label) ax_tps.set_xlabel("KV Cache Position") ax_tps.set_ylabel("Tokens/sec")
PALETTE = { "baseline": "tab:orange", "hydra_naive": "tab:blue", "hydra_poe": "tab:green", } LABELS = { "baseline": "Baseline OLMo", "hydra_naive": "Hydra (naive avg)", "hydra_poe": "Hydra + PoE", }
[docs] def plot_results(results, output_dir): """Render the three-config comparison figure to ``graph.png``. :param results: Benchmark results dict as written to ``results.json`` by :func:`olmo_tap.benchmarks.inference.main`. Recognised top-level keys: ``baseline``, ``hydra_naive``, ``hydra_poe``. Missing keys are skipped silently — a partial run still plots. :param output_dir: Directory to write ``graph.png`` into. Must exist. """ fig, axes = plt.subplots(1, 3, figsize=(18, 5)) ax_ttft, ax_latency, ax_tps = axes # TTFT distribution + per-position decode curve for the two non-PoE configs for key in ("baseline", "hydra_naive"): if key not in results: continue color = PALETTE[key] label = LABELS[key] plot_histogram_kde(ax_ttft, results[key]["ttft"]["filtered_ms"], label, color) plot_decode_curve(ax_latency, ax_tps, results[key]["decode"], label, color) # PoE prefill TTFT (residual_forward over all heads, no draft loop). if "hydra_poe" in results and "ttft" in results["hydra_poe"]: plot_histogram_kde( ax_ttft, results["hydra_poe"]["ttft"]["filtered_ms"], f"{LABELS['hydra_poe']} (prefill)", PALETTE["hydra_poe"], ) # PoE: per-gamma effective TPS + per-token equivalent latency. Per-token # latency = call_median_ms / avg_accepted_tokens (apples-to-apples vs the # other rows' per-step decode latency). if "hydra_poe" in results and "per_gamma" in results["hydra_poe"]: per_gamma = results["hydra_poe"]["per_gamma"] cmap = plt.get_cmap("Greens") n = len(per_gamma) for i, (g_str, p) in enumerate( sorted(per_gamma.items(), key=lambda kv: int(kv[0])) ): color = cmap(0.4 + 0.5 * (i / max(n - 1, 1))) per_tok_ms = ( p["median_ms"] / p["avg_accepted_tokens_per_call"] if p["avg_accepted_tokens_per_call"] > 0 else 0.0 ) label_g = f"PoE γ={g_str}" ax_tps.axhline( p["effective_tps"], color=color, linestyle="--", label=f"{label_g} ({p['effective_tps']} tok/s, resample={p['resample_rate']})", ) ax_latency.axhline( per_tok_ms, color=color, linestyle="--", label=f"{label_g} ({per_tok_ms:.1f} ms/tok eq.)", ) ax_ttft.set_title("TTFT Distribution") ax_ttft.legend() ax_latency.set_title("Decode Latency vs Position") ax_latency.legend() ax_tps.set_title("TPS vs Position") ax_tps.legend() fig.tight_layout() fig.savefig(output_dir / "graph.png", dpi=150) plt.close(fig) print(f"Plot saved to {output_dir / 'graph.png'}")