Partial fixed_params into functions upon model creation. Stateless treatment of dynamic IrregSpacedGrids, in particular ShockGrids.#235
Conversation
Instead of wrapping each function to accept an internal_regime_params dict and extracting parameters at call time, use dags.signature.rename_arguments() to qualify parameter names with function prefixes (e.g., risk_aversion becomes utility__risk_aversion). This makes the parameter flow explicit and removes the internal_regime_params indirection from all function signatures. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…at_regime_params. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ParamsTemplate. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…s Callable in result.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…mstances than REGIME_SEPARATOR. Some small improvements.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…eatment of dynamic IrregSpacedGrids, in particular ShockGrids.
Three bugs in _partial_fixed_params_into_regimes: 1. internal_functions.regime_transition_probs was not updated alongside the top-level attribute, causing simulation to use un-partialled functions for regime transitions. 2. regime_transition_probs was partialled with ALL regime fixed params, but it only accepts a subset — breaking inspect.signature used by dags.concatenate_functions in to_dataframe. 3. to_dataframe additional_targets built DAGs from raw functions that still expected fixed params no longer present in runtime params. Fixes: update internal_functions, filter kwargs per function signature, store resolved_fixed_params on InternalRegime for target computation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
for more information, see https://pre-commit.ci
…speed on runs 2 and especially 3 :-)
examples/mahler_yum_2024/model.py
Outdated
| initial_regimes=initial_regimes, | ||
| seed=8295, | ||
| ) | ||
| for _ in range(3): |
There was a problem hiding this comment.
Commit msg: Run the MY model three times just for the pleasure of looking at the speed on runs 2 and 3 :-)
There was a problem hiding this comment.
Actually, now that I added timings, that is not the only reason. Compile time reduces a lot when passing params directly:
prod_shock_grid = ShockGridAR1Rouwenhorst(n_points=5, ar1_coeff=rho, mean=0, std=1)
...
"adjustment_cost": ShockGridIIDUniform(n_points=5, start=0, stop=1)INFO:lcm:--- Timing summary ---
INFO:lcm: Run 1: 84.696s
INFO:lcm: Run 2: 7.181s
INFO:lcm: Run 3: 7.132s
vs. via fixed_params:
fixed_params={
"alive": {
"productivity_shock": {"ar1_coeff": rho, "mean": 0, "std": 1},
"adjustment_cost": {"start": 0, "stop": 1},
}
},INFO:lcm:--- Timing summary ---
INFO:lcm: Run 1: 117.166s
INFO:lcm: Run 2: 7.301s
INFO:lcm: Run 3: 7.331s
whether the runtime is truly affected or my machine was doing more other stuff, I cannot say.
src/lcm/shocks.py
Outdated
|
|
||
| def _discretized_uniform_distribution_gridpoints( | ||
| n_points: int, start: float = 0, stop: float = 1 | ||
| n_points: int, start: float, stop: float |
There was a problem hiding this comment.
Note: I prefer require setting all parameters.
|
@mj023: Not asking for a full review, but it would be great if you could take a look at the shocks code! |
src/lcm/shock_grids.py
Outdated
| r"""AR(1) shock discretized via Tauchen (1986). | ||
|
|
||
| The process is | ||
| :math:`y_t = \mu_\varepsilon + \rho \, (y_{t-1} - \mu_\varepsilon) + \varepsilon_t`, |
There was a problem hiding this comment.
This AR1 Process is a bit odd. The function that draws the shock is
There was a problem hiding this comment.
Oops, will have to look into that again. Sorry, this was a very late addition (since it is now ready to support that) and I did not have time to check it well. Should be very standard...
There was a problem hiding this comment.
Sorry for the noise. We are now using the form
for AR(1)-processes throughout, which is the standard in econ. It matches what, e.g., QuantEcon does. Sorry again, I had thought this would be so standard that I did not check at all -- but there are different representations around!
…rocesses throughout. Stick to rho - mu - sigma for all shock parameters.
65b0c67 to
a52684d
Compare
timmens
left a comment
There was a problem hiding this comment.
Very nice PR! 🚀
I have a few comments before full approval, but nothing serious. In addition, I have two comments here:
- I think this PR is missing an economically sensible test for the shocks. A test where we predict (some part of) the result of the solution or simulation given the chosen shock. Currently there is no end-to-end test like this. Only the regression test, but that does not test a behavior.
- At some point I believe we must think about a clearer separation of code/objects that are used for the solution and code/objects used for the simulation. Definitely not in this PR -- was just reminded of that while reading the changes.
| from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical | ||
| from lcm.dispatchers import _base_productmap | ||
| from lcm.grids import ShockGrid | ||
| from lcm.shock_grids import ShockGridAR1Rouwenhorst, ShockGridIIDUniform |
There was a problem hiding this comment.
For a more pleasant user-interface, I would leave out the "ShockGrid" prefix. Potentially, I would even rename the shock_grids module to shocks.
I can also see the argument that you want to make clear to the user that they are describing grids. Personally, I still prefer the minimalistic UI (without prefix), which focuses on the actual object they are trying to model (the shock) and not the implementation (the grid).
There was a problem hiding this comment.
We could even go one step further and export a shocks object that behaves as follows:
uniform_shock = lcm.shocks.IID.Uniform(start=0, stop=1)
rouwenhorst_shock = lcm.shocks.AR1.Rouwenhorst(...)| def get_coordinate(self, value: ScalarFloat) -> ScalarFloat: | ||
| """Return the generalized coordinate of a value in the grid.""" | ||
| if self.points is None: | ||
| raise GridInitializationError( | ||
| "Cannot compute coordinate without points. Pass points at " | ||
| "initialization or use IrregSpacedGrid(n_points=...) and " | ||
| "supply points at runtime via params." | ||
| ) | ||
| return grid_helpers.get_irreg_coordinate(value, self.to_jax()) |
There was a problem hiding this comment.
Is this tested? If the error is raised in jitted code the user might see a strange error.
| - state_action_spaces: StateActionSpace object | ||
| - max_Q_over_a_functions: dict mapping period to max_Q_over_a function | ||
| - active: list of periods the regime is active | ||
| - gridspecs: grid specifications (needed by _replace_dynamic_states) |
There was a problem hiding this comment.
Read somewhere that we do not want to use the terminology "dynamic". Is the comment outdated or the function name?
There was a problem hiding this comment.
Nice tests in general. But the style is slightly off, probably because of Claude. The model functions don't adhere to our standard terminology (_simple_utility instead of just utility). We also don't usually use Test-classes with pytest. Maybe have a quick look at it again to align the style? On top, I would expect _simple_model to be named _create_model, or similar.
|
|
||
| ### Testing Style | ||
|
|
||
| - Use plain pytest functions, never test classes (`class TestFoo`) |
There was a problem hiding this comment.
Ah, very nice! Must have been added in a later commit then 😅
| The class *is* the distribution type — no ``distribution_type`` string needed. | ||
|
|
There was a problem hiding this comment.
| The class *is* the distribution type — no ``distribution_type`` string needed. |
| def get_transition_probs(self) -> FloatND: | ||
| """Get the transition probabilities at the gridpoints. | ||
|
|
||
| Returns uniform probabilities when required params are missing. |
There was a problem hiding this comment.
Do we actually want that, or do we want to throw an error if its not fully specified and we request the transition probs?
There was a problem hiding this comment.
For the mathematical formulas / implementation of Tauchen and Rouwenhorst, is there another reference than the original papers, or did we choose the specific implementation ourselves given the original papers?
| @pytest.mark.parametrize( | ||
| "distribution_type", ["uniform", "normal", "tauchen", "rouwenhorst"] | ||
| ) | ||
| def test_model_with_shock(distribution_type): |
There was a problem hiding this comment.
This is just a regression test, right? We don't test any simulation result changes based on the different shock types?
| return MappingProxyType(solution) | ||
|
|
||
|
|
||
| def _replace_runtime_states( |
There was a problem hiding this comment.
Shouldn't this be a method of InternalRegime? Something like:
state_action_space = internal_regime.build_state_action_space(internal_params[name])Potentially with a different prefix than build_.
There was a problem hiding this comment.
I think we could even make the internal_regime.state_action_space attribute private then.
Ended up a bit of a monster, sorry...
fixed_params-> will be partialled into functions.IrregSpacedGrids/ShockGridsso they do the same thing: For bothIrregSpacedGrids andShockGrids, we can pass params directly to the instantiation, viafixed_params, or at runtime.classDiagram ContinuousGrid <|-- ShockGrid ShockGrid <|-- ShockGridIID ShockGrid <|-- ShockGridAR1 ShockGridIID <|-- ShockGridIIDUniform ShockGridIID <|-- ShockGridIIDNormal ShockGridAR1 <|-- ShockGridAR1Tauchen ShockGridAR1 <|-- ShockGridAR1Rouwenhorst class ShockGrid { +n_points: int +params +params_to_pass_at_runtime +is_fully_specified +compute_gridpoints() +compute_transition_probs() +get_gridpoints() +get_transition_probs() +to_jax() +get_coordinate() } class ShockGridIID { +draw_shock(params, key) } class ShockGridAR1 { +draw_shock(params, key, current_value) } class ShockGridIIDUniform { +start: float +stop: float } class ShockGridIIDNormal { +mean: float +std: float +n_std: float } class ShockGridAR1Tauchen { +ar1_coeff: float +std: float +mean: float +n_std: float } class ShockGridAR1Rouwenhorst { +ar1_coeff: float +std: float +mean: float }