arviz_stats.loo_expectations

Contents

arviz_stats.loo_expectations#

arviz_stats.loo_expectations(data, var_name=None, group='posterior_predictive', sample_dims=None, log_likelihood_var_name=None, kind='mean', probs=None, log_weights=None, pareto_k=None)[source]#

Compute weighted expectations using the PSIS-LOO-CV method.

For each observation \(i\), approximates

\[\mathbb{E}_{p(\theta \mid y_{-i})}[g(\theta)] \approx \sum_s w_i^s \, g(\theta^s),\]

where \(w_i^s\) are PSIS-smoothed importance weights and \(g(\theta^s)\) is any scalar quantity associated with draw \(\theta^s\).

If \(g(\theta^s)\) corresponds to posterior predictive samples \(y_i^s \sim p(y_i \mid \theta^s)\), the result is the LOO prediction for observation \(i\). If it corresponds to posterior parameters or derived quantities, the result is the expectation of that quantity under the LOO posterior \(p(\theta \mid y_{-i})\).

The expectations assume that the PSIS approximation is working well. The PSIS-LOO-CV method is described in [1] and [2].

Parameters:
data: DataTree or InferenceData

It should contain the selected group and log_likelihood.

var_name: str, optional

The name of the variable to compute the expectations for.

group: str

Group from which to compute weighted expectations. Defaults to posterior_predictive.

sample_dimsstr or sequence of hashable, optional

Defaults to rcParams["data.sample_dims"]

log_likelihood_var_name: str, optional

The name of the variable in the log_likelihood group to use for loo computation. When log_likelihood contains more than one variable and group is posterior, this must be provided.

kind: str, optional

The kind of expectation to compute. Available options are:

  • ‘mean’. Default.

  • ‘median’.

  • ‘var’.

  • ‘sd’.

  • ‘quantile’.

  • ‘circular_mean’.

  • ‘circular_var’.

  • ‘circular_sd’.

probs: float or list of float, optional

The quantile(s) to compute when kind is ‘quantile’.

log_weightsxarray.DataArray, optional

Pre-computed smoothed log weights from PSIS. Must be provided together with pareto_k. If not provided, PSIS will be computed internally.

pareto_kxarray.DataArray, optional

Pre-computed Pareto k-hat diagnostic values. Must be provided together with log_weights.

Returns:
loo_expecxarray.DataArray or xarray.Dataset

The LOO-weighted expectations, one value per observation.

khatxarray.DataArray or xarray.Dataset

Function-specific Pareto k-hat diagnostics for each observation.

References

[1]

Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing, 27(5) (2017). https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.

[2]

Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024). https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646

Examples

Calculate predictive 0.25 and 0.75 quantiles and the function-specific Pareto k-hat diagnostics

In [1]: from arviz_stats import loo_expectations
   ...: from arviz_base import load_arviz_data
   ...: dt = load_arviz_data("radon")
   ...: loo_expec, khat = loo_expectations(dt, kind="quantile", probs=[0.25, 0.75])
   ...: loo_expec
   ...: 
Out[1]: 
<xarray.DataArray 'y' (quantile: 2, obs_id: 919)> Size: 15kB
array([[-0.21290305,  0.51392885,  0.50130127, ...,  1.00688738,
         1.19475757,  1.22002843],
       [ 0.82123138,  1.49772647,  1.49586019, ...,  1.96598198,
         2.22126888,  2.23358734]], shape=(2, 919))
