arviz_stats.loo_expectations#
- arviz_stats.loo_expectations(data, var_name=None, group='posterior_predictive', sample_dims=None, log_likelihood_var_name=None, kind='mean', probs=None, log_weights=None, pareto_k=None)[source]#
Compute weighted expectations using the PSIS-LOO-CV method.
For each observation \(i\), approximates
\[\mathbb{E}_{p(\theta \mid y_{-i})}[g(\theta)] \approx \sum_s w_i^s \, g(\theta^s),\]where \(w_i^s\) are PSIS-smoothed importance weights and \(g(\theta^s)\) is any scalar quantity associated with draw \(\theta^s\).
If \(g(\theta^s)\) corresponds to posterior predictive samples \(y_i^s \sim p(y_i \mid \theta^s)\), the result is the LOO prediction for observation \(i\). If it corresponds to posterior parameters or derived quantities, the result is the expectation of that quantity under the LOO posterior \(p(\theta \mid y_{-i})\).
The expectations assume that the PSIS approximation is working well. The PSIS-LOO-CV method is described in [1] and [2].
- Parameters:
- data: DataTree or InferenceData
It should contain the selected group and log_likelihood.
- var_name: str, optional
The name of the variable to compute the expectations for.
- group: str
Group from which to compute weighted expectations. Defaults to
posterior_predictive.- sample_dims
stror sequence ofhashable, optional Defaults to
rcParams["data.sample_dims"]- log_likelihood_var_name: str, optional
The name of the variable in the log_likelihood group to use for loo computation. When log_likelihood contains more than one variable and group is
posterior, this must be provided.- kind: str, optional
The kind of expectation to compute. Available options are:
‘mean’. Default.
‘median’.
‘var’.
‘sd’.
‘quantile’.
‘circular_mean’.
‘circular_var’.
‘circular_sd’.
- probs: float or list of float, optional
The quantile(s) to compute when kind is ‘quantile’.
- log_weights
xarray.DataArray, optional Pre-computed smoothed log weights from PSIS. Must be provided together with pareto_k. If not provided, PSIS will be computed internally.
- pareto_k
xarray.DataArray, optional Pre-computed Pareto k-hat diagnostic values. Must be provided together with log_weights.
- Returns:
- loo_expec
xarray.DataArrayorxarray.Dataset The LOO-weighted expectations, one value per observation.
- khat
xarray.DataArrayorxarray.Dataset Function-specific Pareto k-hat diagnostics for each observation.
- loo_expec
References
[1]Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing, 27(5) (2017). https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.
[2]Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024). https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646
Examples
Calculate predictive 0.25 and 0.75 quantiles and the function-specific Pareto k-hat diagnostics
In [1]: from arviz_stats import loo_expectations ...: from arviz_base import load_arviz_data ...: dt = load_arviz_data("radon") ...: loo_expec, khat = loo_expectations(dt, kind="quantile", probs=[0.25, 0.75]) ...: loo_expec ...: Out[1]: <xarray.DataArray 'y' (quantile: 2, obs_id: 919)> Size: 15kB array([[-0.21290305, 0.51392885, 0.50130127, ..., 1.00688738, 1.19475757, 1.22002843], [ 0.82123138, 1.49772647, 1.49586019, ..., 1.96598198, 2.22126888, 2.23358734]], shape=(2, 919)) Coordinates: * quantile (quantile) float64 16B 0.25 0.75 * obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
In [2]: khat Out[2]: <xarray.DataArray 'y' (obs_id: 919)> Size: 7kB array([ 4.46700951e-02, 2.63854121e-01, -1.07236853e-02, 3.67186708e-01, 1.42367028e-01, 9.03724670e-02, -9.64916708e-02, -3.23023070e-02, -7.63117649e-04, -4.39632041e-02, -4.39632041e-02, -2.77011228e-02, -6.87404075e-02, 7.33058816e-03, -4.68570345e-02, 1.45208213e-01, -2.77011228e-02, -5.06999352e-02, 1.37056713e-01, 1.39714839e-01, -4.65341418e-02, 1.51855305e-01, 1.94804519e-01, -5.13443296e-02, 2.47162116e-01, 1.42367028e-01, 2.31482545e-01, 2.44197052e-01, 2.30050539e-01, 1.19981944e-01, 1.45208213e-01, 1.40019803e-01, 9.59767451e-02, 1.93899206e-01, -4.91620621e-02, 1.27165696e-01, 1.51855305e-01, 1.29457903e-01, -7.32512043e-02, 1.15699421e-01, -6.87404075e-02, 1.39714839e-01, 1.15699421e-01, 2.18827244e-01, 1.18644098e-01, 7.32251179e-02, -1.97781752e-02, -7.40230544e-02, 1.33748127e-02, -4.77234266e-02, 1.15699421e-01, -7.40230544e-02, -4.77234266e-02, -6.87404075e-02, -5.13443296e-02, -4.65341418e-02, 1.54028621e-02, 3.24915289e-01, 2.24899198e-01, -2.68898698e-02, -2.69178091e-02, -3.40208521e-02, 1.08271366e-01, 3.75426166e-01, 3.86313391e-02, 5.97667243e-02, 2.09480407e-01, 1.21369876e-01, 2.44276637e-02, 1.48939483e-01, 1.71784207e-01, 1.96219227e-01, 2.61662899e-01, 1.05695073e-01, -7.53712062e-02, -1.75064052e-01, -6.89317647e-02, -8.41709220e-02, -2.77060430e-03, -1.48189148e-01, ... 2.06512895e-01, 2.12507258e-01, 1.40228769e-01, 8.34683490e-03, 1.16030234e-01, 2.30149220e-01, 2.22596841e-02, 9.44759706e-02, 1.16030234e-01, 1.80182948e-02, 1.45663627e-02, -1.00858590e-01, 2.75357811e-03, 7.84219623e-02, -1.00858590e-01, 3.23275565e-02, 1.24061678e-01, 1.26321601e-02, 3.61749643e-02, -9.25393105e-02, 5.15885509e-03, -6.05826532e-02, 2.08509036e-01, 6.36732947e-02, 1.68220083e-01, 1.45663627e-02, 1.24409684e-01, -8.44404703e-02, 7.33142436e-02, 7.17368725e-02, -6.09886444e-02, 2.09901251e-01, 2.75357811e-03, 2.66859194e-03, 3.61749643e-02, -1.02255701e-01, 2.05758883e-01, 7.93527385e-02, -1.14139620e-02, -8.44404703e-02, 3.99240568e-03, 2.90854855e-02, 3.99240568e-03, 2.96225227e-02, 1.24409684e-01, 1.82664183e-01, 8.13822914e-02, 1.98159623e-01, 1.93679498e-01, 2.19122933e-01, 2.41194751e-01, 5.92631532e-02, 1.10015080e-01, 1.23326731e-02, 8.62210686e-02, 1.25478290e-01, 1.20398193e-01, 2.25184849e-01, 1.25478290e-01, 2.53770715e-02, 1.83820998e-02, 7.13161230e-02, 6.05499641e-03, 2.68097760e-01, 1.69499433e-01, 6.25661205e-02, 1.53513527e-01, 2.24842604e-01, 8.67214216e-02, 2.82213136e-02, 1.99025978e-01, 1.60546472e-01, 3.28540934e-02, 3.41122021e-02, 1.49708207e-01, 1.52697372e-01, 2.55589210e-03, 3.53607817e-02, 5.51508714e-02]) Coordinates: * obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
Compute LOO posterior mean for the parameter
mu:In [3]: dt = load_arviz_data("centered_eight") ...: loo_expec, khat = loo_expectations( ...: dt, group="posterior", var_name="mu") ...: loo_expec ...: Out[3]: <xarray.DataArray 'mu' (school: 8)> Size: 64B array([3.18776346, 3.77895567, 4.43554356, 3.92413193, 4.80000974, 4.30116324, 2.93111653, 3.90852475]) Coordinates: * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'