Prior/posterior comparison

Fit a PyMC model, compare prior and posterior draws with compare_draws(), and plot them

compare_draws() is the prior/posterior comparison helper: it calls parameter_draws() once per selected group in the ArviZ DataTree, tags each result with a source column, and stacks them into a single DataFrame. By default it compares posterior against prior, which is the workflow most users need after fitting a PyMC model.

Full PyMC workflow

This example uses the same grouped regression data as the other examples, samples from the model prior before fitting, then compares those prior draws with the fitted posterior. That gives compare_draws() real prior draws and real posterior draws from the same PyMC model.

Code
from pathlib import Path
import sys

import polars as pl
import pymc as pm
import tidydraws as td
from lets_plot import (
    LetsPlot,
    aes,
    facet_wrap,
    geom_density,
    geom_hline,
    geom_pointrange,
    geom_vline,
    ggplot,
    labs,
    position_dodge,
)

for parent in [Path.cwd(), *Path.cwd().parents]:
    helper_dir = parent / "docs" / "examples"
    if (helper_dir / "_pymc_workflow.py").exists():
        sys.path.insert(0, str(helper_dir))
        break

from _pymc_workflow import simulate_grouped_regression

LetsPlot.setup_html()

workflow = simulate_grouped_regression(seed=2026)
observed = workflow.observed
truth = workflow.truth
/home/runner/work/tidydraws/tidydraws/.venv/lib/python3.12/site-packages/lets_plot/plot/annotation.py:551: SyntaxWarning: invalid escape sequence '\('
  """
/home/runner/work/tidydraws/tidydraws/.venv/lib/python3.12/site-packages/lets_plot/plot/annotation.py:601: SyntaxWarning: invalid escape sequence '\('
  """

Build our PyMC model and fit.

coords = {
    "groups": workflow.group_names,
    "obs_ind": observed.get_column("obs_ind").to_numpy(),
}

with pm.Model(coords=coords) as model:
    x = pm.Data("x", observed.get_column("x").to_numpy(), dims="obs_ind")
    group_idx = pm.Data(
        "group_idx",
        observed.get_column("group_idx").to_numpy().astype("int64"),
        dims="obs_ind",
    )
    intercept = pm.Normal("intercept", mu=0.0, sigma=2.0, dims="groups")
    beta = pm.Normal("beta", mu=0.0, sigma=1.5, dims="groups")
    sigma = pm.HalfNormal("sigma", sigma=1.0)
    mu = pm.Deterministic(
        "mu",
        intercept[group_idx] + beta[group_idx] * x,
        dims="obs_ind",
    )
    pm.Normal(
        "y",
        mu=mu,
        sigma=sigma,
        observed=observed.get_column("y").to_numpy(),
        dims="obs_ind",
    )
    prior = pm.sample_prior_predictive(
        draws=800,
        random_seed=2029,
        var_names=["intercept", "beta", "sigma"],
    )
    dt = pm.sample(
        draws=400,
        random_seed=2030,
    )

dt.update(prior)
Sampling: [beta, intercept, sigma]
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [intercept, beta, sigma]

