---------------------------------------------------------------------- This is the API documentation for the tidydraws library. ---------------------------------------------------------------------- ## Functions Extract and join tidy draws from ArviZ DataTrees. parameter_draws(dt: xarray.core.datatree.DataTree, *var_specs: str, group: str = 'posterior', chain_dim: str = 'chain', draw_dim: str = 'draw') -> polars.dataframe.frame.DataFrame Extract posterior draws for one or more variables into a tidy Polars DataFrame. Parameters ---------- dt : xr.DataTree ArviZ DataTree object (xarray.DataTree) from PyMC sampling. *var_specs : str Variable specifications in the form "var_name" for scalar variables, or "var_name[dim1, dim2, ...]" for array variables. Supports nested/multi-dimensional specifications. The bracketed dimension names must match coordinate names in the InferenceData dataset. group : str Which InferenceData group to extract from (e.g., "posterior", "prior"). Default "posterior". chain_dim, draw_dim : str Names of the chain and draw dimensions. Returns ------- pl.DataFrame Tidy DataFrame with columns: chain, draw, [named dims...], [var_names...] One row per unique (chain, draw, [dim combo]). Examples -------- # Scalar parameter (no duplication) parameter_draws(dt, "sigma") # -> columns: chain, draw, sigma # -> 4 x 1000 = 4,000 rows # Array parameter spread over a named dim parameter_draws(dt, "beta[groups]", "intercept[groups]") # -> columns: chain, draw, groups, beta, intercept # -> 4 x 1000 x 4 = 16,000 rows (NOT 320,000) # Mix of scalar and array (sigma broadcast-joined to group-level params) parameter_draws(dt, "beta[groups]", "sigma") # -> columns: chain, draw, groups, beta, sigma # -> 4 x 1000 x 4 = 16,000 rows; sigma repeated per group (explicit and expected) # Different groups (prior vs posterior) parameter_draws(dt, "beta[groups]", group="prior") # -> extract prior draws for beta # Nested dimensions parameter_draws(dt, "gamma[time, group]") # -> columns: chain, draw, time, group, gamma prediction_draws(dt: xarray.core.datatree.DataTree, newdata, var_name, idata_group='predictions', constant_data_group='predictions_constant_data', join_on='obs_ind') -> polars.dataframe.frame.DataFrame Join posterior predictive draws to a covariate DataFrame. The newdata frame is the *left* table. Draws are attached to it. This avoids the denormalisation problem where coefficient values are duplicated for every obs_ind — the join goes in the right direction. Parameters ---------- dt : xr.DataTree ArviZ DataTree object (xarray.DataTree) containing prediction samples. newdata : pl.DataFrame | pd.DataFrame | None Covariate grid. If None, attempts to read from dt[constant_data_group]. If the group is not found, raises a clear error directing user to pass newdata explicitly. var_name : str Name of the predictive variable to extract (e.g., "mu"). Supports nested specifications like "mu[time, group]" if the variable has multiple dimensions. idata_group : str InferenceData group containing the predictive draws ("predictions", "posterior_predictive", or custom). Default "predictions". constant_data_group : str InferenceData group name for the covariate grid that aligns with the prediction draws. Default "predictions_constant_data" (ArviZ 1.x convention for the constant data paired with a "predictions" group). Set this parameter if your DataTree uses a different naming convention. join_on : str | list[str] Column(s) to join newdata to the draws on. Default "obs_ind". Returns ------- pl.DataFrame Tidy DataFrame with columns: chain, draw, [join_on cols], [covariate cols], var_name One row per (chain, draw, obs_ind). Examples -------- # Basic usage — newdata read from dt pred_df = prediction_draws(dt, newdata=None, var_name="mu") # -> columns: chain, draw, obs_ind, x, group, mu # -> 4 x 1000 x 80 = 320,000 rows # Filter before plotting — only group 0 pred_df.filter(pl.col("group") == 0) # -> 4 x 1000 x 20 = 80,000 rows # Provide custom newdata (e.g., a finer grid) fine_grid = pl.DataFrame({"x": np.linspace(0, 20, 200), "group": ...}) prediction_draws(dt, newdata=fine_grid, var_name="mu") compare_draws(dt: xarray.core.datatree.DataTree, *var_specs: str, groups: list[str] = ['posterior', 'prior'], group_name: str = 'source') -> polars.dataframe.frame.DataFrame Extract and stack draws from multiple groups (e.g., posterior and prior). Calls parameter_draws() for each group, adds a column identifying the source group, and concatenates the results into a single DataFrame for easy comparison. Parameters ---------- dt : xr.DataTree ArviZ DataTree object. *var_specs : str Variable specifications (as for parameter_draws()). groups : list[str] Which groups to extract and stack. Default ["posterior", "prior"]. group_name : str Name of the column identifying the source group. Default "source". Returns ------- pl.DataFrame Stacked draws with an additional column (group_name) indicating source. Example ------- # Extract posterior and prior for side-by-side forest plots compare_df = compare_draws(dt, "beta[groups]", groups=["posterior", "prior"]) # -> columns: chain, draw, groups, beta, source # -> source in {"posterior", "prior"}