Skip to content

Commit 3a22e6e

Browse files
authored
Minor updates for v0.10.2 (#141)
* + update version * + improve pymcio + minor logging improvements * + return a fqn path from output _data_dict
1 parent e225de6 commit 3a22e6e

File tree

4 files changed

+23
-12
lines changed

4 files changed

+23
-12
lines changed

oreum_core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import logging
1818

19-
__version__ = "0.10.1"
19+
__version__ = "0.10.2"
2020

2121
# logger goes to null handler by default
2222
# packages that import oreum_core can override this and direct elsewhere

oreum_core/eda/eda_io.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def read(
117117

118118
def output_data_dict(
119119
df: pd.DataFrame, dd_notes: dict[str, str], fqp: Path, fn: str = ""
120-
) -> None:
120+
) -> Path:
121121
"""Helper fn to output a data dict with automatic eda.describe"""
122122

123123
# flag if is index
@@ -167,4 +167,5 @@ def output_data_dict(
167167
na_rep="NULL",
168168
)
169169

170-
excelio.writer_close()
170+
fqn = excelio.writer_close()
171+
return fqn

oreum_core/model_pymc/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,9 @@ def debug(self) -> str:
307307
_ = self.model.debug(fn="logp", verbose=False)
308308
msg.append("debug: logp")
309309
except (TypeError, ValueError):
310-
_log.exception(
311-
"Model contains Potentials, debug logp not compatible",
312-
exc_info=True,
313-
)
310+
_log.error("Model contains Potentials, debug logp not compatible")
311+
# _log.exception(
312+
# "Model contains Potentials, debug logp not compatible",
313+
# exc_info=True,
314+
# )
314315
return f"Ran {len(msg)} checks: [" + ", ".join(msg) + "]"

oreum_core/model_pymc/pymc_io.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self, *args, **kwargs):
4343
def read_idata(
4444
self, mdl: BasePYMCModel = None, fn: str = "", **kwargs
4545
) -> az.InferenceData:
46-
"""Read InferenceData using mdl.mdl_id_fn + txtadd, or from fn"""
46+
"""Read InferenceData appropriate to a built model using
47+
mdl.mdl_id_fn + txtadd, or from fn"""
4748
txtadd = kwargs.pop("txtadd", None)
4849
if mdl is not None:
4950
fn = "_".join(filter(None, ["idata", mdl.mdl_id_fn, txtadd]))
@@ -52,14 +53,22 @@ def read_idata(
5253
_log.info(f"Read model idata from {str(fqn.resolve())}")
5354
return idata
5455

55-
def write_idata(self, mdl: BasePYMCModel, fn: str = "", **kwargs) -> Path:
56-
"""Accept BasePYMCModel object write to InferenceData using
57-
mdl.mdl_id_fn + txtadd"""
56+
def write_idata(
57+
self, mdl: BasePYMCModel, idata: az.InferenceData = None, fn: str = "", **kwargs
58+
) -> Path:
59+
"""Accept BasePYMCModel object with attached in-sample idata, and write
60+
to netcdf file with name mdl.mdl_id_fn + txtadd. Optionally use this to
61+
write out-of-sample InferenceData passed as idata kwarg. Can implicitly
62+
use mdl.mdl_id_fn in either case"""
5863
txtadd = kwargs.pop("txtadd", None)
5964
if fn == "":
6065
fn = "_".join(filter(None, ["idata", mdl.mdl_id_fn, txtadd]))
6166
fqn = self.get_path_write(Path(self.snl.clean(fn)).with_suffix(".netcdf"))
62-
mdl.idata.to_netcdf(str(fqn.resolve()))
67+
68+
if idata is not None:
69+
idata.to_netcdf(str(fqn.resolve()))
70+
else:
71+
mdl.idata.to_netcdf(str(fqn.resolve()))
6372
_log.info(f"Written to {str(fqn.resolve())}")
6473
return fqn
6574

0 commit comments

Comments
 (0)