Coordinates:
  * quantile  (quantile) float64 16B 0.25 0.75
  * obs_id    (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
In [2]: khat
Out[2]: 
<xarray.DataArray 'y' (obs_id: 919)> Size: 7kB
array([ 4.46700951e-02,  2.63854121e-01, -1.07236853e-02,  3.67186708e-01,
        1.42367028e-01,  9.03724670e-02, -9.64916708e-02, -3.23023070e-02,
       -7.63117649e-04, -4.39632041e-02, -4.39632041e-02, -2.77011228e-02,
       -6.87404075e-02,  7.33058816e-03, -4.68570345e-02,  1.45208213e-01,
       -2.77011228e-02, -5.06999352e-02,  1.37056713e-01,  1.39714839e-01,
       -4.65341418e-02,  1.51855305e-01,  1.94804519e-01, -5.13443296e-02,
        2.47162116e-01,  1.42367028e-01,  2.31482545e-01,  2.44197052e-01,
        2.30050539e-01,  1.19981944e-01,  1.45208213e-01,  1.40019803e-01,
        9.59767451e-02,  1.93899206e-01, -4.91620621e-02,  1.27165696e-01,
        1.51855305e-01,  1.29457903e-01, -7.32512043e-02,  1.15699421e-01,
       -6.87404075e-02,  1.39714839e-01,  1.15699421e-01,  2.18827244e-01,
        1.18644098e-01,  7.32251179e-02, -1.97781752e-02, -7.40230544e-02,
        1.33748127e-02, -4.77234266e-02,  1.15699421e-01, -7.40230544e-02,
       -4.77234266e-02, -6.87404075e-02, -5.13443296e-02, -4.65341418e-02,
        1.54028621e-02,  3.24915289e-01,  2.24899198e-01, -2.68898698e-02,
       -2.69178091e-02, -3.40208521e-02,  1.08271366e-01,  3.75426166e-01,
        3.86313391e-02,  5.97667243e-02,  2.09480407e-01,  1.21369876e-01,
        2.44276637e-02,  1.48939483e-01,  1.71784207e-01,  1.96219227e-01,
        2.61662899e-01,  1.05695073e-01, -7.53712062e-02, -1.75064052e-01,
       -6.89317647e-02, -8.41709220e-02, -2.77060430e-03, -1.48189148e-01,
...
        2.06512895e-01,  2.12507258e-01,  1.40228769e-01,  8.34683490e-03,
        1.16030234e-01,  2.30149220e-01,  2.22596841e-02,  9.44759706e-02,
        1.16030234e-01,  1.80182948e-02,  1.45663627e-02, -1.00858590e-01,
        2.75357811e-03,  7.84219623e-02, -1.00858590e-01,  3.23275565e-02,
        1.24061678e-01,  1.26321601e-02,  3.61749643e-02, -9.25393105e-02,
        5.15885509e-03, -6.05826532e-02,  2.08509036e-01,  6.36732947e-02,
        1.68220083e-01,  1.45663627e-02,  1.24409684e-01, -8.44404703e-02,
        7.33142436e-02,  7.17368725e-02, -6.09886444e-02,  2.09901251e-01,
        2.75357811e-03,  2.66859194e-03,  3.61749643e-02, -1.02255701e-01,
        2.05758883e-01,  7.93527385e-02, -1.14139620e-02, -8.44404703e-02,
        3.99240568e-03,  2.90854855e-02,  3.99240568e-03,  2.96225227e-02,
        1.24409684e-01,  1.82664183e-01,  8.13822914e-02,  1.98159623e-01,
        1.93679498e-01,  2.19122933e-01,  2.41194751e-01,  5.92631532e-02,
        1.10015080e-01,  1.23326731e-02,  8.62210686e-02,  1.25478290e-01,
        1.20398193e-01,  2.25184849e-01,  1.25478290e-01,  2.53770715e-02,
        1.83820998e-02,  7.13161230e-02,  6.05499641e-03,  2.68097760e-01,
        1.69499433e-01,  6.25661205e-02,  1.53513527e-01,  2.24842604e-01,
        8.67214216e-02,  2.82213136e-02,  1.99025978e-01,  1.60546472e-01,
        3.28540934e-02,  3.41122021e-02,  1.49708207e-01,  1.52697372e-01,
        2.55589210e-03,  3.53607817e-02,  5.51508714e-02])
Coordinates:
  * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918

Compute LOO posterior mean for the parameter mu:

In [3]: dt = load_arviz_data("centered_eight")
   ...: loo_expec, khat = loo_expectations(
   ...:     dt, group="posterior", var_name="mu")
   ...: loo_expec
   ...: 
Out[3]: 
<xarray.DataArray 'mu' (school: 8)> Size: 64B
array([3.18776346, 3.77895567, 4.43554356, 3.92413193, 4.80000974,
       4.30116324, 2.93111653, 3.90852475])
Coordinates:
  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'