Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
4aad70e
Replace internal_regime_params with flat kwargs via rename_arguments.
hmgaudecker Feb 10, 2026
1669e9d
Rename InternalRegimeParams to FlatRegimeParams and flat_params to fl…
hmgaudecker Feb 10, 2026
f0b5cbe
Rename params_template to regime_params_template where type is Regime…
hmgaudecker Feb 10, 2026
66b31c4
Rename internal_fixed_params to flat_regime_fixed_params; type func a…
hmgaudecker Feb 10, 2026
44a7627
Use QNAME_DELIMITER throughout since we need it in more general circu…
hmgaudecker Feb 10, 2026
4ff6b86
Use pyproject-fmt and update hooks.
hmgaudecker Feb 10, 2026
c90ee02
Simplify.
hmgaudecker Feb 10, 2026
60927f1
Fix MY model.
hmgaudecker Feb 10, 2026
b560602
Improve naming.
hmgaudecker Feb 10, 2026
aaeabe4
Remove unnecessary defaults.
hmgaudecker Feb 10, 2026
be2b079
Factor out time aggregation function H.
hmgaudecker Feb 10, 2026
0f235c5
Include utility among user functions (same as H).
hmgaudecker Feb 10, 2026
1ef6108
Run code-simplifier and some manual fixes in that vein.
hmgaudecker Feb 10, 2026
2eb5b92
Further simplifications.
hmgaudecker Feb 10, 2026
2250a85
Partial fixed_params into functions upon model creation. Stateless tr…
hmgaudecker Feb 10, 2026
aa4ce80
Fix fixed_params support in simulation and to_dataframe
hmgaudecker Feb 10, 2026
f41eae8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
f200db2
Add the NestedMappingParams type.
hmgaudecker Feb 11, 2026
24f07b5
Back to dags main; tiny harmonisation efforts in pyproject.toml
hmgaudecker Feb 11, 2026
8f1e84d
Improve typing.
hmgaudecker Feb 11, 2026
0b2e5d3
Merge branch 'remove-internal_regime_params' into factor-out-time-agg…
hmgaudecker Feb 11, 2026
876b8af
Update lockfile, remove now-irrelevant ty:ignore
hmgaudecker Feb 11, 2026
187304f
Remove a couple of future imports
hmgaudecker Feb 13, 2026
7ce65c9
prek autoupdate and include [tool.pyproject-fmt] in pyproject.toml
hmgaudecker Feb 13, 2026
4183875
Merge branch 'remove-internal_regime_params' into factor-out-time-agg…
hmgaudecker Feb 13, 2026
c706b26
Merge branch 'factor-out-time-aggregation' into partial-fixed_params
hmgaudecker Feb 13, 2026
abdc30b
Address review comments from #233.
hmgaudecker Feb 13, 2026
607fd53
Merge branch 'main' into factor-out-time-aggregation
hmgaudecker Feb 13, 2026
e5340f8
Merge branch 'factor-out-time-aggregation' into partial-fixed_params
hmgaudecker Feb 13, 2026
ebffbaf
Merge branch 'main' into partial-fixed_params
hmgaudecker Feb 13, 2026
f7084ea
Simplify.
hmgaudecker Feb 13, 2026
844c87d
Improve naming of grids that may change at runtime.
hmgaudecker Feb 13, 2026
3642954
Simplify grid logic further.
hmgaudecker Feb 13, 2026
cbce3b4
Run the MY model three times just for the pleasure of looking at the …
hmgaudecker Feb 13, 2026
167f4b7
Improve naming.
hmgaudecker Feb 13, 2026
cb17244
Simplify fixed_params mechanism.
hmgaudecker Feb 13, 2026
63c52cb
Stateless setting up of model.
hmgaudecker Feb 13, 2026
00c9ed0
Split up test file in a now-sensible way.
hmgaudecker Feb 13, 2026
fabdae9
Fix params flow.
hmgaudecker Feb 13, 2026
5cb1f6c
Merge branch 'partial-fixed_params' into nested-mapping-params
hmgaudecker Feb 13, 2026
a9e9fa2
Add reproducer from #236 -- but it passes...
hmgaudecker Feb 13, 2026
b02d435
Refactor ShockGrids to be strongly typed.
hmgaudecker Feb 14, 2026
1c9fc09
Merge branch 'partial-fixed_params' into nested-mapping-params
hmgaudecker Feb 14, 2026
a52684d
Review comments: Use form y_{t+1} = μ + ρ y_{t} + ϵ_{t+1} for AR(1)-p…
hmgaudecker Feb 15, 2026
d5452bd
Stateless Rouwenhorst.
hmgaudecker Feb 15, 2026
a424870
More readable formatting.
hmgaudecker Feb 15, 2026
365754d
Normalise probabilities explicitly to get rid of failure on 32bit.
hmgaudecker Feb 15, 2026
bc9d0f3
Revert previous change; make rounding precision in tests to be archit…
hmgaudecker Feb 15, 2026
1c20045
First set of review comments: Test style, renamings, minor stuff.
hmgaudecker Feb 16, 2026
2c4ed99
Move shock_grids -> shocks.iid, shocks.ar1 and internal stuff.
hmgaudecker Feb 16, 2026
83376da
Make sure to return NaN for gridpoints and transition probs if parame…
hmgaudecker Feb 16, 2026
d6000f2
Simplify regime processing and next_state functions for shocks.
hmgaudecker Feb 16, 2026
f81bdd8
Fix IrregSpacedGrid.to_jax() to return a NaN-array if no points are a…
hmgaudecker Feb 16, 2026
2f32ae6
Add a couple of ty: ignore statements.
hmgaudecker Feb 16, 2026
0190c8c
Merge branch 'partial-fixed_params' into nested-mapping-params
hmgaudecker Feb 16, 2026
572c60a
Review comments.
hmgaudecker Feb 16, 2026
8f77a9c
Force import via lcm.params
hmgaudecker Feb 16, 2026
35ec15e
Differentiate between internal_regime._base_state_action_space and in…
hmgaudecker Feb 16, 2026
8d7476c
Merge branch 'partial-fixed_params' into nested-mapping-params
hmgaudecker Feb 16, 2026
7510b21
Add SequenceLeaf type and as_leaf function.
hmgaudecker Feb 16, 2026
e297ab5
Merge branch 'main' into nested-mapping-params
hmgaudecker Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ initial_regimes = ["working", "working", "retired"]
- Google-style docstrings
- All functions require type annotations
- Pre-commit hooks ensure code quality
- Never use `from __future__ import annotations` — this project requires Python 3.14+

