Parameter space

Fit a PyMC model, extract posterior draws with parameter_draws(), and plot them

parameter_draws() is the entry point for parameter-space plots: it pulls posterior draws out of an ArviZ DataTree into a tidy Polars DataFrame — one row per chain × draw × coordinate, each variable a column. This example starts with simulated observed data, fits a real PyMC model, and then uses the resulting posterior draws for densities, intervals, contrasts, and cross-parameter plots.

The string spec

A spec names a variable and, in brackets, the dimensions to spread over. The bracketed names must match coordinate names in the DataTree.

Spec Meaning Rows
"sigma" scalar parameter chain × draw
"beta[groups]" one-dimensional array chain × draw × groups
"intercept[groups]" another group-level array chain × draw × groups

Request several variables in one call and tidydraws joins them: variables sharing a dimension are inner-joined; a scalar is broadcast across an array’s dimensions.

Full PyMC workflow

The data are generated from four groups with deliberately different intercepts and slopes, including one negative slope. That separation makes the posterior plots diagnose real group differences rather than noise around one common line.

Code
from pathlib import Path
import sys

import numpy as np
import polars as pl
import pymc as pm
import tidydraws as td
from lets_plot import (
    LetsPlot,
    aes,
    facet_wrap,
    geom_density,
    geom_hline,
    geom_jitter,
    geom_line,
    geom_linerange,
    geom_point,
    geom_pointrange,
    geom_vline,
    ggplot,
    labs,
    scale_color_gradient,
)

for parent in [Path.cwd(), *Path.cwd().parents]:
    helper_dir = parent / "docs" / "examples"
    if (helper_dir / "_pymc_workflow.py").exists():
        sys.path.insert(0, str(helper_dir))
        break

from _pymc_workflow import interval_summary, simulate_grouped_regression

LetsPlot.setup_html()

workflow = simulate_grouped_regression(seed=2026)
observed = workflow.observed
truth = workflow.truth

The observed data and true generating lines show the group separation before fitting.

(
    ggplot(observed.sort(["groups", "x"]).to_pandas(), aes("x", "y"))
    + geom_point(aes(color="groups"), alpha=0.65, size=2.0)
    + geom_line(aes(y="mu_true", color="groups"), size=1.0)
    + labs(
        x="x", y="y", color="group", title="Observed data from known group differences"
    )
)

Simulated observed data with true group-specific regression lines.

Now we build our PyMC model and fit.

coords = {
    "groups": workflow.group_names,
    "obs_ind": observed.get_column("obs_ind").to_numpy(),
}