Sampling 2 chains for 1_000 tune and 400 draw iterations (2_000 + 800 draws total) took 1 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
dt
<xarray.DataTree>
Group: /
├── Group: /posterior
│       Dimensions:    (chain: 2, draw: 400, groups: 4, obs_ind: 96)
│       Coordinates:
│         * chain      (chain) int64 16B 0 1
│         * draw       (draw) int64 3kB 0 1 2 3 4 5 6 7 ... 393 394 395 396 397 398 399
│         * groups     (groups) <U5 80B 'North' 'South' 'East' 'West'
│         * obs_ind    (obs_ind) int64 768B 0 1 2 3 4 5 6 7 ... 88 89 90 91 92 93 94 95
│       Data variables:
│           intercept  (chain, draw, groups) float64 26kB -1.425 0.1027 ... 1.099 2.101
│           beta       (chain, draw, groups) float64 26kB 0.3943 0.782 ... 1.594 -0.6559
│           sigma      (chain, draw) float64 6kB 0.424 0.4397 0.4762 ... 0.3836 0.5073
│           mu         (chain, draw, obs_ind) float64 614kB -2.239 -2.138 ... 0.6829
│       Attributes:
│           created_at:                 2026-06-27T21:01:58.946255+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                ['chain', 'draw']
│           sampling_time:              0.9663512706756592
│           tuning_steps:               1000
├── Group: /sample_stats
│       Dimensions:                (chain: 2, draw: 400)
│       Coordinates:
│         * chain                  (chain) int64 16B 0 1
│         * draw                   (draw) int64 3kB 0 1 2 3 4 5 ... 395 396 397 398 399
│       Data variables: (12/18)
│           process_time_diff      (chain, draw) float64 6kB 0.000266 ... 0.0003367
│           n_steps                (chain, draw) float64 6kB 3.0 7.0 3.0 ... 7.0 7.0 7.0
│           largest_eigval         (chain, draw) float64 6kB nan nan nan ... nan nan nan
│           tree_depth             (chain, draw) int64 6kB 2 3 2 3 3 2 3 ... 2 3 2 3 3 3
│           step_size              (chain, draw) float64 6kB 0.9588 0.9588 ... 0.7891
│           smallest_eigval        (chain, draw) float64 6kB nan nan nan ... nan nan nan
│           ...                     ...
│           perf_counter_start     (chain, draw) float64 6kB 267.7 267.7 ... 268.0 268.0
│           divergences            (chain, draw) int64 6kB 0 0 0 0 0 0 0 ... 0 0 0 0 0 0
│           max_energy_error       (chain, draw) float64 6kB -0.8613 0.4267 ... 0.8733
│           reached_max_treedepth  (chain, draw) bool 800B False False ... False False
│           diverging              (chain, draw) bool 800B False False ... False False
│           energy                 (chain, draw) float64 6kB 74.81 76.41 ... 70.49 74.05
│       Attributes:
│           created_at:                 2026-06-27T21:01:58.957996+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                ['chain', 'draw']
│           sampling_time:              0.9663512706756592
│           tuning_steps:               1000
├── Group: /observed_data
│       Dimensions:  (obs_ind: 96)
│       Coordinates:
│         * obs_ind  (obs_ind) int64 768B 0 1 2 3 4 5 6 7 8 ... 88 89 90 91 92 93 94 95
│       Data variables:
│           y        (obs_ind) float64 768B -2.106 -1.795 -2.253 ... 1.393 1.081 0.7312
│       Attributes:
│           created_at:                 2026-06-27T21:01:49.078002+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                []
├── Group: /constant_data
│       Dimensions:    (obs_ind: 96)
│       Coordinates:
│         * obs_ind    (obs_ind) int64 768B 0 1 2 3 4 5 6 7 ... 88 89 90 91 92 93 94 95
│       Data variables:
│           x          (obs_ind) float64 768B -2.063 -1.807 -1.804 ... 1.55 1.835 2.163
│           group_idx  (obs_ind) int32 384B 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3
│       Attributes:
│           created_at:                 2026-06-27T21:01:49.079111+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                []
├── Group: /prior
│       Dimensions:    (chain: 1, draw: 800, groups: 4)
│       Coordinates:
│         * chain      (chain) int64 8B 0
│         * draw       (draw) int64 6kB 0 1 2 3 4 5 6 7 ... 793 794 795 796 797 798 799
│         * groups     (groups) <U5 80B 'North' 'South' 'East' 'West'
│       Data variables:
│           intercept  (chain, draw, groups) float64 26kB -1.457 1.446 ... -0.5506
│           beta       (chain, draw, groups) float64 26kB -2.307 -0.09826 ... 3.153
│           sigma      (chain, draw) float64 6kB 1.121 1.051 0.8453 ... 0.1778 0.07849
│       Attributes:
│           created_at:                 2026-06-27T21:01:49.075214+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                ['chain', 'draw']
└── Group: /prior_predictive
        Attributes:
            created_at:                 2026-06-27T21:01:49.077115+00:00
            creation_library:           ArviZ
            creation_library_version:   1.2.0
            creation_library_language:  Python
            inference_library:          pymc
            inference_library_version:  6.0.1
            sample_dims:                ['chain', 'draw']

One call stacks both sources:

compare = td.compare_draws(dt, "beta[groups]")
compare.head()
shape: (5, 5)
chaindrawgroupsbetasource
i64i64strf64str
00"North"0.394303"posterior"
00"South"0.782"posterior"
00"East"1.384254"posterior"
00"West"-0.407587"posterior"
01"North"0.473318"posterior"