### Testing Style

- Use plain pytest functions, never test classes (`class TestFoo`)
- Use `@pytest.mark.parametrize` for test variations

### Testing Style

Expand Down
26 changes: 26 additions & 0 deletions src/lcm/params/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from collections.abc import Mapping, Sequence
from typing import Any, overload

from lcm.params.mapping_leaf import MappingLeaf
from lcm.params.sequence_leaf import SequenceLeaf


@overload
def as_leaf(data: Mapping[str, Any]) -> MappingLeaf: ...


@overload
def as_leaf(data: Sequence[Any]) -> SequenceLeaf: ...


def as_leaf(data: Mapping[str, Any] | Sequence[Any]) -> MappingLeaf | SequenceLeaf:
"""Wrap a Mapping or Sequence as a JAX-pytree leaf."""
if isinstance(data, Mapping):
return MappingLeaf(dict(data))
if isinstance(data, Sequence):
return SequenceLeaf(data)
msg = f"as_leaf() expects a Mapping or Sequence, got {type(data).__name__}"
raise TypeError(msg)


__all__ = ["MappingLeaf", "SequenceLeaf", "as_leaf"]
46 changes: 46 additions & 0 deletions src/lcm/params/mapping_leaf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""A Mapping wrapper that is a JAX pytree but not itself a Mapping."""

from collections.abc import Mapping
from typing import Any

import jax


class MappingLeaf:
"""A Mapping wrapper that is a JAX pytree but not itself a Mapping.

Prevents flatten_regime_namespace from recursing into contents while
allowing JAX to trace through array values.

