Prediction space

Fit a PyMC model, join posterior expectations to covariates with prediction_draws(), and plot them

prediction_draws() joins posterior prediction draws such as mu to covariates, producing one tidy row per chain × draw × observation. Parameters are deliberately excluded — prediction space stays separate from parameter space, so coefficients never get duplicated across observations.

Pass a DataFrame when you want explicit control over the covariates used for plotting. With newdata=None, prediction_draws() reads covariates from a constant-data group in the DataTree and fails loudly if that group is missing.

Full PyMC workflow

The workflow starts from observed data with strong group differences, fits a PyMC model, asks PyMC for posterior expectations mu[obs_ind], and then uses prediction_draws() to attach those draws to x, groups, and observed y.

Code
from pathlib import Path
import sys

import polars as pl
import pymc as pm
import tidydraws as td
from lets_plot import (
    LetsPlot,
    aes,
    as_discrete,
    geom_line,
    geom_point,
    geom_ribbon,
    ggplot,
    labs,
)

for parent in [Path.cwd(), *Path.cwd().parents]:
    helper_dir = parent / "docs" / "examples"
    if (helper_dir / "_pymc_workflow.py").exists():
        sys.path.insert(0, str(helper_dir))
        break

from _pymc_workflow import interval_summary, simulate_grouped_regression

LetsPlot.setup_html()

workflow = simulate_grouped_regression(seed=2026)
observed = workflow.observed

coords = {
    "groups": workflow.group_names,
    "obs_ind": observed.get_column("obs_ind").to_numpy(),
}

Build our PyMC model and fit.