with pm.Model(coords=coords) as model:
    x = pm.Data("x", observed.get_column("x").to_numpy(), dims="obs_ind")
    group_idx = pm.Data(
        "group_idx",
        observed.get_column("group_idx").to_numpy().astype("int64"),
        dims="obs_ind",
    )
    intercept = pm.Normal("intercept", mu=0.0, sigma=2.0, dims="groups")
    beta = pm.Normal("beta", mu=0.0, sigma=1.5, dims="groups")
    sigma = pm.HalfNormal("sigma", sigma=1.0)
    mu = pm.Deterministic(
        "mu",
        intercept[group_idx] + beta[group_idx] * x,
        dims="obs_ind",
    )
    pm.Normal(
        "y",
        mu=mu,
        sigma=sigma,
        observed=observed.get_column("y").to_numpy(),
        dims="obs_ind",
    )
    dt = pm.sample(
        draws=400,
        tune=400,
        random_seed=2026,
    )
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [intercept, beta, sigma]

Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 1 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
dt
<xarray.DataTree>
Group: /
├── Group: /posterior
│       Dimensions:    (chain: 2, draw: 400, groups: 4, obs_ind: 96)
│       Coordinates:
│         * chain      (chain) int64 16B 0 1
│         * draw       (draw) int64 3kB 0 1 2 3 4 5 6 7 ... 393 394 395 396 397 398 399
│         * groups     (groups) <U5 80B 'North' 'South' 'East' 'West'
│         * obs_ind    (obs_ind) int64 768B 0 1 2 3 4 5 6 7 ... 88 89 90 91 92 93 94 95
│       Data variables:
│           intercept  (chain, draw, groups) float64 26kB -1.366 0.2531 ... 1.261 2.061
│           beta       (chain, draw, groups) float64 26kB 0.3564 0.6474 ... -0.5574
│           sigma      (chain, draw) float64 6kB 0.4225 0.4352 0.4174 ... 0.4691 0.4101
│           mu         (chain, draw, obs_ind) float64 614kB -2.101 -2.01 ... 1.037 0.855
│       Attributes:
│           created_at:                 2026-06-27T21:02:59.806189+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                ['chain', 'draw']
│           sampling_time:              0.5730710029602051
│           tuning_steps:               400
├── Group: /sample_stats
│       Dimensions:                (chain: 2, draw: 400)
│       Coordinates:
│         * chain                  (chain) int64 16B 0 1
│         * draw                   (draw) int64 3kB 0 1 2 3 4 5 ... 395 396 397 398 399
│       Data variables: (12/18)
│           n_steps                (chain, draw) float64 6kB 3.0 3.0 7.0 ... 7.0 7.0 3.0
│           index_in_trajectory    (chain, draw) int64 6kB 1 -3 3 2 -3 4 ... -2 1 3 -6 2
│           tree_depth             (chain, draw) int64 6kB 2 2 3 3 3 3 2 ... 2 2 3 3 3 2
│           divergences            (chain, draw) int64 6kB 0 0 0 0 0 0 0 ... 0 0 0 0 0 0
│           energy_error           (chain, draw) float64 6kB 0.4233 -0.2305 ... -0.5857
│           energy                 (chain, draw) float64 6kB 75.28 75.22 ... 71.42 69.1
│           ...                     ...
│           smallest_eigval        (chain, draw) float64 6kB nan nan nan ... nan nan nan
│           step_size              (chain, draw) float64 6kB 0.9906 0.9906 ... 0.7741
│           acceptance_rate        (chain, draw) float64 6kB 0.5714 0.9075 ... 1.0
│           lp                     (chain, draw) float64 6kB -70.64 -68.08 ... -64.54
│           reached_max_treedepth  (chain, draw) bool 800B False False ... False False
│           perf_counter_diff      (chain, draw) float64 6kB 0.0003232 ... 0.0002932
│       Attributes:
│           created_at:                 2026-06-27T21:02:59.816046+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                ['chain', 'draw']
│           sampling_time:              0.5730710029602051
│           tuning_steps:               400
├── Group: /observed_data
│       Dimensions:  (obs_ind: 96)
│       Coordinates:
│         * obs_ind  (obs_ind) int64 768B 0 1 2 3 4 5 6 7 8 ... 88 89 90 91 92 93 94 95
│       Data variables:
│           y        (obs_ind) float64 768B -2.106 -1.795 -2.253 ... 1.393 1.081 0.7312
│       Attributes:
│           created_at:                 2026-06-27T21:02:59.820629+00:00
│           creation_library:           ArviZ
│           creation_library_version:   1.2.0
│           creation_library_language:  Python
│           inference_library:          pymc
│           inference_library_version:  6.0.1
│           sample_dims:                []
└── Group: /constant_data
        Dimensions:    (obs_ind: 96)
        Coordinates:
          * obs_ind    (obs_ind) int64 768B 0 1 2 3 4 5 6 7 ... 88 89 90 91 92 93 94 95
        Data variables:
            x          (obs_ind) float64 768B -2.063 -1.807 -1.804 ... 1.55 1.835 2.163
            group_idx  (obs_ind) int32 384B 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3
        Attributes:
            created_at:                 2026-06-27T21:02:59.821832+00:00
            creation_library:           ArviZ
            creation_library_version:   1.2.0
            creation_library_language:  Python
            inference_library:          pymc
            inference_library_version:  6.0.1
            sample_dims:                []

One parameter_draws() call extracts group slopes and intercepts from the fitted posterior:

beta_df = td.parameter_draws(dt, "beta[groups]", "intercept[groups]")
beta_df.head()
shape: (5, 5)
chaindrawgroupsbetaintercept
i64i64strf64f64
00"North"0.356443-1.365803
00"South"0.6473710.253135
00"East"1.3842581.345093
00"West"-0.5386692.264232
01"North"0.505888-1.138709