Replace internal_regime_params with flat kwargs via dags.signature.rename_arguments.#233
Replace internal_regime_params with flat kwargs via dags.signature.rename_arguments.#233hmgaudecker merged 12 commits intomainfrom
internal_regime_params with flat kwargs via dags.signature.rename_arguments.#233Conversation
Instead of wrapping each function to accept an internal_regime_params dict and extracting parameters at call time, use dags.signature.rename_arguments() to qualify parameter names with function prefixes (e.g., risk_aversion becomes utility__risk_aversion). This makes the parameter flow explicit and removes the internal_regime_params indirection from all function signatures. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…at_regime_params. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
8aef1ab to
1669e9d
Compare
…ParamsTemplate. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…s Callable in result.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…mstances than REGIME_SEPARATOR. Some small improvements.
883fa5f to
44a7627
Compare
|
|
||
| return cast("InternalUserFunction", processed_func) | ||
| """ | ||
| param_names = list(params_template[param_key]) # ty: ignore[invalid-argument-type] |
There was a problem hiding this comment.
ty: ignore will be gone once in PR that follows immediately.
| """Get all flat parameter names from a regime params template. | ||
|
|
||
| Converts nested template entries like {"utility": {"risk_aversion": type}} to | ||
| flat names like "utility__risk_aversion". Top-level params like |
There was a problem hiding this comment.
Will be simplified in subsequent PR.
|
|
||
| return cast("InternalUserFunction", processed_func) | ||
| """ | ||
| param_names = list(regime_params_template[fn_key]) # ty: ignore[invalid-argument-type] |
There was a problem hiding this comment.
Note: ty: ignore will disappear in subsequent PR.
timmens
left a comment
There was a problem hiding this comment.
While starting the review I was a bit skeptical about the change, but was quickly convinced that this is a big improvement. Very nice PR! Only minor comments, but already approved.
| assert set(internal_params.keys()) == set(params_template.keys()) | ||
| for regime in params_template: | ||
| assert set(internal_params[regime].keys()) == set( | ||
| params_template[regime].keys() | ||
| ) | ||
| expected_flat_keys = set() | ||
| for func, func_params in params_template[regime].items(): | ||
| for arg in func_params: | ||
| expected_flat_keys.add(f"{func}__{arg}") | ||
| assert set(internal_params[regime].keys()) == expected_flat_keys |
There was a problem hiding this comment.
Seeing this check now for the third time. Make a function out of it?
| }, | ||
| } | ||
| ) | ||
| flat_regime_params = { |
There was a problem hiding this comment.
Is there a reason why this dictionary does not need to be a MappingProxyType anymore?
| {"working": {"discount_factor": value}}. | ||
|
|
||
| """ | ||
| result: dict[str, dict[str, Any]] = {} |
There was a problem hiding this comment.
| result: dict[str, dict[str, Any]] = {} | |
| result: dict[RegimeName, dict[str, Any]] = {} |
| for param_name in value: | ||
| result.add(f"{key}{QNAME_DELIMITER}{param_name}") | ||
| else: | ||
| # Top-level param (e.g., "discount_factor": float) |
There was a problem hiding this comment.
This exact comment was renamed above to
Scalar param (currently unused - all params are under namespaces)
| params = tuple(inspect.signature(next_state).parameters) | ||
|
|
||
| next_state_vmapped = vmap_1d( | ||
| func=next_state, | ||
| variables=tuple( | ||
| parameter | ||
| for parameter in parameters | ||
| if parameter not in ("period", "age", "internal_regime_params") | ||
| p for p in params if p not in _get_non_vmap_params(regime_params_template) |
There was a problem hiding this comment.
I think params and/or the p in the loop should be replaced by something like argnames and var. I was slightly confused for which reason we would vmap over parameter, which we don't, we vmap over states and actions that are arguments of the next_state function.
| params = tuple(inspect.signature(next_regime_accepting_all).parameters) | ||
|
|
||
| next_regime_vmapped = vmap_1d( | ||
| func=next_regime_accepting_all, | ||
| variables=tuple( | ||
| parameter | ||
| for parameter in parameters | ||
| if parameter not in ("period", "age", "internal_regime_params") | ||
| p for p in params if p not in _get_non_vmap_params(regime_params_template) |
There was a problem hiding this comment.
Same comment as above. It confuses me a bit that we use params to denote function arguments but also to denote the model parameters. I think it should be clear from the naming that these are different concepts
| internal_regime_params=internal_regime_params, | ||
| next_V_arr=next_V_arr, | ||
| **states_and_actions, | ||
| **kwargs, |
There was a problem hiding this comment.
In the **states_and_actions case I liked that it was clear that the unpacked dictionary contains the states and actions. If I understand it correctly, kwargs now also contains the flat parameters, which is why we need the more general name. Could we still add a comment for the reader on what is expected to be found inside kwargs?
| internal_regime_params: InternalRegimeParams, | ||
| period: Period, | ||
| **states_and_actions: Array, | ||
| **kwargs: Array, |
There was a problem hiding this comment.
Same comment on kwargs as above.
|
These are great comments, thanks! Could you live with me applying them to #235 ? Would make life much easier... |
|
Yes, lets do it in #235 !! |
Instead of wrapping each function to accept an
internal_regime_paramsdict and extracting parameters at call time, usedags.signature.rename_arguments()to qualify parameter names with function prefixes (e.g.,risk_aversionbecomesutility__risk_aversion).This makes the parameter flow explicit and removes the
internal_regime_paramsindirection from all function signatures.It is a crucial prerequisite for solving #174 and the partialling of fixed_params described in #219, which will be done in subsequent PRs.