prediction_draws()

Join posterior predictive draws to a covariate DataFrame.

Usage

prediction_draws(
    dt,
    newdata,
    var_name,
    idata_group="predictions",
    constant_data_group="predictions_constant_data",
    join_on="obs_ind"
)

The newdata frame is the left table. Draws are attached to it. This avoids the denormalisation problem where coefficient values are duplicated for every obs_ind — the join goes in the right direction.

Parameters

dt: xr.DataTree

ArviZ DataTree object (xarray.DataTree) containing prediction samples.

newdata: pl.DataFrame | pd.DataFrame | None

Covariate grid. If None, attempts to read from dt[constant_data_group]. If the group is not found, raises a clear error directing user to pass newdata explicitly.

var_name: str

Name of the predictive variable to extract (e.g., “mu”). Supports nested specifications like “mu[time, group]” if the variable has multiple dimensions.

idata_group: str = "predictions"

InferenceData group containing the predictive draws (“predictions”, “posterior_predictive”, or custom). Default “predictions”.

constant_data_group: str = "predictions_constant_data"

InferenceData group name for the covariate grid that aligns with the prediction draws. Default “predictions_constant_data” (ArviZ 1.x convention for the constant data paired with a “predictions” group). Set this parameter if your DataTree uses a different naming convention.

join_on: str | list[str] = "obs_ind"
Column(s) to join newdata to the draws on. Default “obs_ind”.

Returns

pl.DataFrame
Tidy DataFrame with columns: chain, draw, [join_on cols], [covariate cols], var_name One row per (chain, draw, obs_ind).

Examples

Basic usage — newdata read from dt

pred_df = prediction_draws(dt, newdata=None, var_name=“mu”) # -> columns: chain, draw, obs_ind, x, group, mu # -> 4 x 1000 x 80 = 320,000 rows

Filter before plotting — only group 0

pred_df.filter(pl.col(“group”) == 0) # -> 4 x 1000 x 20 = 80,000 rows

Provide custom newdata (e.g., a finer grid)

fine_grid = pl.DataFrame({“x”: np.linspace(0, 20, 200), “group”: …}) prediction_draws(dt, newdata=fine_grid, var_name=“mu”)