Showcase

Stress-test tidydraws across model archetypes — one canonical plot per model shape

tidydraws is tested against synthetic DataTrees. This page stress-tests the full pipeline — PyMC model → DataTreeparameter_draws / prediction_draws / compare_draws → plot — across eight model archetypes that each exercise a different code path. If this page builds, the data layer survived that shape. If it breaks, you know exactly which archetype is the culprit.

Code
from pathlib import Path
import sys

import numpy as np
import polars as pl
import tidydraws as td
from lets_plot import (
    LetsPlot,
    aes,
    as_discrete,
    facet_wrap,
    geom_density,
    geom_density2d,
    geom_hline,
    geom_jitter,
    geom_line,
    geom_point,
    geom_pointrange,
    geom_ribbon,
    ggplot,
    labs,
    theme_minimal,
)

for parent in [Path.cwd(), *Path.cwd().parents]:
    helper_dir = parent / "docs" / "examples"
    if (helper_dir / "_showcase_models.py").exists():
        sys.path.insert(0, str(helper_dir))
        break

from _showcase_models import (
    simple_regression,
    varying_intercepts,
    varying_slopes,
    varying_both,
    multiple_regression,
    logistic,
    simple_regression_1chain,
    interval_summary,
)

LetsPlot.setup_html()

1. Simple linear regression

Scalar parameters only — no group dimensions. Exercises the simplest extraction path.

dt, obs = simple_regression(seed=2026)
draws = td.parameter_draws(dt, "alpha", "beta", "sigma")
pred = td.prediction_draws(dt, newdata=obs, var_name="mu")
pred_summary = interval_summary(pred, "mu", ["obs_ind", "x", "y"], [0.50, 0.80, 0.95])
Sampling: [alpha, beta, sigma, y]
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [alpha, beta, sigma]
Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 0 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
Sampling: []
Cross-join detected between frame 0 and 1. Broadcasting scalar or differently-dimensioned variable on dims ['draw', 'chain'].
Cross-join detected between frame 1 and 2. Broadcasting scalar or differently-dimensioned variable on dims ['draw', 'chain'].

Bivariate posterior of intercept and slope

parameter_draws(dt, "alpha", "beta") returns both columns on the same chain × draw index, so a 2D density contour uses them directly.

(
    ggplot(draws.to_pandas(), aes("alpha", "beta"))
    + geom_density2d(bins=8, alpha=0.8)
    + geom_point(alpha=0.15, size=1.0, color="#444444")
    + labs(x="intercept (α)", y="slope (β)", title="Bivariate posterior")
    + theme_minimal()
)

Posterior predictive fit

(
    ggplot(pred_summary.filter(pl.col("prob") == 0.95).to_pandas(), aes("x"))
    + geom_ribbon(aes(ymin="lower", ymax="upper"), alpha=0.25)
    + geom_line(aes(y="median"), color="#2166ac", size=0.8)
    + geom_point(aes(y="y"), data=obs.to_pandas(), alpha=0.6, size=1.5)
    + labs(x="x", y="y", title="Predictive fit (95% ribbon)")
    + theme_minimal()
)

2. Varying intercepts

One-dimensional alpha[group] with string coordinate labels. Exercises _datatree_group_to_df with non-numeric coords and tests that forest-plot sorting/filtering works on string dims.

dt, obs = varying_intercepts(seed=2027)
draws = td.parameter_draws(dt, "alpha[group]")
forest = interval_summary(draws, "alpha", "group", [0.50, 0.89])
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [alpha, beta, sigma]
Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 0 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

Forest plot with string labels

(
    ggplot(forest.filter(pl.col("prob") == 0.89).to_pandas(), aes("group", "median"))
    + geom_pointrange(aes(ymin="lower", ymax="upper"), size=0.8)
    + geom_hline(yintercept=0, linetype="dashed", color="#888888")
    + labs(x="group", y="intercept (α)", title="Varying intercepts (89% intervals)")
    + theme_minimal()
)

3. Varying slopes — cross-dim broadcast

Extracts beta[group] (1-d) and scalar sigma together. sigma is broadcast across all groups — verifies the cross-dim join doesn’t create phantom rows.

dt, obs = varying_slopes(seed=2028)
draws = td.parameter_draws(dt, "beta[group]", "sigma")
Sampling: [alpha, beta, sigma]
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [alpha, beta, sigma]
Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 1 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
Cross-join detected between frame 0 and 1. Broadcasting scalar or differently-dimensioned variable on dims ['draw', 'chain'].

Beta vs sigma scatter

Every group-level beta draw gets the corresponding chain/draw sigma value via broadcast.

(
    ggplot(draws.to_pandas(), aes("sigma", "beta"))
    + geom_point(aes(color="group"), alpha=0.35, size=2.0)
    + labs(
        x="sigma (σ)",
        y="beta (β)",
        color="group",
        title="Cross-dim broadcast: β[group] × σ",
    )
    + theme_minimal()
)

4. Bivariate group-level posterior

Two 1-d arrays on the same dimensions: alpha[group] and beta[group]. parameter_draws joins them — exercises the same-dims multi-variable path.

dt, obs = varying_both(seed=2029)
draws = td.parameter_draws(dt, "alpha[group]", "beta[group]")
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [alpha, beta, sigma]
Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 1 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

2D density by group

(
    ggplot(draws.to_pandas(), aes("alpha", "beta"))
    + geom_density2d(aes(color="group"), bins=5, alpha=0.7)
    + geom_point(alpha=0.10, size=0.8)
    + labs(
        x="intercept (α)",
        y="slope (β)",
        color="group",
        title="Group-level bivariate posterior",
    )
    + theme_minimal()
)

