Skip to content

Commit 6b6c00a

Browse files
clarkmiyamotoclarkmiyamoto
andauthored
Handles objects with __array__ (For jax.numpy.ndarray support) (#2481)
* handles objects with __array__ * format * added __array__ into docstring * black formatting --------- Co-authored-by: clarkmiyamoto <[email protected]>
1 parent f8b4bfc commit 6b6c00a

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

arviz/data/converters.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import xarray as xr
5+
import pandas as pd
56

67
try:
78
from tree import is_nested
@@ -44,6 +45,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
4445
| dict: creates an xarray dataset as the only group
4546
| numpy array: creates an xarray dataset as the only group, gives the
4647
array an arbitrary name
48+
| object with __array__: converts to numpy array, then creates an xarray dataset as
49+
the only group, gives the array an arbitrary name
4750
group : str
4851
If `obj` is a dict or numpy array, assigns the resulting xarray
4952
dataset to this group. Default: "posterior".
@@ -115,6 +118,13 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
115118
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
116119
elif isinstance(obj, np.ndarray):
117120
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
121+
elif (
122+
hasattr(obj, "__array__")
123+
and callable(getattr(obj, "__array__"))
124+
and (not isinstance(obj, pd.DataFrame))
125+
):
126+
obj = obj.__array__()
127+
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
118128
elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
119129
if group == "sample_stats":
120130
kwargs["posterior"] = kwargs.pop(group)
@@ -129,6 +139,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
129139
"pytree (if 'dm-tree' is installed)",
130140
"netcdf filename",
131141
"numpy array",
142+
"object with __array__",
132143
"pystan fit",
133144
"emcee fit",
134145
"pyro mcmc fit",

arviz/data/io_pyjags.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def _extract_arviz_dict_from_inference_data(
277277

278278

279279
def _convert_arviz_dict_to_pyjags_dict(
280-
samples: tp.Mapping[str, np.ndarray]
280+
samples: tp.Mapping[str, np.ndarray],
281281
) -> tp.Mapping[str, np.ndarray]:
282282
"""
283283
Convert and ArviZ dictionary to a PyJAGS dictionary.

0 commit comments

Comments
 (0)