## fmt: off
#| code-fold: true
#| cache: true
## fmt: on
from pathlib import Path
import sys
import polars as pl
import pymc as pm
import tidydraws as td
from lets_plot import (
LetsPlot,
aes,
as_discrete,
facet_wrap,
geom_density,
geom_hline,
geom_line,
geom_pointrange,
geom_ribbon,
ggsave,
gggrid,
ggplot,
labs,
)
for parent in [Path.cwd(), *Path.cwd().parents]:
helper_dir = parent / "docs" / "examples"
if (helper_dir / "_pymc_workflow.py").exists():
sys.path.insert(0, str(helper_dir))
break
from _pymc_workflow import simulate_grouped_regression
LetsPlot.setup_html()
workflow = simulate_grouped_regression(seed=2026)
observed = workflow.observed.with_columns(pl.col("groups").alias("group"))
coords = {
"groups": workflow.group_names,
"obs_ind": observed.get_column("obs_ind").to_numpy(),
}
with pm.Model(coords=coords) as model:
x = pm.Data("x", observed.get_column("x").to_numpy(), dims="obs_ind")
group_idx = pm.Data(
"group_idx",
observed.get_column("group_idx").to_numpy().astype("int64"),
dims="obs_ind",
)
intercept = pm.Normal("intercept", mu=0.0, sigma=2.0, dims="groups")
beta = pm.Normal("beta", mu=0.0, sigma=1.5, dims="groups")
sigma = pm.HalfNormal("sigma", sigma=1.0)
mu = pm.Deterministic(
"mu",
intercept[group_idx] + beta[group_idx] * x,
dims="obs_ind",
)
pm.Normal(
"y",
mu=mu,
sigma=sigma,
observed=observed.get_column("y").to_numpy(),
dims="obs_ind",
)
prior = pm.sample_prior_predictive(
draws=800,
random_seed=2029,
var_names=["intercept", "beta", "sigma"],
)
dt = pm.sample(
draws=400,
tune=400,
random_seed=2026,
progressbar=False,
compute_convergence_checks=False,
)
pm.sample_posterior_predictive(
dt,
var_names=["mu"],
predictions=True,
extend_inferencedata=True,
random_seed=2028,
progressbar=False,
)
dt.update(prior)
Sampling: [beta, intercept, sigma]
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [intercept, beta, sigma]
Sampling 2 chains for 400 tune and 400 draw iterations (800 + 800 draws total) took 0 seconds.
Sampling: []