"""NumPyro-specific conversion code."""
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
import numpy as np
from xarray import DataTree
from arviz_base.base import dict_to_dataset, requires
from arviz_base.rcparams import rc_context, rcParams
from arviz_base.utils import expand_dims
def _add_dims(dims_a, dims_b):
"""Merge two dimension mappings by concatenating dimension labels.
Used to combine batch dims with event dims by appending the dims of dims_b to dims_a.
Parameters
----------
dims_a : dict of {str: list of str(s)}
Mapping from site name to a list of dimension labels, typically
representing batch dimensions.
dims_b : dict of {str: list of str(s)}
Mapping from site name to a list of dimension labels, typically
representing event dimensions.
Returns
-------
dict of {str: list of str(s)}
Combined mapping where each site name is associated with the
concatenated dimension labels from both inputs.
"""
merged = defaultdict(list, dims_a)
for k, v in dims_b.items():
merged[k].extend(v)
# Convert back to a regular dict
return dict(merged)
def infer_dims(
model,
model_args=None,
model_kwargs=None,
):
"""Infers batch dim names from numpyro model plates.
Parameters
----------
model : callable
A numpyro model function.
model_args : tuple of (Any, ...), optional
Input args for the numpyro model.
model_kwargs : dict of {str: Any}, optional
Input kwargs for the numpyro model.
Returns
-------
dict of {str: list of str(s)}
Mapping from model site name to list of dimension labels.
"""
import jax
from numpyro import distributions as dist
from numpyro import handlers
from numpyro.infer.initialization import init_to_sample
from numpyro.ops.pytree import PytreeTrace
model_args = tuple() if model_args is None else model_args
model_kwargs = dict() if model_kwargs is None else model_kwargs
def _get_dist_name(fn):
if isinstance(fn, dist.Independent | dist.ExpandedDistribution | dist.MaskedDistribution):
return _get_dist_name(fn.base_dist)
return type(fn).__name__
def get_trace():
# We use `init_to_sample` to get around ImproperUniform distribution,
# which does not have `sample` method.
subs_model = handlers.substitute(
handlers.seed(model, 0),
substitute_fn=init_to_sample,
)
trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
# Work around an issue where jax.eval_shape does not work
# for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
# Here we will remove `fn` and store its name in the trace.
for _, site in trace.items():
if site["type"] == "sample":
site["fn_name"] = _get_dist_name(site.pop("fn"))
elif site["type"] == "deterministic":
site["fn_name"] = "Deterministic"
return PytreeTrace(trace)
# We use eval_shape to avoid any array computation.
trace = jax.eval_shape(get_trace).trace
named_dims = {}
# loop through the trace and pull the batch dim and event dim names
for name, site in trace.items():
batch_dims = [frame.name for frame in sorted(site["cond_indep_stack"], key=lambda x: x.dim)]
event_dims = list(site.get("infer", {}).get("event_dims", []))
# save the dim names leading with batch dims
if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
named_dims[name] = batch_dims + event_dims
return named_dims
class BaseNumPyroConverter(ABC):
"""Base converter with sampler-agnostic logic."""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
*,
posterior=None,
prior=None,
posterior_predictive=None,
predictions=None,
constant_data=None,
predictions_constant_data=None,
log_likelihood=False,
index_origin=None,
coords=None,
dims=None,
pred_dims=None,
extra_event_dims=None,
):
"""Convert NumPyro data into an InferenceData object.
Parameters
----------
posterior : numpyro.infer.mcmc.MCMC | numpyro.infer.svi.SVI | object, optional
Fitted MCMC or SVI posterior object from NumPyro
prior : dict, optional
Prior samples from a NumPyro model
posterior_predictive : dict, optional
Posterior predictive samples for the posterior
predictions : dict, optional
Out of sample predictions
constant_data : dict, optional
Dictionary containing constant data variables mapped to their values.
predictions_constant_data : dict, optional
Constant data used for out-of-sample predictions.
index_origin : int, optional
coords : dict, optional
Map of dimensions to coordinates
dims : dict of {str : list of str}, optional
Map variable names to their coordinates. Will be inferred if they are not provided.
pred_dims : dict, optional
Dims for predictions data. Map variable names to their coordinates.
extra_event_dims : dict, optional
Maps event dims that couldnt be inferred (ie deterministic sites) to their coordinates.
"""
import jax
import numpyro
self.posterior = posterior
self.prior = jax.device_get(prior)
self.posterior_predictive = jax.device_get(posterior_predictive)
self.predictions = predictions
self.constant_data = constant_data
self.predictions_constant_data = predictions_constant_data
self.log_likelihood = log_likelihood
self.index_origin = rcParams["data.index_origin"] if index_origin is None else index_origin
self.coords = coords
self.dims = dims
self.pred_dims = pred_dims
self.extra_event_dims = extra_event_dims
self.numpyro = numpyro
self.sample_shape = self._infer_sample_shape()
self._args, self._kwargs = self._get_train_args_kwargs()
if posterior is not None:
self._samples = self._get_samples()
self.dims = self.dims if self.dims is not None else self.infer_dims()
self.pred_dims = (
self.pred_dims if self.pred_dims is not None else self.infer_pred_dims()
)
observations = {}
if self.model is not None:
trace = self._get_model_trace(
self.model,
model_args=self._args,
model_kwargs=self._kwargs,
key=jax.random.PRNGKey(0),
)
observations = {
name: site["value"]
for name, site in trace.items()
if site["type"] == "sample" and site["is_observed"]
}
self.observations = observations if observations else None
@property
@abstractmethod
def model(self):
"""Return the internal model."""
pass
@abstractmethod
def _infer_sample_shape(self):
"""Return the expected sample shape."""
pass
@abstractmethod
def _get_train_args_kwargs(self):
"""Extract training metadata from posterior object.
Should return:
- self._args: model args
- self._kwargs: model kwargs
"""
pass
@abstractmethod
def _get_samples(self):
"""Extract samples from posterior object.
Should set:
- self._samples: dict of samples
"""
pass
def sample_stats_to_xarray(self):
"""Extract sampler-specific statistics.
Returns
-------
xarray.Dataset | None
Sample statistics dataset, or None if not available
"""
return None
def _get_model_trace(self, model, model_args, model_kwargs, key):
"""Extract the numpyro model trace."""
model_args = model_args or tuple()
model_kwargs = model_kwargs or dict()
# we need to use an init strategy to generate random samples for ImproperUniform sites
seeded_model = self.numpyro.handlers.substitute(
self.numpyro.handlers.seed(model, key),
substitute_fn=self.numpyro.infer.init_to_sample,
)
trace = self.numpyro.handlers.trace(seeded_model).get_trace(*model_args, **model_kwargs)
return trace
def _prepare_predictive_data(self, dct: dict) -> dict:
"""Prepare and reshape posterior_predictive/predictions data.
Parameters
----------
dct : dict
Dictionary of arrays to prepare
Returns
-------
dict
Dictionary with properly shaped arrays for this sampler
"""
expected_size = np.prod(self.sample_shape) # flatten sample dimensions
data = {}
for k, ary in dct.items():
shape = ary.shape
if shape[: len(self.sample_shape)] == self.sample_shape:
# Already in desired sample shape
data[k] = ary
elif shape[0] == expected_size:
# Flattened sample dimension: reshape to sample_shape + remaining dims
data[k] = ary.reshape(self.sample_shape + shape[1:])
else:
# Not compatible, expand dims and warn
data[k] = np.expand_dims(ary, axis=0)
warnings.warn(
f"posterior predictive shape {shape} not compatible with "
"sample_shape {self.sample_shape}. "
"This can mean that some draws or even whole chains are not represented."
)
return data
@requires("posterior")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
data = self._samples
return dict_to_dataset(
data,
inference_library=self.numpyro,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
)
@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from NumPyro posterior."""
if not self.log_likelihood:
return None
data = {}
if self.observations is not None:
samples = self._get_samples()
if hasattr(samples, "_asdict"):
samples = samples._asdict()
log_likelihood_dict = self.numpyro.infer.log_likelihood(
self.model, samples, *self._args, **self._kwargs
)
for obs_name, log_like in log_likelihood_dict.items():
shape = self.sample_shape + log_like.shape[1:]
data[obs_name] = np.reshape(np.asarray(log_like), shape)
return dict_to_dataset(
data,
inference_library=self.numpyro,
dims=self.dims,
coords=self.coords,
index_origin=self.index_origin,
skip_event_dims=True,
)
def translate_posterior_predictive_dict_to_xarray(self, dct, dims):
"""Convert posterior_predictive or prediction samples to xarray."""
data = self._prepare_predictive_data(dct)
return dict_to_dataset(
data,
inference_library=self.numpyro,
coords=self.coords,
dims=dims,
index_origin=self.index_origin,
)
@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(
self.posterior_predictive, self.dims
)
@requires("predictions")
def predictions_to_xarray(self):
"""Convert predictions to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, self.pred_dims)
def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
if self.prior is None:
return {"prior": None, "prior_predictive": None}
if self.posterior is not None:
prior_vars = list(self._samples.keys())
prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
else:
prior_vars = self.prior.keys()
prior_predictive_vars = None
has_chains = len(self.sample_shape) > 1
priors_dict = {
group: (
None
if var_names is None
else dict_to_dataset(
{
k: expand_dims(self.prior[k]) if has_chains else self.prior[k]
for k in var_names
},
inference_library=self.numpyro,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
)
)
for group, var_names in zip(
("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
)
}
return priors_dict
@requires("observations")
@requires("model")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
return dict_to_dataset(
self.observations,
inference_library=self.numpyro,
dims=self.dims,
coords=self.coords,
sample_dims=[],
index_origin=self.index_origin,
)
@requires("constant_data")
def constant_data_to_xarray(self):
"""Convert constant_data to xarray."""
return dict_to_dataset(
self.constant_data,
inference_library=self.numpyro,
dims=self.dims,
coords=self.coords,
sample_dims=[],
index_origin=self.index_origin,
)
@requires("predictions_constant_data")
def predictions_constant_data_to_xarray(self):
"""Convert predictions_constant_data to xarray."""
return dict_to_dataset(
self.predictions_constant_data,
inference_library=self.numpyro,
dims=self.pred_dims,
coords=self.coords,
sample_dims=[],
index_origin=self.index_origin,
)
def to_datatree(self):
"""Convert all available data to an InferenceData object.
Note that if groups can not be created (i.e., there is no `trace`, so
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
will not have those groups.
"""
dicto = {
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"predictions": self.predictions_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
"constant_data": self.constant_data_to_xarray(),
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
}
return DataTree.from_dict({group: ds for group, ds in dicto.items() if ds is not None})
@requires("posterior")
@requires("model")
def infer_dims(self) -> dict[str, list[str]]:
"""Infers dims for input data."""
dims = infer_dims(self.model, self._args, self._kwargs)
if self.extra_event_dims:
dims = _add_dims(dims, self.extra_event_dims)
return dims
@requires("posterior")
@requires("model")
@requires("predictions")
def infer_pred_dims(self) -> dict[str, list[str]]:
"""Infers dims for predictions data."""
dims = infer_dims(self.model, self._args, self._kwargs)
if self.extra_event_dims:
dims = _add_dims(dims, self.extra_event_dims)
return dims
class MCMCConverter(BaseNumPyroConverter):
"""Converter for numpyro MCMC inference results."""
def __init__(
self,
*,
posterior=None,
prior=None,
posterior_predictive=None,
predictions=None,
constant_data=None,
predictions_constant_data=None,
log_likelihood=False,
index_origin=None,
coords=None,
dims=None,
pred_dims=None,
extra_event_dims=None,
num_chains=1,
):
"""Convert NumPyro data into an InferenceData object.
Parameters
----------
posterior : numpyro.mcmc.MCMC
Fitted MCMC object from NumPyro
prior : dict, optional
Prior samples from a NumPyro model
posterior_predictive : dict, optional
Posterior predictive samples for the posterior
predictions : dict, optional
Out of sample predictions
constant_data : dict, optional
Dictionary containing constant data variables mapped to their values.
predictions_constant_data : dict, optional
Constant data used for out-of-sample predictions.
index_origin : int, optional
coords : dict, optional
Map of dimensions to coordinates
dims : dict of {str : list of str}, optional
Map variable names to their coordinates. Will be inferred if they are not provided.
pred_dims : dict, optional
Dims for predictions data. Map variable names to their coordinates.
extra_event_dims : dict, optional
Maps event dims that couldnt be inferred (ie deterministic sites) to their coordinates.
num_chains : int, optional
Number of chains used for sampling. Ignored if posterior is present.
"""
self.nchains = num_chains
super().__init__(
posterior=posterior,
prior=prior,
posterior_predictive=posterior_predictive,
predictions=predictions,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
log_likelihood=log_likelihood,
index_origin=index_origin,
coords=coords,
dims=dims,
pred_dims=pred_dims,
extra_event_dims=extra_event_dims,
)
@property
def model(self):
"""Return the internal model."""
if self.posterior is None:
return None
return self.posterior.sampler.model
def _infer_sample_shape(self):
"""Return the expected sample shape."""
if self.posterior is not None:
return (
self.posterior.num_chains,
self.posterior.num_samples // self.posterior.thinning,
)
def arbitrary_element(dct):
return next(iter(dct.values()))
get_from = None
if self.predictions is not None:
get_from = self.predictions
elif self.posterior_predictive is not None:
get_from = self.posterior_predictive
elif self.prior is not None:
get_from = self.prior
if (
get_from is None
and self.constant_data is None
and self.predictions_constant_data is None
):
raise ValueError(
"When constructing InferenceData must have at least"
" one of posterior, prior, posterior_predictive or predictions."
)
if get_from is not None:
aelem = arbitrary_element(get_from)
self.ndraws = (
aelem.shape[0] // self.nchains if self.nchains is not None else aelem.shape[0]
)
return (
self.nchains,
self.ndraws,
)
def _get_train_args_kwargs(self):
"""Extract training metadata from posterior object.
Should return:
- self._args: model args
- self._kwargs: model kwargs
"""
return (
(self.posterior._args, self.posterior._kwargs)
if self.posterior is not None
else (tuple(), dict())
)
def _get_samples(self):
"""Extract samples from MCMC posterior."""
import jax
samples = jax.device_get(self.posterior.get_samples(group_by_chain=True))
if hasattr(samples, "_asdict"):
# In case it is easy to convert to a dictionary, as in the case of namedtuples
samples = {k: expand_dims(v) for k, v in samples._asdict().items()}
if not isinstance(samples, dict):
# handle the case we run MCMC with a general potential_fn
# (instead of a NumPyro model) whose args is not a dictionary
# (e.g. f(x) = x ** 2)
tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0]
samples = {f"Param:{i}": jax.device_get(v) for i, v in enumerate(tree_flatten_samples)}
return samples
# self.nchains, self.ndraws = (
# self.posterior.num_chains,
# self.posterior.num_samples // self.posterior.thinning,
# )
# self.model = posterior.sampler.model
# # model arguments and keyword arguments
# self._args = posterior._args # pylint: disable=protected-access
# self._kwargs = posterior._kwargs # pylint: disable=protected-access
@requires("posterior")
def sample_stats_to_xarray(self):
"""Extract sample_stats from NumPyro MCMC posterior."""
rename_key = {
"potential_energy": "lp",
"adapt_state.step_size": "step_size",
"num_steps": "n_steps",
"accept_prob": "acceptance_rate",
}
data = {}
for stat, value in self.posterior.get_extra_fields(group_by_chain=True).items():
if isinstance(value, dict | tuple):
continue
name = rename_key.get(stat, stat)
value_cp = value.copy()
data[name] = value_cp
if stat == "num_steps":
data["tree_depth"] = np.log2(value_cp).astype(int) + 1
return dict_to_dataset(
data,
inference_library=self.numpyro,
dims=None,
coords=self.coords,
index_origin=self.index_origin,
)
class SVIConverter(BaseNumPyroConverter):
"""Converter for SVI (Stochastic Variational Inference)."""
def __init__(
self,
svi,
*,
svi_result,
model_args=None,
model_kwargs=None,
prior=None,
posterior_predictive=None,
predictions=None,
constant_data=None,
predictions_constant_data=None,
log_likelihood=None,
index_origin=None,
coords=None,
dims=None,
pred_dims=None,
extra_event_dims=None,
num_samples=1000,
):
"""Initialize SVI converter.
Parameters
----------
svi : numpyro.infer.svi.SVI
Numpyro SVI instance used for fitting the model.
svi_result : numpyro.infer.svi.SVIRunResult
SVI results from a fitted model.
model_args : tuple, optional
Model arguments, should match those used for fitting the model.
model_kwargs : dict, optional
Model keyword arguments, should match those used for fitting the model.
prior : dict, optional
Prior samples from a NumPyro model
posterior_predictive : dict, optional
Posterior predictive samples for the posterior
predictions : dict, optional
Out of sample predictions
constant_data : dict, optional
Dictionary containing constant data variables mapped to their values.
predictions_constant_data : dict, optional
Constant data used for out-of-sample predictions.
index_origin : int, optional
coords : dict, optional
Map of dimensions to coordinates
dims : dict of {str : list of str}, optional
Map variable names to their coordinates. Will be inferred if they are not provided.
pred_dims : dict, optional
Dims for predictions data. Map variable names to their coordinates. Default behavior is
to infer dims if this is not provided
extra_event_dims : dict, optional
Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
their coordinates.
num_samples : int, optional
The number of posterior samples to use.
"""
self.svi = svi
self.svi_result = svi_result
self._args = model_args or tuple()
self._kwargs = model_kwargs or dict()
self.num_samples = num_samples
# Pass the wrapper as 'posterior' to base class
super().__init__(
posterior=svi,
prior=prior,
posterior_predictive=posterior_predictive,
predictions=predictions,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
log_likelihood=log_likelihood,
index_origin=index_origin,
coords=coords,
dims=dims,
pred_dims=pred_dims,
extra_event_dims=extra_event_dims,
)
@property
def model(self):
"""Return the internal model."""
return getattr(self.svi.guide, "model", self.svi.model)
def _infer_sample_shape(self):
"""Return the expected sample shape."""
return (self.num_samples,)
def _get_train_args_kwargs(self):
return (self._args, self._kwargs) if self.svi is not None else (tuple(), dict())
def _get_samples(self):
"""Extract samples from SVI guide."""
import jax
key = jax.random.PRNGKey(0)
if isinstance(self.svi.guide, self.numpyro.infer.autoguide.AutoGuide):
return self.svi.guide.sample_posterior(
key,
self.svi_result.params,
*self._args,
sample_shape=(self.num_samples,),
**self._kwargs,
)
# if a custom guide is provided, sample by hand
predictive = self.numpyro.infer.Predictive(
self.svi.guide, params=self.svi_result.params, num_samples=self.num_samples
)
return predictive(key, *self._args, **self._kwargs)
[docs]
def from_numpyro(
posterior=None,
*,
prior=None,
posterior_predictive=None,
predictions=None,
constant_data=None,
predictions_constant_data=None,
log_likelihood=None,
index_origin=None,
coords=None,
dims=None,
pred_dims=None,
extra_event_dims=None,
num_chains=1,
):
"""Convert NumPyro data into a DataTree object.
For a usage example read :ref:`numpyro_conversion`
If no dims are provided, this will infer batch dim names from NumPyro model plates.
For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}`
can be provided in numpyro.sample, i.e.::
# equivalent to dims entry, {"gamma": ["groups"]}
gamma = numpyro.sample(
"gamma",
dist.ZeroSumNormal(1, event_shape=(n_groups,)),
infer={"event_dims":["groups"]}
)
There is also an additional `extra_event_dims` input to cover any edge cases, for instance
deterministic sites with event dims (which dont have an `infer` argument to provide metadata).
Parameters
----------
posterior : numpyro.infer.mcmc.MCMC
Fitted MCMC object from NumPyro
prior : dict, optional
Prior samples from a NumPyro model
posterior_predictive : dict, optional
Posterior predictive samples for the posterior
predictions : dict, optional
Out of sample predictions
constant_data : dict, optional
Dictionary containing constant data variables mapped to their values.
predictions_constant_data : dict, optional
Constant data used for out-of-sample predictions.
index_origin : int, optional
coords : dict, optional
Map of dimensions to coordinates
dims : dict of {str : list of str}, optional
Map variable names to their coordinates. Will be inferred if they are not provided.
pred_dims : dict, optional
Dims for predictions data. Map variable names to their coordinates. Default behavior is to
infer dims if this is not provided
extra_event_dims : dict, optional
Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
their coordinates.
num_chains : int, default 1
Number of chains used for sampling. Ignored if posterior is present.
Returns
-------
DataTree
"""
with rc_context(rc={"data.sample_dims": ["chain", "draw"]}):
return MCMCConverter(
posterior=posterior,
prior=prior,
posterior_predictive=posterior_predictive,
predictions=predictions,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
log_likelihood=log_likelihood,
index_origin=index_origin,
coords=coords,
dims=dims,
pred_dims=pred_dims,
extra_event_dims=extra_event_dims,
num_chains=num_chains,
).to_datatree()
def from_numpyro_svi(
svi,
*,
svi_result,
model_args=None,
model_kwargs=None,
prior=None,
posterior_predictive=None,
predictions=None,
constant_data=None,
predictions_constant_data=None,
log_likelihood=None,
index_origin=None,
coords=None,
dims=None,
pred_dims=None,
extra_event_dims=None,
num_samples: int = 1000,
):
"""Convert NumPyro SVI results into a DataTree object.
For a usage example read :ref:`numpyro_conversion`
If no dims are provided, this will infer batch dim names from NumPyro model plates.
For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}`
can be provided in numpyro.sample, i.e.::
# equivalent to dims entry, {"gamma": ["groups"]}
gamma = numpyro.sample(
"gamma",
dist.ZeroSumNormal(1, event_shape=(n_groups,)),
infer={"event_dims":["groups"]}
)
There is also an additional `extra_event_dims` input to cover any edge cases, for instance
deterministic sites with event dims (which dont have an `infer` argument to provide metadata).
Parameters
----------
svi : numpyro.infer.svi.SVI
Numpyro SVI instance used for fitting the model.
svi_result : numpyro.infer.svi.SVIRunResult
SVI results from a fitted model.
model_args : tuple, optional
Model arguments, should match those used for fitting the model.
model_kwargs : dict, optional
Model keyword arguments, should match those used for fitting the model.
prior : dict, optional
Prior samples from a NumPyro model
posterior_predictive : dict, optional
Posterior predictive samples for the posterior
predictions : dict, optional
Out of sample predictions
constant_data : dict, optional
Dictionary containing constant data variables mapped to their values.
predictions_constant_data : dict, optional
Constant data used for out-of-sample predictions.
index_origin : int, optional
coords : dict, optional
Map of dimensions to coordinates
dims : dict of {str : list of str}, optional
Map variable names to their coordinates. Will be inferred if they are not provided.
pred_dims : dict, optional
Dims for predictions data. Map variable names to their coordinates. Default behavior is to
infer dims if this is not provided
extra_event_dims : dict, optional
Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
their coordinates.
Returns
-------
DataTree
"""
with rc_context(rc={"data.sample_dims": ["sample"]}):
return SVIConverter(
svi,
svi_result=svi_result,
model_args=model_args,
model_kwargs=model_kwargs,
num_samples=num_samples,
prior=prior,
posterior_predictive=posterior_predictive,
predictions=predictions,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
log_likelihood=log_likelihood,
index_origin=index_origin,
coords=coords,
dims=dims,
pred_dims=pred_dims,
extra_event_dims=extra_event_dims,
).to_datatree()