with pm.Model(coords=coords) as model:
    x = pm.Data("x", observed.get_column("x").to_numpy(), dims="obs_ind")
    group_idx = pm.Data(
        "group_idx",
        observed.get_column("group_idx").to_numpy().astype("int64"),
        dims="obs_ind",
    )
    intercept = pm.Normal("intercept", mu=0.0, sigma=2.0, dims="groups")
    beta = pm.Normal("beta", mu=0.0, sigma=1.5, dims="groups")
    sigma = pm.HalfNormal("sigma", sigma=1.0)
    mu = pm.Deterministic(
        "mu",
        intercept[group_idx] + beta[group_idx] * x,
        dims="obs_ind",
    )
    pm.Normal(
        "y",
        mu=mu,
        sigma=sigma,
        observed=observed.get_column("y").to_numpy(),
        dims="obs_ind",
    )
    dt = pm.sample(
        draws=400,
        tune=400,
        random_seed=2027,
    )
    pm.sample_posterior_predictive(
        dt,
        var_names=["mu", "y"],
        predictions=True,
        extend_inferencedata=True,
        random_seed=2028,
        progressbar=False,
    )
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [intercept, beta, sigma]
/home/runner/work/tidydraws/tidydraws/.venv/lib/python3.12/site-packages/pymc/step_methods/hmc/quadpotential.py:321: RuntimeWarning: overflow encountered in dot
  return 0.5 * np.dot(x, v_out)

Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 1 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
Sampling: [y]
dt
<xarray.DataTree>
Group: /
├── Group: /posterior
│       Dimensions:    (chain: 2, draw: 400, groups: 4, obs_ind: 96)
│       Coordinates:
│         * chain      (chain) int64 16B 0 1
│         * draw       (draw) int64 3kB 0 1 2 3 4 5 6 7 ... 393 394 395 396 397 398 399
│         * groups     (groups) <U5 80B 'North' 'South' 'East' 'West'
│         * obs_ind    (obs_ind) int64 768B 0 1 2 3 4 5 6 7 ... 88 89 90 91 92 93 94 95
│       Data variables:
│           intercept  (chain, draw, groups) float64 26kB -1.162 0.1217 ... 1.421 2.132
│           beta       (chain, draw, groups) float64 26kB 0.425 0.8192 ... 1.452 -0.5716
│           sigma      (chain, draw) float64 6kB 0.4078 0.4078 0.4084 ... 0.4723 0.4386
│           mu         (chain, draw, obs_ind) float64 614kB -2.039 -1.93 ... 0.8962
│       Attributes:
│           created_at:                 2026-06-27T21:02:49.957324+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                ['chain', 'draw']
│           sampling_time:              0.5894467830657959
│           tuning_steps:               400
├── Group: /sample_stats
│       Dimensions:                (chain: 2, draw: 400)
│       Coordinates:
│         * chain                  (chain) int64 16B 0 1
│         * draw                   (draw) int64 3kB 0 1 2 3 4 5 ... 395 396 397 398 399
│       Data variables: (12/18)
│           divergences            (chain, draw) int64 6kB 0 0 0 0 0 0 0 ... 0 0 0 0 0 0
│           index_in_trajectory    (chain, draw) int64 6kB 7 0 -1 2 1 ... -4 2 -2 -4 -2
│           acceptance_rate        (chain, draw) float64 6kB 0.9216 0.7231 ... 0.9916
│           step_size              (chain, draw) float64 6kB 0.9673 0.9673 ... 0.813
│           smallest_eigval        (chain, draw) float64 6kB nan nan nan ... nan nan nan
│           tree_depth             (chain, draw) int64 6kB 3 3 2 2 2 3 3 ... 3 3 2 2 3 3
│           ...                     ...
│           perf_counter_start     (chain, draw) float64 6kB 318.6 318.6 ... 318.9 318.9
│           reached_max_treedepth  (chain, draw) bool 800B False False ... False False
│           step_size_bar          (chain, draw) float64 6kB 0.7812 0.7812 ... 0.7753
│           perf_counter_diff      (chain, draw) float64 6kB 0.0005814 ... 0.0003255
│           largest_eigval         (chain, draw) float64 6kB nan nan nan ... nan nan nan
│           diverging              (chain, draw) bool 800B False False ... False False
│       Attributes:
│           created_at:                 2026-06-27T21:02:49.967720+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                ['chain', 'draw']
│           sampling_time:              0.5894467830657959
│           tuning_steps:               400
├── Group: /observed_data
│       Dimensions:  (obs_ind: 96)
│       Coordinates:
│         * obs_ind  (obs_ind) int64 768B 0 1 2 3 4 5 6 7 8 ... 88 89 90 91 92 93 94 95
│       Data variables:
│           y        (obs_ind) float64 768B -2.106 -1.795 -2.253 ... 1.393 1.081 0.7312
│       Attributes:
│           created_at:                 2026-06-27T21:02:49.971562+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                []
├── Group: /constant_data
│       Dimensions:    (obs_ind: 96)
│       Coordinates:
│         * obs_ind    (obs_ind) int64 768B 0 1 2 3 4 5 6 7 ... 88 89 90 91 92 93 94 95
│       Data variables:
│           x          (obs_ind) float64 768B -2.063 -1.807 -1.804 ... 1.55 1.835 2.163
│           group_idx  (obs_ind) int32 384B 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3
│       Attributes:
│           created_at:                 2026-06-27T21:02:49.972732+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                []
├── Group: /predictions
│       Dimensions:  (chain: 2, draw: 400, obs_ind: 96)
│       Coordinates:
│         * chain    (chain) int64 16B 0 1
│         * draw     (draw) int64 3kB 0 1 2 3 4 5 6 7 ... 393 394 395 396 397 398 399
│         * obs_ind  (obs_ind) int64 768B 0 1 2 3 4 5 6 7 8 ... 88 89 90 91 92 93 94 95
│       Data variables:
│           y        (chain, draw, obs_ind) float64 614kB -1.592 -1.999 ... 0.9266 0.599
│           mu       (chain, draw, obs_ind) float64 614kB -2.039 -1.93 ... 1.083 0.8962
│       Attributes:
│           created_at:                 2026-06-27T21:02:51.092763+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                ['chain', 'draw']
└── Group: /predictions_constant_data
        Dimensions:    (obs_ind: 96)
        Coordinates:
          * obs_ind    (obs_ind) int64 768B 0 1 2 3 4 5 6 7 ... 88 89 90 91 92 93 94 95
        Data variables:
            x          (obs_ind) float64 768B -2.063 -1.807 -1.804 ... 1.55 1.835 2.163
            group_idx  (obs_ind) int32 384B 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3
        Attributes:
            created_at:                 2026-06-27T21:02:51.095575+00:00
            creation_library:           ArviZ
            creation_library_version:   1.2.0
            creation_library_language:  Python
            inference_library:          pymc
            inference_library_version:  6.0.1
            sample_dims:                []

One call joins mu to the observed covariates and responses:

pred = td.prediction_draws(dt, newdata=observed, var_name="mu")
pred.head()
shape: (5, 9)
chaindrawobs_indmugroupsgroup_idxxmu_truey
i64i64i64f64stri64f64f64f64
000-2.038775"North"0-2.06345-1.972207-2.106032
001-1.929716"North"0-1.806841-1.882394-1.794667
002-1.928457"North"0-1.80388-1.881358-2.253447
003-1.742612"North"0-1.366599-1.72831-2.094702
004-1.694453"North"0-1.253284-1.688649-2.373679

Note the columns: chain, draw, obs_ind, mu, groups, group_idx, x, mu_true, and y. No beta, no intercept — parameters stay in their own frame.