Data is frozen to immutable containers on construction.
"""

__slots__ = ("data",)

def __init__(self, data: Mapping[str, Any]) -> None:
from lcm.utils import ensure_containers_are_immutable # noqa: PLC0415

self.data = ensure_containers_are_immutable(data)

def __repr__(self) -> str:
return f"MappingLeaf({dict(self.data)!r})"

__hash__ = None # MappingProxyType is not hashable

def __eq__(self, other: object) -> bool:
if not isinstance(other, MappingLeaf):
return NotImplemented
return self.data == other.data


def _flatten(nmp: MappingLeaf) -> tuple[list[Any], tuple[str, ...]]:
keys = tuple(sorted(nmp.data.keys()))
values = [nmp.data[k] for k in keys]
return values, keys


def _unflatten(keys: tuple[str, ...], values: list[Any]) -> MappingLeaf:
return MappingLeaf(dict(zip(keys, values, strict=True)))


jax.tree_util.register_pytree_node(MappingLeaf, _flatten, _unflatten)
45 changes: 45 additions & 0 deletions src/lcm/params/sequence_leaf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""A Sequence wrapper that is a JAX pytree but not itself a Sequence."""

from collections.abc import Sequence
from typing import Any

import jax


class SequenceLeaf:
"""A Sequence wrapper that is a JAX pytree but not itself a Sequence.

Prevents flatten_regime_namespace from recursing into contents while
allowing JAX to trace through array values.

Data is frozen to immutable containers on construction.
"""

__slots__ = ("data",)

def __init__(self, data: Sequence[Any]) -> None:
from lcm.utils import _make_immutable # noqa: PLC0415

self.data = tuple(_make_immutable(v) for v in data)

def __repr__(self) -> str:
return f"SequenceLeaf({list(self.data)!r})"

def __hash__(self) -> int:
return hash(self.data)

def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceLeaf):
return NotImplemented
return self.data == other.data


def _flatten(sl: SequenceLeaf) -> tuple[list[Any], None]:
return list(sl.data), None


def _unflatten(_aux: None, values: list[Any]) -> SequenceLeaf:
return SequenceLeaf(values)


jax.tree_util.register_pytree_node(SequenceLeaf, _flatten, _unflatten)
15 changes: 14 additions & 1 deletion src/lcm/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from jax import Array
from jaxtyping import Bool, Float, Int, Scalar

from lcm.params import MappingLeaf
from lcm.params.sequence_leaf import SequenceLeaf

type ContinuousState = Float[Array, "..."]
type ContinuousAction = Float[Array, "..."]
type DiscreteState = Int[Array, "..."]
Expand Down Expand Up @@ -40,7 +43,17 @@
bool
| float
| Array
| Mapping[str, bool | float | Array | Mapping[str, bool | float | Array]],
| MappingLeaf
| SequenceLeaf
| Mapping[
str,
bool
| float
| Array
| MappingLeaf
| SequenceLeaf
| Mapping[str, bool | float | Array | MappingLeaf | SequenceLeaf],
],
]

# Internal regime parameters: A flat mapping with function-qualified names.
Expand Down
10 changes: 9 additions & 1 deletion src/lcm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
from dags.tree import flatten_to_qnames, unflatten_from_qnames
from jax import Array

from lcm.params import MappingLeaf
from lcm.params.sequence_leaf import SequenceLeaf
from lcm.typing import RegimeName

T = TypeVar("T")


def _make_immutable(value: Any) -> Any: # noqa: ANN401
"""Recursively convert a value to its immutable equivalent."""
if isinstance(value, (MappingLeaf, SequenceLeaf)):
return value # already immutable by construction
if isinstance(value, (MappingProxyType, tuple, frozenset)):
return value
if isinstance(value, Mapping):
Expand Down Expand Up @@ -51,8 +55,12 @@ def ensure_containers_are_immutable[K, V](
return cast("MappingProxyType[K, V]", _make_immutable(value))


def _make_mutable(value: Any) -> Any: # noqa: ANN401
def _make_mutable(value: Any) -> Any: # noqa: ANN401, PLR0911
"""Recursively convert a value to its mutable equivalent."""
if isinstance(value, MappingLeaf):
return {k: _make_mutable(v) for k, v in value.data.items()}
if isinstance(value, SequenceLeaf):
return [_make_mutable(v) for v in value.data]
if isinstance(value, (set, list)):
return value
if isinstance(value, (MappingProxyType, Mapping)):
Expand Down
Loading
Loading