5. Multiple regression

Three scalar predictors, varying intercepts per group. parameter_draws extracts b1, b2, b3, and alpha[group] — four variables, three of them scalar, one 1-d. Exercises many-variable alignment.

dt, obs = multiple_regression(seed=2030)
draws = td.parameter_draws(dt, "b1", "b2", "b3", "alpha[group]")

coefs = pl.concat([
    draws.select(pl.col("group"), pl.col("b1").alias("value")).with_columns(
        pl.lit("b1").alias("predictor")
    ),
    draws.select(pl.col("group"), pl.col("b2").alias("value")).with_columns(
        pl.lit("b2").alias("predictor")
    ),
    draws.select(pl.col("group"), pl.col("b3").alias("value")).with_columns(
        pl.lit("b3").alias("predictor")
    ),
])
coef_forest = interval_summary(coefs, "value", ["predictor", "group"], [0.89])
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [alpha, b1, b2, b3, sigma]
Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 1 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
Cross-join detected between frame 0 and 1. Broadcasting scalar or differently-dimensioned variable on dims ['draw', 'chain'].
Cross-join detected between frame 1 and 2. Broadcasting scalar or differently-dimensioned variable on dims ['draw', 'chain'].
Cross-join detected between frame 2 and 3. Broadcasting scalar or differently-dimensioned variable on dims ['draw', 'chain'].

Faceted coefficient forest

(
    ggplot(
        coef_forest.filter(pl.col("prob") == 0.89).to_pandas(), aes("group", "median")
    )
    + geom_pointrange(aes(ymin="lower", ymax="upper"), size=0.7)
    + geom_hline(yintercept=0, linetype="dashed", color="#888888")
    + facet_wrap(facets="predictor", ncol=3)
    + labs(
        x="group",
        y="coefficient",
        title="Multiple regression coefficients (89% intervals)",
    )
    + theme_minimal()
)

6. Logistic regression

Bernoulli outcome with group-level intercepts. Predictions are computed directly from parameter draws by joining alpha[group] and beta with the prediction grid — no sample_posterior_predictive needed. Exercises parameter_draws with a non-gaussian model and demonstrates manual prediction computation.

dt, obs, grid = logistic(seed=2031)
draws = td.parameter_draws(dt, "alpha[group]", "beta")

# Compute predictions: join draws with grid, then logit_p → p
pred_draws = draws.join(grid.select(["obs_ind", "group", "x"]), on="group")
pred_draws = pred_draws.with_columns(
    (1 / (1 + (-pl.col("alpha") - pl.col("beta") * pl.col("x")).exp())).alias("p"),
)
pred_summary = interval_summary(
    pred_draws, "p", ["obs_ind", "x", "group"], [0.80, 0.95]
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [alpha, beta]
Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 0 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
Cross-join detected between frame 0 and 1. Broadcasting scalar or differently-dimensioned variable on dims ['draw', 'chain'].

Probability ribbon with observed data

(
    ggplot(
        pred_summary.filter(pl.col("prob") == 0.95).to_pandas(),
        aes("x"),
    )
    + geom_ribbon(aes(ymin="lower", ymax="upper", fill="group"), alpha=0.25)
    + geom_line(aes(y="median", color="group"), size=0.8)
    + geom_jitter(
        aes(y="y"),
        data=obs.with_columns(pl.col("y").cast(pl.Float64)).to_pandas(),
        alpha=0.5,
        height=0.02,
        size=1.5,
    )
    + labs(
        x="x",
        y="P(y=1)",
        color="group",
        fill="group",
        title="Logistic: probability ribbon + observed 0/1",
    )
    + theme_minimal()
)

7. Prior vs posterior

compare_draws() stacks prior and posterior into one frame with a source column. Exercises the stacked-concat path with real PyMC prior samples.

dt, _obs = varying_slopes(seed=2028)  # same seed as §3 — already has prior
compare = td.compare_draws(dt, "beta[group]", groups=["prior", "posterior"])
Sampling: [alpha, beta, sigma]
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [alpha, beta, sigma]
Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 0 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details

Overlaid prior and posterior densities

(
    ggplot(compare.to_pandas(), aes("beta", fill="source"))
    + geom_density(alpha=0.45)
    + facet_wrap(facets="group", ncol=2)
    + labs(
        x="beta (β)", y="density", fill="source", title="Prior vs posterior by group"
    )
    + theme_minimal()
)

8. Single chain

Same simple regression as §1, but sampled with chains=1. Verifies that nothing in the extraction or alignment assumes chain >= 2.

dt, obs = simple_regression_1chain(seed=2032)
draws = td.parameter_draws(dt, "alpha", "beta", "sigma")
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [alpha, beta, sigma]
Sampling 1 chain for 400 tune and 400 draw iterations (400 + 400 draws total) took 0 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Cross-join detected between frame 0 and 1. Broadcasting scalar or differently-dimensioned variable on dims ['draw', 'chain'].
Cross-join detected between frame 1 and 2. Broadcasting scalar or differently-dimensioned variable on dims ['draw', 'chain'].

Density with one chain

(
    ggplot(draws.to_pandas(), aes("beta"))
    + geom_density(fill="#2166ac", alpha=0.5, color="#2166ac")
    + labs(
        x="beta (β)",
        y="density",
        title=f"Single chain: {draws.select(pl.col('chain').n_unique()).item()} chain × {draws.select(pl.col('draw').n_unique()).item()} draws",
    )
    + theme_minimal()
)