prediction_draws()
Join posterior predictive draws to a covariate DataFrame.
Usage
prediction_draws(
dt,
newdata,
var_name,
idata_group="predictions",
constant_data_group="predictions_constant_data",
join_on="obs_ind"
)The newdata frame is the left table. Draws are attached to it. This avoids the denormalisation problem where coefficient values are duplicated for every obs_ind — the join goes in the right direction.
Parameters
dt: xr.DataTree-
ArviZ DataTree object (xarray.DataTree) containing prediction samples.
newdata: pl.DataFrame | pd.DataFrame | None-
Covariate grid. If None, attempts to read from dt[constant_data_group]. If the group is not found, raises a clear error directing user to pass newdata explicitly.
var_name: str-
Name of the predictive variable to extract (e.g., “mu”). Supports nested specifications like “mu[time, group]” if the variable has multiple dimensions.
idata_group: str = "predictions"-
InferenceData group containing the predictive draws (“predictions”, “posterior_predictive”, or custom). Default “predictions”.
constant_data_group: str = "predictions_constant_data"-
InferenceData group name for the covariate grid that aligns with the prediction draws. Default “predictions_constant_data” (ArviZ 1.x convention for the constant data paired with a “predictions” group). Set this parameter if your DataTree uses a different naming convention.
join_on: str | list[str] = "obs_ind"- Column(s) to join newdata to the draws on. Default “obs_ind”.
Returns
pl.DataFrame- Tidy DataFrame with columns: chain, draw, [join_on cols], [covariate cols], var_name One row per (chain, draw, obs_ind).
Examples
Basic usage — newdata read from dt
pred_df = prediction_draws(dt, newdata=None, var_name=“mu”) # -> columns: chain, draw, obs_ind, x, group, mu # -> 4 x 1000 x 80 = 320,000 rows
Filter before plotting — only group 0
pred_df.filter(pl.col(“group”) == 0) # -> 4 x 1000 x 20 = 80,000 rows
Provide custom newdata (e.g., a finer grid)
fine_grid = pl.DataFrame({“x”: np.linspace(0, 20, 200), “group”: …}) prediction_draws(dt, newdata=fine_grid, var_name=“mu”)