title: “tidydraws”

AI / Agents

Skills
llms.txt
llms-full.txt

Developers

Benjamin Vincent

Community

Contributing guide
Full license MIT

Meta

Requires: Python >=3.12
Provides-Extra: letsplot, plotnine, all, dev


A tidybayes-inspired data layer for declarative Bayesian visualisation in Python.

tidydraws turns MCMC output (ArviZ DataTree) into tidy Polars frames that are ready to plot — one .to_pandas() away from any ggplot-like backend. It does no plotting itself. Three functions, three spaces:

Function Space Plot archetype
parameter_draws() parameter density, forest, scatter
prediction_draws() prediction ribbon + line, fit + data
compare_draws() comparison prior vs posterior, intervals
## fmt: off
#| code-fold: true
#| cache: true
## fmt: on
from pathlib import Path
import sys

import polars as pl
import pymc as pm
import tidydraws as td
from lets_plot import (
    LetsPlot,
    aes,
    as_discrete,
    facet_wrap,
    geom_density,
    geom_hline,
    geom_line,
    geom_pointrange,
    geom_ribbon,
    ggsave,
    gggrid,
    ggplot,
    labs,
)

for parent in [Path.cwd(), *Path.cwd().parents]:
    helper_dir = parent / "docs" / "examples"
    if (helper_dir / "_pymc_workflow.py").exists():
        sys.path.insert(0, str(helper_dir))
        break

from _pymc_workflow import simulate_grouped_regression

LetsPlot.setup_html()

workflow = simulate_grouped_regression(seed=2026)
observed = workflow.observed.with_columns(pl.col("groups").alias("group"))

coords = {
    "groups": workflow.group_names,
    "obs_ind": observed.get_column("obs_ind").to_numpy(),
}

with pm.Model(coords=coords) as model:
    x = pm.Data("x", observed.get_column("x").to_numpy(), dims="obs_ind")
    group_idx = pm.Data(
        "group_idx",
        observed.get_column("group_idx").to_numpy().astype("int64"),
        dims="obs_ind",
    )
    intercept = pm.Normal("intercept", mu=0.0, sigma=2.0, dims="groups")
    beta = pm.Normal("beta", mu=0.0, sigma=1.5, dims="groups")
    sigma = pm.HalfNormal("sigma", sigma=1.0)
    mu = pm.Deterministic(
        "mu",
        intercept[group_idx] + beta[group_idx] * x,
        dims="obs_ind",
    )
    pm.Normal(
        "y",
        mu=mu,
        sigma=sigma,
        observed=observed.get_column("y").to_numpy(),
        dims="obs_ind",
    )
    prior = pm.sample_prior_predictive(
        draws=800,
        random_seed=2029,
        var_names=["intercept", "beta", "sigma"],
    )
    dt = pm.sample(
        draws=400,
        tune=400,
        random_seed=2026,
        progressbar=False,
        compute_convergence_checks=False,
    )
    pm.sample_posterior_predictive(
        dt,
        var_names=["mu"],
        predictions=True,
        extend_inferencedata=True,
        random_seed=2028,
        progressbar=False,
    )

dt.update(prior)
Sampling: [beta, intercept, sigma]
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [intercept, beta, sigma]
Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 0 seconds.
Sampling: []
beta_df = td.parameter_draws(dt, "beta[groups]")
beta_summary = (
    beta_df
    .group_by("groups")
    .agg(
        pl.col("beta").quantile(0.055).alias("lower"),
        pl.col("beta").median().alias("median"),
        pl.col("beta").quantile(0.945).alias("upper"),
    )
    .sort("groups")
)
compare = td.compare_draws(dt, "beta[groups]")
pred = td.prediction_draws(dt, newdata=observed, var_name="mu")
pred_summary = (
    pred
    .group_by("obs_ind", "x", "group")
    .agg(
        pl.col("mu").quantile(0.055).alias("lower"),
        pl.col("mu").median().alias("median"),
        pl.col("mu").quantile(0.945).alias("upper"),
    )
    .sort("x")
)

p_forest = (
    ggplot(beta_summary.to_pandas(), aes("groups", "median"))
    + geom_pointrange(aes(ymin="lower", ymax="upper"), size=0.8)
    + geom_hline(yintercept=0, linetype="dashed", color="#888888")
    + labs(x="group", y="beta", title="Forest plot")
)
p_density = (
    ggplot(beta_df.to_pandas(), aes("beta", fill="groups"))
    + geom_density(alpha=0.5)
    + labs(x="beta", y="density", fill="group", title="Posterior density")
)
p_compare = (
    ggplot(compare.to_pandas(), aes("beta", fill="source"))
    + geom_density(alpha=0.5)
    + facet_wrap(facets="groups", ncol=2)
    + labs(x="beta", y="density", fill="source", title="Prior vs posterior")
)
p_pred = (
    ggplot(pred_summary.to_pandas(), aes("x"))
    + geom_ribbon(
        aes(ymin="lower", ymax="upper", fill=as_discrete("group")), alpha=0.25
    )
    + geom_line(aes(y="median", color=as_discrete("group")), size=0.8)
    + labs(x="x", y="mu", color="group", fill="group", title="Predictive fit")
)

g = gggrid([p_forest, p_density, p_compare, p_pred], ncol=2)
ggsave(g, "index-plot.png", path="../docs/assets", scale=1.5)
g
Fontconfig error: Cannot load default config file: No such file: (null)

Install

With uv

uv add tidydraws

With pip

pip install tidydraws

Why tidydraws?

Plotting MCMC output in Python means manually slicing xarray dimensions, iterating groups, and aligning coordinates — imperative, verbose, error-prone. R’s tidybayes solved this with a data layer that respects parameter space vs prediction space. tidydraws brings that to Python on Polars.

Note

tidydraws is plotting-backend-agnostic: it returns Polars DataFrames, and .to_pandas() bridges to lets-plot, plotnine, or any library that takes pandas. The examples use lets-plot; see the parameter_draws() page for the same plot in plotnine.


Inspired by tidybayes for R.