From 98d199f8faf4b24e7f5414ec1d8861c800b64564 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 10 Feb 2025 09:24:32 +0100 Subject: [PATCH 01/16] Create solution and simulation sub-packages --- src/lcm/entry_point.py | 6 +++--- src/lcm/simulation/__init__.py | 0 src/lcm/{ => simulation}/simulate.py | 0 src/lcm/solution/__init__.py | 0 src/lcm/{ => solution}/solve_brute.py | 0 src/lcm/{ => solution}/state_space.py | 0 tests/simulation/__init__.py | 0 tests/{ => simulation}/test_simulate.py | 4 ++-- tests/solution/__init__.py | 0 tests/{ => solution}/test_solve_brute.py | 2 +- tests/{ => solution}/test_state_space.py | 2 +- tests/test_entry_point.py | 2 +- tests/test_model_functions.py | 2 +- 13 files changed, 9 insertions(+), 9 deletions(-) create mode 100644 src/lcm/simulation/__init__.py rename src/lcm/{ => simulation}/simulate.py (100%) create mode 100644 src/lcm/solution/__init__.py rename src/lcm/{ => solution}/solve_brute.py (100%) rename src/lcm/{ => solution}/state_space.py (100%) create mode 100644 tests/simulation/__init__.py rename tests/{ => simulation}/test_simulate.py (99%) create mode 100644 tests/solution/__init__.py rename tests/{ => solution}/test_solve_brute.py (98%) rename tests/{ => solution}/test_state_space.py (90%) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 30175c01..0f99f3d8 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -15,9 +15,9 @@ get_utility_and_feasibility_function, ) from lcm.next_state import get_next_state_function -from lcm.simulate import simulate -from lcm.solve_brute import solve -from lcm.state_space import create_state_choice_space +from lcm.simulation.simulate import simulate +from lcm.solution.solve_brute import solve +from lcm.solution.state_space import create_state_choice_space from lcm.typing import ParamsDict from lcm.user_model import Model diff --git a/src/lcm/simulation/__init__.py b/src/lcm/simulation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lcm/simulate.py b/src/lcm/simulation/simulate.py similarity index 100% rename from src/lcm/simulate.py rename to src/lcm/simulation/simulate.py diff --git a/src/lcm/solution/__init__.py b/src/lcm/solution/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lcm/solve_brute.py b/src/lcm/solution/solve_brute.py similarity index 100% rename from src/lcm/solve_brute.py rename to src/lcm/solution/solve_brute.py diff --git a/src/lcm/state_space.py b/src/lcm/solution/state_space.py similarity index 100% rename from src/lcm/state_space.py rename to src/lcm/solution/state_space.py diff --git a/tests/simulation/__init__.py b/tests/simulation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_simulate.py b/tests/simulation/test_simulate.py similarity index 99% rename from tests/test_simulate.py rename to tests/simulation/test_simulate.py index d37addcc..f0304943 100644 --- a/tests/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -12,7 +12,7 @@ from lcm.logging import get_logger from lcm.model_functions import get_utility_and_feasibility_function from lcm.next_state import _get_next_state_function_simulation -from lcm.simulate import ( +from lcm.simulation.simulate import ( _as_data_frame, _compute_targets, _generate_simulation_keys, @@ -24,7 +24,7 @@ retrieve_non_sparse_choices, simulate, ) -from lcm.state_space import create_state_choice_space +from lcm.solution.state_space import create_state_choice_space from tests.test_models import ( get_model_config, get_params, diff --git a/tests/solution/__init__.py b/tests/solution/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_solve_brute.py b/tests/solution/test_solve_brute.py similarity index 98% rename from tests/test_solve_brute.py rename to tests/solution/test_solve_brute.py index 44395cca..7b56f965 100644 --- a/tests/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -6,7 +6,7 @@ from lcm.interfaces import Space from lcm.logging import get_logger from lcm.ndimage import map_coordinates -from lcm.solve_brute import solve, solve_continuous_problem +from lcm.solution.solve_brute import solve, solve_continuous_problem def test_solve_brute(): diff --git a/tests/test_state_space.py b/tests/solution/test_state_space.py similarity index 90% rename from tests/test_state_space.py rename to tests/solution/test_state_space.py index cc0de209..a9d0981f 100644 --- a/tests/test_state_space.py +++ b/tests/solution/test_state_space.py @@ -1,5 +1,5 @@ from lcm.input_processing import process_model -from lcm.state_space import ( +from lcm.solution.state_space import ( create_state_choice_space, ) from tests.test_models import get_model_config diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index e8d3a567..7c478b26 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -9,7 +9,7 @@ ) from lcm.input_processing import process_model from lcm.model_functions import get_utility_and_feasibility_function -from lcm.state_space import create_state_choice_space +from lcm.solution.state_space import create_state_choice_space from tests.test_models import get_model_config from tests.test_models.deterministic import RetirementStatus from tests.test_models.deterministic import utility as iskhakov_et_al_2017_utility diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index fffd25f3..a20d74d2 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -10,7 +10,7 @@ get_multiply_weights, get_utility_and_feasibility_function, ) -from lcm.state_space import create_state_choice_space +from lcm.solution.state_space import create_state_choice_space from tests.test_models import get_model_config from tests.test_models.deterministic import utility From 7c9e08c0dcf9326127a02c7556fd1634858869ce Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 10 Feb 2025 09:25:50 +0100 Subject: [PATCH 02/16] Replace double backticks with single --- src/lcm/dispatchers.py | 28 ++++++++++++++-------------- src/lcm/function_representation.py | 4 ++-- src/lcm/interfaces.py | 6 +++--- src/lcm/simulation/simulate.py | 4 ++-- src/lcm/solution/solve_brute.py | 2 +- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index c91d99d0..a1e04db2 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -34,12 +34,12 @@ def spacemap( Returns: A callable with the same arguments as func (but with an additional leading - dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func`` + dimension) that returns a jax.numpy.ndarray or pytree of arrays. If `func` returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1 - jax.numpy.ndarray with k + 1 dimensions, where k is the length of ``dense_vars`` - and the additional dimension corresponds to the ``sparse_vars``. The order of - the dimensions is determined by the order of ``dense_vars`` as well as the - ``put_dense_first`` argument. If the output of ``func`` is a jax pytree, the + jax.numpy.ndarray with k + 1 dimensions, where k is the length of `dense_vars` + and the additional dimension corresponds to the `sparse_vars`. The order of + the dimensions is determined by the order of `dense_vars` as well as the + `put_dense_first` argument. If the output of `func` is a jax pytree, the usual jax behavior applies, i.e. the leading dimensions of all arrays in the pytree are as described above but there might be additional dimensions. @@ -106,12 +106,12 @@ def vmap_1d( Returns: A callable with the same arguments as func (but with an additional leading - dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func`` + dimension) that returns a jax.numpy.ndarray or pytree of arrays. If `func` returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1 jax.numpy.ndarray with 1 dimension and length k, where k is the length of one of - the mapped inputs in ``variables``. The order of the dimensions is determined by - the order of ``variables`` which can be different to the order of ``funcs`` - arguments. If the output of ``func`` is a jax pytree, the usual jax behavior + the mapped inputs in `variables`. The order of the dimensions is determined by + the order of `variables` which can be different to the order of `funcs` + arguments. If the output of `func` is a jax pytree, the usual jax behavior applies, i.e. the leading dimensions of all arrays in the pytree are as described above but there might be additional dimensions. @@ -169,11 +169,11 @@ def productmap(func: F, variables: list[str]) -> F: Returns: A callable with the same arguments as func (but with an additional leading - dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func`` + dimension) that returns a jax.numpy.ndarray or pytree of arrays. If `func` returns a scalar, the dispatched function returns a jax.numpy.ndarray with k - dimensions, where k is the length of ``variables``. The order of the dimensions - is determined by the order of ``variables`` which can be different to the order - of ``funcs`` arguments. If the output of ``func`` is a jax pytree, the usual jax + dimensions, where k is the length of `variables`. The order of the dimensions + is determined by the order of `variables` which can be different to the order + of `funcs` arguments. If the output of `func` is a jax pytree, the usual jax behavior applies, i.e. the leading dimensions of all arrays in the pytree are as described above but there might be additional dimensions. @@ -207,7 +207,7 @@ def _base_productmap(func: F, product_axes: list[str]) -> F: product_axes: List with names of arguments over which we apply vmap. Returns: - A callable with the same arguments as func. See ``product_map`` for details. + A callable with the same arguments as func. See `product_map` for details. """ signature = inspect.signature(func) diff --git a/src/lcm/function_representation.py b/src/lcm/function_representation.py index 7b1fe9fd..d7b9d89d 100644 --- a/src/lcm/function_representation.py +++ b/src/lcm/function_representation.py @@ -147,7 +147,7 @@ def _get_label_translator( resulting function. Returns: - callable: A callable with the keyword only argument ``in_name`` that converts a + callable: A callable with the keyword only argument `in_name` that converts a label into a position in a list of labels. """ @@ -172,7 +172,7 @@ def _get_lookup_function( Returns: callable: A callable with the keyword-only arguments [axis_names] + [array_name] - that looks up values in an indexer array called ``array_name``. + that looks up values in an indexer array called `array_name`. """ arg_names = [*axis_names, array_name] diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 8c44d6de..3648d0a1 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -82,12 +82,12 @@ class InternalModel: are True if the variable has the corresponding property. The columns are: is_state, is_choice, is_continuous, is_discrete. functions: Dictionary that maps names of functions to functions. The functions - differ from the user functions in that they take ``params`` as a keyword + differ from the user functions in that they take `params` as a keyword argument. Two cases: - If the original function depended on model parameters, those are - automatically extracted from ``params`` and passed to the original + automatically extracted from `params` and passed to the original function. - - Otherwise, the ``params`` argument is simply ignored. + - Otherwise, the `params` argument is simply ignored. function_info: A table with information about all functions in the model. The index contains the name of a function. The columns are booleans that are True if the function has the corresponding property. The columns are: diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index b585bfd5..e6e73a26 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -243,10 +243,10 @@ def solve_continuous_problem( Returns: - jnp.ndarray: Jax array with policies for each combination of a state and a discrete choice. The number and order of dimensions is defined by the - ``gridmap`` function. + `gridmap` function. - jnp.ndarray: Jax array with continuation values for each combination of a state and a discrete choice. The number and order of dimensions is defined - by the ``gridmap`` function. + by the `gridmap` function. """ _gridmapped = spacemap( diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 52e1de6b..1d8f1162 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -107,7 +107,7 @@ def solve_continuous_problem( Returns: jnp.ndarray: Jax array with continuation values for each combination of a state and a discrete choice. The number and order of dimensions is defined - by the ``gridmap`` function. + by the `gridmap` function. """ _gridmapped = spacemap( From bc56df25d56e751bd681ea8533d679e3011b0220 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 10 Feb 2025 09:54:16 +0100 Subject: [PATCH 03/16] Remove indexers --- src/lcm/entry_point.py | 14 --- src/lcm/function_representation.py | 19 ++-- src/lcm/interfaces.py | 24 ------ src/lcm/simulation/simulate.py | 9 -- src/lcm/solution/solve_brute.py | 10 --- src/lcm/solution/state_space.py | 4 - tests/simulation/test_simulate.py | 1 - tests/solution/test_solve_brute.py | 7 -- tests/test_function_representation.py | 120 +++----------------------- 9 files changed, 16 insertions(+), 192 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 0f99f3d8..b5286f6c 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -38,11 +38,6 @@ def get_lcm_function( source code of this function to see how the lower level components are meant to be used. - Notes: - ----- - - There is a hack to make the state_indexers empty in the last period which needs - to be replaced by a better solution, when we want to allow for bequest motives. - Args: model: User model specification. targets: The requested function types. Currently only "solve", "simulate" and @@ -79,7 +74,6 @@ def get_lcm_function( # Initialize other argument lists # ================================================================================== state_choice_spaces = [] - state_indexers = [] # type:ignore[var-annotated] space_infos = [] compute_ccv_functions = [] compute_ccv_policy_functions = [] @@ -101,12 +95,6 @@ def get_lcm_function( state_choice_spaces.append(sc_space) choice_segments.append(None) - - if is_last_period: - state_indexers.append({}) - else: - state_indexers.append({}) - space_infos.append(space_info) # ================================================================================== @@ -157,7 +145,6 @@ def get_lcm_function( _solve_model = partial( solve, state_choice_spaces=state_choice_spaces, - state_indexers=state_indexers, continuous_choice_grids=continuous_choice_grids, compute_ccv_functions=compute_ccv_functions, emax_calculators=emax_calculators, @@ -169,7 +156,6 @@ def get_lcm_function( _next_state_simulate = get_next_state_function(model=_mod, target="simulate") simulate_model = partial( simulate, - state_indexers=state_indexers, continuous_choice_grids=continuous_choice_grids, compute_ccv_policy_functions=compute_ccv_policy_functions, model=_mod, diff --git a/src/lcm/function_representation.py b/src/lcm/function_representation.py index d7b9d89d..ad5cc75d 100644 --- a/src/lcm/function_representation.py +++ b/src/lcm/function_representation.py @@ -25,8 +25,8 @@ def get_function_representation( This function dynamically generates a function that looks up and interpolates values of the pre-calculated function. The arguments of the resulting function can be split in two categories: - 1. Helper arguments such as information about the grid, indexer arrays and the - pre-calculated values of the function. + 1. Helper arguments such as information about the grid and the pre-calculated + values of the function. 2. The original arguments of the function that was pre-calculated on the grid. After partialling in all helper arguments, the resulting function behaves like an @@ -57,7 +57,7 @@ def get_function_representation( 'vf_arr', in which case, one would partial in 'vf_arr' into the representation. input_prefix: Prefix that will be added to all argument names of the resulting - function, except for the helpers arguments such as indexers or value arrays. + function, except for the helpers arguments such as the value arrays. Default is the empty string. The prefix needs to contain the separator. E.g. `next_` if an undescore should be used as separator. @@ -82,15 +82,6 @@ def get_function_representation( in_name=input_prefix + var, ) - # ================================================================================== - # wrap the indexers and put them it into funcs - # ================================================================================== - for indexer_info in space_info.indexer_infos: - funcs[f"__{indexer_info.out_name}_pos__"] = _get_lookup_function( - array_name=indexer_info.name, - axis_names=[f"__{var}_pos__" for var in indexer_info.axis_names], - ) - # ================================================================================== # create a function for the discrete lookup # ================================================================================== @@ -171,8 +162,8 @@ def _get_lookup_function( axis_names (list): List of strings with names for each axis in the array. Returns: - callable: A callable with the keyword-only arguments [axis_names] + [array_name] - that looks up values in an indexer array called `array_name`. + callable: A callable with the keyword-only arguments `[*axis_names]` that looks + up values from an array called `array_name`. """ arg_names = [*axis_names, array_name] diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 3648d0a1..9b32c5dd 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -8,28 +8,6 @@ from lcm.typing import ParamsDict, ShockType -@dataclass(frozen=True) -class IndexerInfo: - """Information needed to work with an indexer array. - - In particular, this contains enough information to wrap an indexer array into a - function that can be understood by dags. - - Attributes: - axis_names (list): List of strings containing the names of the axes of the - indexer array. - name (str): The name of the indexer array. This will become an argument name - of the function we need for dags. - out_name (str): The name of the result of indexing into the indexer. This will - become the name of the function we need for dags. - - """ - - axis_names: list[str] - name: str - out_name: str - - @dataclass(frozen=True) class Space: """Everything needed to evaluate a function on a space (e.g. state space). @@ -58,14 +36,12 @@ class SpaceInfo: their order. interpolation_info: Dict that defines information on the grids of all continuous variables. - indexer_infos: List of IndexerInfo objects. """ axis_names: list[str] lookup_info: dict[str, DiscreteGrid] interpolation_info: dict[str, ContinuousGrid] - indexer_infos: list[IndexerInfo] @dataclass(frozen=True) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index e6e73a26..f20c7cb8 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -15,7 +15,6 @@ def simulate( params, initial_states, - state_indexers, continuous_choice_grids, compute_ccv_policy_functions, model: InternalModel, @@ -32,8 +31,6 @@ def simulate( params (dict): Dict of model parameters. initial_states (list): List of initial states to start from. Typically from the observed dataset. - state_indexers (list): List of dicts of length n_periods. Each dict contains one - or several state indexers. continuous_choice_grids (list): List of dicts of length n_periods. Each dict contains 1d grids for continuous choice variables. compute_ccv_policy_functions (list): List of functions of length n_periods. Each @@ -130,7 +127,6 @@ def simulate( compute_ccv=compute_ccv_policy_functions[period], continuous_choice_grids=continuous_choice_grids[period], vf_arr=vf_arr_list[period], - state_indexers=state_indexers[period], params=params, ) @@ -218,7 +214,6 @@ def solve_continuous_problem( compute_ccv, continuous_choice_grids, vf_arr, - state_indexers, params, ): """Solve the agent's continuous choices problem problem. @@ -231,13 +226,10 @@ def solve_continuous_problem( - discrete and continuous state variables - discrete and continuous choice variables - vf_arr - - one or several state_indexers - params continuous_choice_grids (list): List of dicts with 1d grids for continuous choice variables. vf_arr (jax.numpy.ndarray): Value function array. - state_indexers (list): List of dicts with length n_periods. Each dict contains - one or several state indexers. params (dict): Dict of model parameters. Returns: @@ -260,7 +252,6 @@ def solve_continuous_problem( **data_scs.dense_vars, **continuous_choice_grids, **data_scs.sparse_vars, - **state_indexers, vf_arr=vf_arr, params=params, ) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 1d8f1162..115501e8 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -6,7 +6,6 @@ def solve( params, state_choice_spaces, - state_indexers, continuous_choice_grids, compute_ccv_functions, emax_calculators, @@ -28,8 +27,6 @@ def solve( state_choice_spaces (list): List with one state_choice_space per period. value_function_evaluators (list): List with one value_function_evaluator per period. - state_indexers (list): List of dicts with length n_periods. Each dict contains - one or several state indexers. continuous_choice_grids (list): List of dicts with 1d grids for continuous choice variables. compute_ccv_functions (list): List of functions needed to solve the agent's @@ -37,7 +34,6 @@ def solve( - discrete and continuous state variables - discrete and continuous choice variables - vf_arr - - one or several state_indexers - params emax_calculators (list): List of functions that take continuation values for combinations of states and discrete choices and calculate the @@ -63,7 +59,6 @@ def solve( compute_ccv=compute_ccv_functions[period], continuous_choice_grids=continuous_choice_grids[period], vf_arr=vf_arr, - state_indexers=state_indexers[period], params=params, ) @@ -82,7 +77,6 @@ def solve_continuous_problem( compute_ccv, continuous_choice_grids, vf_arr, - state_indexers, params, ): """Solve the agent's continuous choices problem problem. @@ -95,13 +89,10 @@ def solve_continuous_problem( - discrete and continuous state variables - discrete and continuous choice variables - vf_arr - - one or several state_indexers - params continuous_choice_grids (list): List of dicts with 1d grids for continuous choice variables. vf_arr (jax.numpy.ndarray): Value function array. - state_indexers (list): List of dicts with length n_periods. Each dict contains - one or several state indexers. params (dict): Dict of model parameters. Returns: @@ -119,7 +110,6 @@ def solve_continuous_problem( return gridmapped( **state_choice_space.dense_vars, **continuous_choice_grids, - **state_indexers, vf_arr=vf_arr, params=params, ) diff --git a/src/lcm/solution/state_space.py b/src/lcm/solution/state_space.py index 73c1f838..8caf6a05 100644 --- a/src/lcm/solution/state_space.py +++ b/src/lcm/solution/state_space.py @@ -56,14 +56,10 @@ def create_state_choice_space( _cont_states = set(vi.query("is_continuous & is_state").index.tolist()) interpolation_info = {k: v for k, v in model.gridspecs.items() if k in _cont_states} - # indexer infos - indexer_infos = [] # type: ignore[var-annotated] - space_info = SpaceInfo( axis_names=axis_names, lookup_info=lookup_info, # type: ignore[arg-type] interpolation_info=interpolation_info, # type: ignore[arg-type] - indexer_infos=indexer_infos, ) return state_choice_space, space_info diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index f0304943..0d0ed900 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -63,7 +63,6 @@ def simulate_inputs(): n_grid_points = model_config.choices["consumption"].n_points return { - "state_indexers": [{}], "continuous_choice_grids": [ {"consumption": jnp.linspace(1, 100, num=n_grid_points)}, ], diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 7b56f965..783b735b 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -38,11 +38,6 @@ def test_solve_brute(): ) state_choice_spaces = [_scs] * 2 - # ================================================================================== - # create the state_indexers (trivial because we do not have sparsity) - # ================================================================================== - state_indexers = [{}, {}] - # ================================================================================== # create continuous choice grids # ================================================================================== @@ -106,7 +101,6 @@ def calculate_emax(values, params): # noqa: ARG001 solution = solve( params=params, state_choice_spaces=state_choice_spaces, - state_indexers=state_indexers, continuous_choice_grids=continuous_choice_grids, compute_ccv_functions=utility_and_feasibility_functions, emax_calculators=emax_calculators, @@ -145,7 +139,6 @@ def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 compute_ccv, continuous_choice_grids, vf_arr=None, - state_indexers={}, params={}, ) aaae(got, expected) diff --git a/tests/test_function_representation.py b/tests/test_function_representation.py index 66ff7b28..da525d48 100644 --- a/tests/test_function_representation.py +++ b/tests/test_function_representation.py @@ -15,7 +15,6 @@ get_function_representation, ) from lcm.interfaces import ( - IndexerInfo, SpaceInfo, ) @@ -29,7 +28,6 @@ def test_function_evaluator_with_one_continuous_variable(): interpolation_info={ "wealth": wealth_grid, }, - indexer_infos=[], ) vf_arr = jnp.pi * wealth_grid.to_jax() + 2 @@ -57,7 +55,6 @@ def test_function_evaluator_with_one_discrete_variable(): axis_names=["working"], lookup_info={"working": [0, 1]}, interpolation_info={}, - indexer_infos=[], ) # create the evaluator @@ -78,10 +75,9 @@ def test_function_evaluator_with_one_discrete_variable(): def test_function_evaluator(): """Test get_precalculated_function_evaluator in simple example. - - One sparse discrete state variable: retired (True, False) - - One sparse discrete choice variable: working (True, False) - - One dense discrete choice variable: insured ("yes", "no") - - Two dense continuous state variables: + - One discrete state variable: retired (True, False) + - One discrete choice variable: insured ("yes", "no") + - Two continuous state variables: - wealth (linspace(100, 1100, 6)) - human_capital (linspace(-3, 3, 7)) @@ -93,29 +89,19 @@ def test_function_evaluator(): """ # create a value function array - discrete_part = jnp.arange(6).repeat(6 * 7).reshape((3, 2, 6, 7)) * 100 + discrete_part = jnp.arange(4).repeat(6 * 7).reshape((2, 2, 6, 7)) * 100 + cont_func = productmap(lambda x, y: x + y, ["x", "y"]) cont_part = cont_func(x=jnp.linspace(100, 1100, 6), y=jnp.linspace(-3, 3, 7)) + vf_arr = discrete_part + cont_part # create info on discrete variables lookup_info = { "retired": [0, 1], - "working": [0, 1], "insured": [0, 1], } - # create an indexer for the sparse discrete part - indexer_infos = [ - IndexerInfo( - axis_names=["retired", "working"], - name="state_indexer", - out_name="state_index", - ), - ] - - indexer_array = jnp.array([[-1, 0], [1, 2]]) - # create info on continuous grids interpolation_info = { "wealth": LinspaceGrid(start=100, stop=1100, n_points=6), @@ -123,13 +109,12 @@ def test_function_evaluator(): } # create info on axis of value function array - axis_names = ["state_index", "insured", "wealth", "human_capital"] + axis_names = ["retired", "insured", "wealth", "human_capital"] space_info = SpaceInfo( axis_names=axis_names, lookup_info=lookup_info, interpolation_info=interpolation_info, - indexer_infos=indexer_infos, ) # create the evaluator @@ -141,92 +126,13 @@ def test_function_evaluator(): # test the evaluator out = evaluator( retired=1, - working=1, - insured=0, - wealth=600, - human_capital=1.5, - state_indexer=indexer_array, - vf_arr=vf_arr, - ) - - assert jnp.allclose(out, 1001.5) - - -def test_function_evaluator_longer_indexer(): - """Test get_precalculated_function_evaluator in an extended example. - - - One sparse discrete state variable: retired ('working', 'part retired', retired) - - One sparse discrete choice variable: working (0, 1, 2) - - One dense discrete choice variable: insured ("yes", "no") - - Two dense continuous state variables: - - wealth (linspace(100, 1100, 6)) - - human_capital (linspace(-3, 3, 7)) - - The utility function is wealth + human_capital + c. c takes a different - value for each discrete state choice combination. - - The setup of space_info here is quite long. Usually these inputs will be generated - from a model specification. - - """ - # create a value function array - discrete_part = jnp.arange(10).repeat(6 * 7).reshape((5, 2, 6, 7)) * 100 - cont_func = productmap(lambda x, y: x + y, ["x", "y"]) - cont_part = cont_func(x=jnp.linspace(100, 1100, 6), y=jnp.linspace(-3, 3, 7)) - vf_arr = discrete_part + cont_part - - # create info on discrete variables - lookup_info = { - "retired": [0, 1, 2], - "working": [0, 1, 2], - "insured": [0, 1], - } - - # create an indexer for the sparse discrete part - indexer_infos = [ - IndexerInfo( - axis_names=["retired", "working"], - name="state_indexer", - out_name="state_index", - ), - ] - - indexer_array = jnp.array([[-1, 0, 1], [2, 3, -1], [4, -1, -1]]) - - # create info on continuous grids - interpolation_info = { - "wealth": LinspaceGrid(start=100, stop=1100, n_points=6), - "human_capital": LinspaceGrid(start=-3, stop=3, n_points=7), - } - - # create info on axis of value function array - axis_names = ["state_index", "insured", "wealth", "human_capital"] - - space_info = SpaceInfo( - axis_names=axis_names, - lookup_info=lookup_info, - interpolation_info=interpolation_info, - indexer_infos=indexer_infos, - ) - - # create the evaluator - evaluator = get_function_representation( - space_info=space_info, - name_of_values_on_grid="vf_arr", - ) - - # test the evaluator - out = evaluator( - retired=0, - working=1, insured=0, wealth=600, human_capital=1.5, - state_indexer=indexer_array, vf_arr=vf_arr, ) - assert jnp.allclose(out, 601.5) + assert jnp.allclose(out, 801.5) def test_get_label_translator_with_args(): @@ -255,10 +161,10 @@ def test_get_label_translator_wrong_kwarg(): def test_get_lookup_function(): - indexer = jnp.arange(6).reshape(3, 2) - func = _get_lookup_function(array_name="my_indexer", axis_names=["a", "b"]) + array = jnp.arange(6).reshape(3, 2) + func = _get_lookup_function(array_name="my_array", axis_names=["a", "b"]) - pure_lookup_func = partial(func, my_indexer=indexer) + pure_lookup_func = partial(func, my_array=array) calculated = pure_lookup_func(a=2, b=0) assert calculated == 4 @@ -309,7 +215,6 @@ def test_get_function_evaluator_illustrative(): interpolation_info={ "a": a_grid, }, - indexer_infos=[], ) values = jnp.pi * a_grid.to_jax() + 2 @@ -383,7 +288,6 @@ def test_fail_if_interpolation_axes_are_not_last_illustrative(): "c": None, }, lookup_info=None, - indexer_infos=None, ) _fail_if_interpolation_axes_are_not_last(space_info) # does not fail @@ -399,7 +303,6 @@ def test_fail_if_interpolation_axes_are_not_last_illustrative(): "d": None, }, lookup_info=None, - indexer_infos=None, ) _fail_if_interpolation_axes_are_not_last(space_info) # does not fail @@ -415,7 +318,6 @@ def test_fail_if_interpolation_axes_are_not_last_illustrative(): "d": None, }, lookup_info=None, - indexer_infos=None, ) with pytest.raises(ValueError, match="Interpolation axes need to be the last"): From 1f76cb36f52fb2a68b9f7e02f0b1e52dae26ef00 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 10 Feb 2025 11:14:51 +0100 Subject: [PATCH 04/16] Use SolutionSpace as state-choice-space during the model solution --- src/lcm/entry_point.py | 4 +- src/lcm/interfaces.py | 19 ++++++++ src/lcm/model_functions.py | 11 +++-- src/lcm/solution/solve_brute.py | 7 +-- src/lcm/solution/state_space.py | 74 +++++++++++------------------- tests/simulation/test_simulate.py | 4 +- tests/solution/test_solve_brute.py | 14 +++--- tests/solution/test_state_space.py | 29 ++++++++++-- tests/test_entry_point.py | 16 +++---- tests/test_model_functions.py | 4 +- 10 files changed, 102 insertions(+), 80 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index b5286f6c..e2506157 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -88,14 +88,14 @@ def get_lcm_function( # call state space creation function, append trivial items to their lists # ============================================================================== - sc_space, space_info = create_state_choice_space( + sc_space = create_state_choice_space( model=_mod, is_last_period=is_last_period, ) state_choice_spaces.append(sc_space) choice_segments.append(None) - space_infos.append(space_info) + space_infos.append(sc_space.state_space_info) # ================================================================================== # Shift space info (in period t we require the space info of period t+1) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 9b32c5dd..72937e92 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -8,6 +8,25 @@ from lcm.typing import ParamsDict, ShockType +@dataclass(frozen=True) +class SolutionSpace: + """The state-choice space of a model used during the solution process. + + The state-choice space is the Cartesian product of the state variables and the + choice variables, stored here as a dictionary of one-dimensional arrays. The + continuous choice variables are handled outside of this class. + + Attributes: + vars: Dictionary containing one dimensional grids of all variables, except for + continuous choice variables. + state_space_info: Information on the state variables. + + """ + + vars: dict[str, Array] + state_space_info: "SpaceInfo" + + @dataclass(frozen=True) class Space: """Everything needed to evaluate a function on a space (e.g. state space). diff --git a/src/lcm/model_functions.py b/src/lcm/model_functions.py index 2f9809ce..f400a1c1 100644 --- a/src/lcm/model_functions.py +++ b/src/lcm/model_functions.py @@ -11,16 +11,17 @@ all_as_kwargs, get_union_of_arguments, ) -from lcm.interfaces import InternalModel +from lcm.interfaces import InternalModel, SpaceInfo from lcm.next_state import get_next_state_function def get_utility_and_feasibility_function( model: InternalModel, - space_info, - name_of_values_on_grid, - period, - is_last_period, + space_info: SpaceInfo, + name_of_values_on_grid: str, + period: int, + *, + is_last_period: bool, ): # ================================================================================== # Gather information on the model variables diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 115501e8..f239750e 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -1,6 +1,7 @@ import jax from lcm.dispatchers import spacemap +from lcm.interfaces import SolutionSpace def solve( @@ -73,7 +74,7 @@ def solve( def solve_continuous_problem( - state_choice_space, + state_choice_space: SolutionSpace, compute_ccv, continuous_choice_grids, vf_arr, @@ -103,12 +104,12 @@ def solve_continuous_problem( """ _gridmapped = spacemap( func=compute_ccv, - dense_vars=list(state_choice_space.dense_vars), + dense_vars=list(state_choice_space.vars), ) gridmapped = jax.jit(_gridmapped) return gridmapped( - **state_choice_space.dense_vars, + **state_choice_space.vars, **continuous_choice_grids, vf_arr=vf_arr, params=params, diff --git a/src/lcm/solution/state_space.py b/src/lcm/solution/state_space.py index 8caf6a05..d94d2833 100644 --- a/src/lcm/solution/state_space.py +++ b/src/lcm/solution/state_space.py @@ -1,69 +1,51 @@ """Create a state space for a given model.""" -from lcm.interfaces import InternalModel, Space, SpaceInfo +from lcm.interfaces import InternalModel, SolutionSpace, SpaceInfo def create_state_choice_space( - model: InternalModel, *, is_last_period: bool -) -> tuple[Space, SpaceInfo]: - """Create a state-choice-space for the model. + model: InternalModel, + *, + is_last_period: bool, +) -> SolutionSpace: + """Create a state-choice-space for the model solution. A state-choice-space is a compressed representation of all feasible states and the feasible discrete choices within that state. Args: - model (Model): A processed model. - is_last_period (bool): Whether the function is created for the last period. + model: A processed model. + is_last_period: Whether the function is created for the last period. Returns: - Space: Space object containing the sparse and dense variables. This can be used - to execute a function on an entire space. - SpaceInfo: A SpaceInfo object that contains all information needed to work with - the output of a function evaluated on the space. + SolutionSpace: An object containing the variable values of all variables in the + state-choice-space, the grid specifications for the state variables, and the + names of the state variables. Continuous choice variables are not included. """ - # ================================================================================== - # preparations - # ================================================================================== vi = model.variable_info if is_last_period: vi = vi.query("~is_auxiliary") - # ================================================================================== - # create state choice space - # ================================================================================== - _value_grid = _create_value_grid( - grids=model.grids, - subset=vi.query("~(is_choice & is_continuous)").index.tolist(), - ) - - state_choice_space = Space( - sparse_vars={}, - dense_vars=_value_grid, - ) - - # ================================================================================== - # create state space info - # ================================================================================== - # axis_names - axis_names = vi.query("is_state").index.tolist() + discrete_states_names = vi.query("is_discrete & is_state").index.tolist() + continuous_states_names = vi.query("is_continuous & is_state").index.tolist() - # lookup_info - _discrete_states = set(vi.query("is_discrete & is_state").index.tolist()) - lookup_info = {k: v for k, v in model.gridspecs.items() if k in _discrete_states} + discrete_states = {sn: model.gridspecs[sn] for sn in discrete_states_names} + continuous_states = {sn: model.gridspecs[sn] for sn in continuous_states_names} - # interpolation info - _cont_states = set(vi.query("is_continuous & is_state").index.tolist()) - interpolation_info = {k: v for k, v in model.gridspecs.items() if k in _cont_states} + # Create a dictionary with all state and choice variables and their feasible values, + # except for continuous choice variables, since they are treated differently. + space_grids = { + sn: model.grids[sn] for sn in vi.query("is_state | is_discrete").index.tolist() + } - space_info = SpaceInfo( - axis_names=axis_names, - lookup_info=lookup_info, # type: ignore[arg-type] - interpolation_info=interpolation_info, # type: ignore[arg-type] + state_space_info = SpaceInfo( + axis_names=discrete_states_names + continuous_states_names, + lookup_info=discrete_states, # type: ignore[arg-type] + interpolation_info=continuous_states, # type: ignore[arg-type] ) - return state_choice_space, space_info - - -def _create_value_grid(grids, subset): - return {name: grid for name, grid in grids.items() if name in subset} + return SolutionSpace( + vars=space_grids, + state_space_info=state_space_info, + ) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 0d0ed900..fb362f94 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -40,7 +40,7 @@ def simulate_inputs(): model_config = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=1) model = process_model(model_config) - _, space_info = create_state_choice_space( + sc_space = create_state_choice_space( model=model, is_last_period=False, ) @@ -49,7 +49,7 @@ def simulate_inputs(): for period in range(model.n_periods): u_and_f = get_utility_and_feasibility_function( model=model, - space_info=space_info, + space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=period, is_last_period=True, diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 783b735b..f04ada8d 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -3,7 +3,7 @@ from numpy.testing import assert_array_almost_equal as aaae from lcm.entry_point import create_compute_conditional_continuation_value -from lcm.interfaces import Space +from lcm.interfaces import SolutionSpace from lcm.logging import get_logger from lcm.ndimage import map_coordinates from lcm.solution.solve_brute import solve, solve_continuous_problem @@ -25,9 +25,8 @@ def test_solve_brute(): # ================================================================================== # create the list of state_choice_spaces # ================================================================================== - _scs = Space( - sparse_vars={}, - dense_vars={ + _scs = SolutionSpace( + vars={ # pick [0, 1] such that no label translation is needed # lazy is like a type, it influences utility but is not affected by choices "lazy": jnp.array([0, 1]), @@ -35,6 +34,7 @@ def test_solve_brute(): # pick [0, 1, 2] such that no coordinate mapping is needed "wealth": jnp.array([0.0, 1.0, 2.0]), }, + state_space_info=None, ) state_choice_spaces = [_scs] * 2 @@ -111,13 +111,13 @@ def calculate_emax(values, params): # noqa: ARG001 def test_solve_continuous_problem_no_vf_arr(): - state_choice_space = Space( - dense_vars={ + state_choice_space = SolutionSpace( + vars={ "a": jnp.array([0, 1.0]), "b": jnp.array([2, 3.0]), "c": jnp.array([4, 5, 6]), }, - sparse_vars={}, + state_space_info=None, ) def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 diff --git a/tests/solution/test_state_space.py b/tests/solution/test_state_space.py index a9d0981f..9209e1f8 100644 --- a/tests/solution/test_state_space.py +++ b/tests/solution/test_state_space.py @@ -1,4 +1,7 @@ +import jax.numpy as jnp + from lcm.input_processing import process_model +from lcm.interfaces import SolutionSpace, SpaceInfo from lcm.solution.state_space import ( create_state_choice_space, ) @@ -6,10 +9,26 @@ def test_create_state_choice_space(): - _model = process_model( - get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3), - ) - create_state_choice_space( - model=_model, + model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) + internal_model = process_model(model) + + state_choice_space = create_state_choice_space( + model=internal_model, is_last_period=False, ) + + assert isinstance(state_choice_space, SolutionSpace) + assert isinstance(state_choice_space.state_space_info, SpaceInfo) + + assert jnp.array_equal( + state_choice_space.vars["retirement"], model.choices["retirement"].to_jax() + ) + assert jnp.array_equal( + state_choice_space.vars["wealth"], model.states["wealth"].to_jax() + ) + + state_space_info = state_choice_space.state_space_info + + assert state_space_info.axis_names == ["wealth"] + assert state_space_info.lookup_info == {} + assert state_space_info.interpolation_info == model.states diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 7c478b26..cded434e 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -182,14 +182,14 @@ def test_create_compute_conditional_continuation_value(): }, } - _, space_info = create_state_choice_space( + sc_space = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - space_info=space_info, + space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -228,14 +228,14 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): }, } - _, space_info = create_state_choice_space( + sc_space = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - space_info=space_info, + space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -279,14 +279,14 @@ def test_create_compute_conditional_continuation_policy(): }, } - _, space_info = create_state_choice_space( + sc_space = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - space_info=space_info, + space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -326,14 +326,14 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): }, } - _, space_info = create_state_choice_space( + sc_space = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - space_info=space_info, + space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index a20d74d2..e834319d 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -30,14 +30,14 @@ def test_get_utility_and_feasibility_function(): }, } - _, space_info = create_state_choice_space( + sc_space = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - space_info=space_info, + space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, From 0515af87b735f1416cd94e001bfe5f7c5ed0c520 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 10 Feb 2025 11:20:02 +0100 Subject: [PATCH 05/16] Update return-type of determine_dense_discrete_choice_axes --- src/lcm/discrete_problem.py | 22 ++++++---------------- src/lcm/simulation/simulate.py | 11 +++-------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index 019a7a47..30c2cf35 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -71,7 +71,7 @@ def get_solve_discrete_problem( def _solve_discrete_problem_no_shocks( cc_values: Array, - choice_axes: tuple[int, ...] | None, + choice_axes: tuple[int, ...], params: ParamsDict, # noqa: ARG001 ) -> Array: """Reduce conditional continuation values over discrete choices. @@ -90,11 +90,7 @@ def _solve_discrete_problem_no_shocks( if choice_segments is not None. """ - out = cc_values - if choice_axes is not None: - out = out.max(axis=choice_axes) - - return out + return cc_values.max(axis=choice_axes) # ====================================================================================== @@ -139,24 +135,18 @@ def _calculate_emax_extreme_value_shocks(values, choice_axes, params): def _determine_dense_discrete_choice_axes( variable_info: pd.DataFrame, -) -> tuple[int, ...] | None: +) -> tuple[int, ...]: """Get axes of a state-choice-space that correspond to discrete choices. Args: variable_info: DataFrame with information about the variables. Returns: - tuple[int, ...] | None: A tuple of indices representing the axes' positions in - the value function that correspond to discrete choices. Returns None if - there are no discrete choice axes. + tuple[int, ...]: A tuple of indices representing the axes' positions in + the value function that correspond to discrete choices. """ # List of dense variables excluding continuous choice variables. axes = variable_info.query("is_state | is_discrete").index.tolist() - choice_vars = set(variable_info.query("is_choice").index.tolist()) - - choice_indices = tuple(i for i, ax in enumerate(axes) if ax in choice_vars) - - # Return None if there are no discrete choice axes, otherwise return the indices. - return choice_indices if choice_indices else None + return tuple(i for i, ax in enumerate(axes) if ax in choice_vars) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index f20c7cb8..0454df04 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -530,10 +530,7 @@ def _calculate_discrete_argmax(values, choice_axes, choice_segments): # noqa: A # Determine argmax and max over dense choices # ============================================================================== - if choice_axes is not None: - dense_argmax, _max = argmax(_max, axis=choice_axes) - else: - dense_argmax = None + dense_argmax, _max = argmax(_max, axis=choice_axes) # Determine argmax and max over sparse choices # ============================================================================== @@ -584,8 +581,6 @@ def determine_discrete_dense_choice_axes(variable_info): choice_vars = set(variable_info.query("is_choice").index.tolist()) # We add 1 because the first dimension corresponds to the sparse state variables - choice_indices = [ + return tuple( i + 1 for i, ax in enumerate(discrete_dense_choice_vars) if ax in choice_vars - ] - - return None if not choice_indices else tuple(choice_indices) + ) From 11defbbeae11e0d36d3da0b20d9a2dd167c2ef08 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 12 Feb 2025 16:04:17 +0100 Subject: [PATCH 06/16] Rename dense -> product; sparse -> combination --- src/lcm/argmax.py | 4 +- src/lcm/discrete_problem.py | 12 ++--- src/lcm/dispatchers.py | 78 +++++++++++++++++-------------- src/lcm/interfaces.py | 18 +++---- src/lcm/simulation/simulate.py | 34 +++++++------- src/lcm/solution/solve_brute.py | 4 +- tests/simulation/test_simulate.py | 6 +-- tests/test_discrete_problem.py | 6 +-- tests/test_dispatchers.py | 16 +++---- 9 files changed, 91 insertions(+), 87 deletions(-) diff --git a/src/lcm/argmax.py b/src/lcm/argmax.py index 60f20aea..601e34e5 100644 --- a/src/lcm/argmax.py +++ b/src/lcm/argmax.py @@ -73,7 +73,7 @@ def _move_axes_to_back(a: Array, axes: tuple[int, ...]) -> Array: axes (tuple): Axes to move to the back. Returns: - jax.numpy.ndarray: Array a with shifted axes. + jax.Array: Array a with shifted axes. """ front_axes = sorted(set(range(a.ndim)) - set(axes)) @@ -88,7 +88,7 @@ def _flatten_last_n_axes(a: Array, n: int) -> Array: n (int): Number of axes to flatten. Returns: - jax.numpy.ndarray: Array a with flattened last n axes. + jax.Array: Array a with flattened last n axes. """ return a.reshape(*a.shape[:-n], -1) diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index 30c2cf35..f9408475 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -52,7 +52,7 @@ def get_solve_discrete_problem( if is_last_period: variable_info = variable_info.query("~is_auxiliary") - choice_axes = _determine_dense_discrete_choice_axes(variable_info) + choice_axes = _determine_discrete_choice_axes(variable_info) if random_utility_shock_type == ShockType.NONE: func = _solve_discrete_problem_no_shocks @@ -104,10 +104,10 @@ def _calculate_emax_extreme_value_shocks(values, choice_axes, params): """Aggregate conditional continuation values over discrete choices. Args: - values (jax.numpy.ndarray): Multidimensional jax array with conditional + values (jax.Array): Multidimensional jax array with conditional continuation values. choice_axes (int or tuple): Int or tuple of int, specifying which axes in - values correspond to dense choice variables. + values correspond to the discrete choice variables. choice_segments (dict): Dictionary with the entries "segment_ids" and "num_segments". segment_ids are a 1d integer array that partitions the first dimension of values into choice sets over which we need to aggregate. @@ -115,7 +115,7 @@ def _calculate_emax_extreme_value_shocks(values, choice_axes, params): params (dict): Params dict that contains the schock_scale if necessary. Returns: - jax.numpy.ndarray: Multidimensional jax array with aggregated continuation + jax.Array: Multidimensional jax array with aggregated continuation values. Has less dimensions than values if choice_axes is not None and is shorter in the first dimension if choice_segments is not None. @@ -133,7 +133,7 @@ def _calculate_emax_extreme_value_shocks(values, choice_axes, params): # ====================================================================================== -def _determine_dense_discrete_choice_axes( +def _determine_discrete_choice_axes( variable_info: pd.DataFrame, ) -> tuple[int, ...]: """Get axes of a state-choice-space that correspond to discrete choices. @@ -146,7 +146,7 @@ def _determine_dense_discrete_choice_axes( the value function that correspond to discrete choices. """ - # List of dense variables excluding continuous choice variables. + # List of all model variables excluding the continuous choice variables. axes = variable_info.query("is_state | is_discrete").index.tolist() choice_vars = set(variable_info.query("is_choice").index.tolist()) return tuple(i for i, ax in enumerate(axes) if ax in choice_vars) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index a1e04db2..b0fc82bf 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -11,55 +11,59 @@ def spacemap( func: F, - dense_vars: list[str], - sparse_vars: list[str] | None = None, + product_vars: list[str], + combination_vars: list[str] | None = None, ) -> F: - """Apply vmap such that func is evaluated on a space of dense and sparse variables. - - This is achieved by applying _base_productmap for all dense variables and vmap_1d - for the sparse variables. + """Apply vmap such that func can be evaluated on product and combination variables. + + Product variables are used to create a Cartesian product of possible values. I.e., + for each product variable, we create a new leading dimension in the output object, + with the size of the dimension being the number of possible values in the grid. The + i-th entries of the combination variables, correspond to one valid combination. For + the combination variables, a single dimension is thus added to the output object, + with the size of the dimension being the number of possible combinations. This means + that all combination variables must have the same size (e.g., in the simulation the + states act as combination variables, and their size equals the number of + simulations). spacemap preserves the function signature and allows the function to be called with keyword arguments. Args: func: The function to be dispatched. - dense_vars: Names of the dense variables, i.e. those that are stored as arrays - of possible values in the grid. - sparse_vars: Names of the sparse variables, i.e. those that are stored as arrays - of possible combinations of variables in the grid. - put_dense_first: Whether the dense or sparse dimensions should come first in the - output of the dispatched function. - + product_vars: Names of the product variables, i.e. those that are stored as + arrays of possible values in the grid, over which we create a Cartesian + product. + combination_vars: Names of the combination variables, i.e. those that are + stored as arrays of possible combinations. Returns: A callable with the same arguments as func (but with an additional leading - dimension) that returns a jax.numpy.ndarray or pytree of arrays. If `func` - returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1 - jax.numpy.ndarray with k + 1 dimensions, where k is the length of `dense_vars` - and the additional dimension corresponds to the `sparse_vars`. The order of - the dimensions is determined by the order of `dense_vars` as well as the - `put_dense_first` argument. If the output of `func` is a jax pytree, the - usual jax behavior applies, i.e. the leading dimensions of all arrays in the - pytree are as described above but there might be additional dimensions. + dimension) that returns a jax.Array or pytree of arrays. If `func` returns a + scalar, the dispatched function returns a jax.Array with k + 1 dimensions, where + k is the length of `product_vars` and the additional dimension corresponds to + the `combination_vars`. The order of the dimensions is determined by the order + of `product_vars`. If the output of `func` is a jax pytree, the usual jax + behavior applies, i.e. the leading dimensions of all arrays in the pytree are as + described above but there might be additional dimensions. """ # Check inputs and prepare function # ================================================================================== - duplicates = {v for v in dense_vars if dense_vars.count(v) > 1} + duplicates = {v for v in product_vars if product_vars.count(v) > 1} if duplicates: raise ValueError( - f"Same argument provided more than once in dense variables: {duplicates}", + f"Same argument provided more than once in product variables: {duplicates}", ) - if sparse_vars: - overlap = set(dense_vars).intersection(sparse_vars) + if combination_vars: + overlap = set(product_vars).intersection(combination_vars) if overlap: raise ValueError( f"Dense and sparse variables must be disjoint. Overlap: {overlap}", ) - duplicates = {v for v in sparse_vars if sparse_vars.count(v) > 1} + duplicates = {v for v in combination_vars if combination_vars.count(v) > 1} if duplicates: raise ValueError( "Same argument provided more than once in sparse variables: " @@ -69,13 +73,15 @@ def spacemap( # jax.vmap cannot deal with keyword-only arguments func = allow_args(func) - # Apply vmap_1d for sparse and _base_productmap for dense variables + # Apply vmap_1d for sparse and _base_productmap for product variables # ================================================================================== - if not sparse_vars: - vmapped = _base_productmap(func, dense_vars) + if not combination_vars: + vmapped = _base_productmap(func, product_vars) else: - vmapped = _base_productmap(func, dense_vars) - vmapped = vmap_1d(vmapped, variables=sparse_vars, callable_with="only_args") + vmapped = _base_productmap(func, product_vars) + vmapped = vmap_1d( + vmapped, variables=combination_vars, callable_with="only_args" + ) # This raises a mypy error but is perfectly fine to do. See # https://github.com/python/mypy/issues/12472 @@ -106,9 +112,9 @@ def vmap_1d( Returns: A callable with the same arguments as func (but with an additional leading - dimension) that returns a jax.numpy.ndarray or pytree of arrays. If `func` - returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1 - jax.numpy.ndarray with 1 dimension and length k, where k is the length of one of + dimension) that returns a jax.Array or pytree of arrays. If `func` + returns a scalar, the dispatched function returns a jax.Array with 1 + jax.Array with 1 dimension and length k, where k is the length of one of the mapped inputs in `variables`. The order of the dimensions is determined by the order of `variables` which can be different to the order of `funcs` arguments. If the output of `func` is a jax pytree, the usual jax behavior @@ -169,8 +175,8 @@ def productmap(func: F, variables: list[str]) -> F: Returns: A callable with the same arguments as func (but with an additional leading - dimension) that returns a jax.numpy.ndarray or pytree of arrays. If `func` - returns a scalar, the dispatched function returns a jax.numpy.ndarray with k + dimension) that returns a jax.Array or pytree of arrays. If `func` + returns a scalar, the dispatched function returns a jax.Array with k dimensions, where k is the length of `variables`. The order of the dimensions is determined by the order of `variables` which can be different to the order of `funcs` arguments. If the output of `func` is a jax pytree, the usual jax diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 72937e92..f63c9c12 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -28,20 +28,20 @@ class SolutionSpace: @dataclass(frozen=True) -class Space: - """Everything needed to evaluate a function on a space (e.g. state space). +class SimulationSpace: + """The state-choice space of a model used during the simulation process. + + The state-choice space is the product of the state variables with the Cartesian + product of the choice variables. Attributes: - sparse_vars (dict): Dictionary containing the names of sparse variables as keys - and arrays with values of those variables as values. Together, the arrays - define all feasible combinations of sparse variables. - dense_vars (dict): Dictionary containing one dimensional grids of - dense variables. + states: Dictionary containing the values of the state variables. + choices: Dictionary containing the values of the choice variables. """ - sparse_vars: dict[str, Array] - dense_vars: dict[str, Array] + states: dict[str, Array] + choices: dict[str, Array] @dataclass(frozen=True) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 0454df04..f738650e 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -9,7 +9,7 @@ from lcm.argmax import argmax from lcm.dispatchers import spacemap, vmap_1d -from lcm.interfaces import InternalModel, Space +from lcm.interfaces import InternalModel, SimulationSpace def simulate( @@ -108,9 +108,7 @@ def simulate( # Compute objects dependent on data-state-choice-space # ============================================================================== - dense_vars_grid_shape = tuple( - len(grid) for grid in data_scs.dense_vars.values() - ) + dense_vars_grid_shape = tuple(len(grid) for grid in data_scs.choices.values()) cont_choice_grid_shape = tuple( len(grid) for grid in continuous_choice_grids[period].values() ) @@ -152,7 +150,7 @@ def simulate( # ============================================================================== dense_choices = retrieve_non_sparse_choices( indices=dense_argmax, - grids=data_scs.dense_vars, + grids=data_scs.choices, grid_shape=dense_vars_grid_shape, ) @@ -210,7 +208,7 @@ def simulate( def solve_continuous_problem( - data_scs, + data_scs: SimulationSpace, compute_ccv, continuous_choice_grids, vf_arr, @@ -229,7 +227,7 @@ def solve_continuous_problem( - params continuous_choice_grids (list): List of dicts with 1d grids for continuous choice variables. - vf_arr (jax.numpy.ndarray): Value function array. + vf_arr (jax.Array): Value function array. params (dict): Dict of model parameters. Returns: @@ -243,15 +241,15 @@ def solve_continuous_problem( """ _gridmapped = spacemap( func=compute_ccv, - dense_vars=list(data_scs.dense_vars), - sparse_vars=list(data_scs.sparse_vars), + product_vars=list(data_scs.choices), + combination_vars=list(data_scs.states), ) gridmapped = jax.jit(_gridmapped) return gridmapped( - **data_scs.dense_vars, + **data_scs.choices, + **data_scs.states, **continuous_choice_grids, - **data_scs.sparse_vars, vf_arr=vf_arr, params=params, ) @@ -390,13 +388,13 @@ def filter_ccv_policy( """Select optimal continuous choice index given optimal discrete choice. Args: - ccv_policy (jax.numpy.ndarray): Index array of optimal continous choices + ccv_policy (jax.Array): Index array of optimal continous choices conditional on discrete choices. dense_argmax (jax.numpy.array): Index array of optimal dense choices. dense_vars_grid_shape (tuple): Shape of the dense variables grid. Returns: - jax.numpy.ndarray: Index array of optimal continuous choices. + jax.Array: Index array of optimal continuous choices. """ if dense_argmax is None: @@ -487,9 +485,9 @@ def create_data_scs( if name in vi.query("is_choice & is_discrete").index.tolist() } - data_scs = Space( - sparse_vars=states, - dense_vars=dense_choices, + data_scs = SimulationSpace( + states=states, + choices=dense_choices, ) # create choice segments @@ -516,9 +514,9 @@ def get_discrete_policy_calculator(variable_info): Returns: callable: Function that calculates the argmax of the conditional continuation values. The function depends on: - - values (jax.numpy.ndarray): Multidimensional jax array with conditional + - values (jax.Array): Multidimensional jax array with conditional continuation values. - - choice_segments (jax.numpy.ndarray): Jax array with the indices of the + - choice_segments (jax.Array): Jax array with the indices of the choice segments that indicate which sparse choice variables belong to one state. diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index f239750e..a148e344 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -93,7 +93,7 @@ def solve_continuous_problem( - params continuous_choice_grids (list): List of dicts with 1d grids for continuous choice variables. - vf_arr (jax.numpy.ndarray): Value function array. + vf_arr (jax.Array): Value function array. params (dict): Dict of model parameters. Returns: @@ -104,7 +104,7 @@ def solve_continuous_problem( """ _gridmapped = spacemap( func=compute_ccv, - dense_vars=list(state_choice_space.vars), + product_vars=list(state_choice_space.vars), ) gridmapped = jax.jit(_gridmapped) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index fb362f94..e07e09b3 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -414,9 +414,9 @@ def test_create_data_state_choice_space(): }, model=model, ) - assert_array_equal(got_space.dense_vars["retirement"], jnp.array([0, 1])) - assert_array_equal(got_space.sparse_vars["wealth"], jnp.array([10.0, 20.0])) - assert_array_equal(got_space.sparse_vars["lagged_retirement"], jnp.array([0, 1])) + assert_array_equal(got_space.choices["retirement"], jnp.array([0, 1])) + assert_array_equal(got_space.states["wealth"], jnp.array([10.0, 20.0])) + assert_array_equal(got_space.states["lagged_retirement"], jnp.array([0, 1])) assert got_segment_info is None diff --git a/tests/test_discrete_problem.py b/tests/test_discrete_problem.py index b942dba9..3657fa34 100644 --- a/tests/test_discrete_problem.py +++ b/tests/test_discrete_problem.py @@ -5,7 +5,7 @@ from lcm.discrete_problem import ( _calculate_emax_extreme_value_shocks, - _determine_dense_discrete_choice_axes, + _determine_discrete_choice_axes, _solve_discrete_problem_no_shocks, get_solve_discrete_problem, ) @@ -119,7 +119,7 @@ def test_determine_discrete_choice_axes_illustrative_one_var(): }, ) - assert _determine_dense_discrete_choice_axes(variable_info) == (1,) + assert _determine_discrete_choice_axes(variable_info) == (1,) @pytest.mark.illustrative @@ -133,4 +133,4 @@ def test_determine_discrete_choice_axes_illustrative_three_var(): }, ) - assert _determine_dense_discrete_choice_axes(variable_info) == (1, 2, 3) + assert _determine_discrete_choice_axes(variable_info) == (1, 2, 3) diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index 291f9601..3cea774f 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -229,20 +229,20 @@ def test_spacemap_all_arguments_mapped( setup_spacemap, expected_spacemap, ): - dense_vars, sparse_vars = setup_spacemap + product_vars, combination_vars = setup_spacemap decorated = spacemap( g, - list(dense_vars), - list(sparse_vars), + list(product_vars), + list(combination_vars), ) - calculated = decorated(**dense_vars, **sparse_vars) + calculated = decorated(**product_vars, **combination_vars) aaae(calculated, jnp.transpose(expected_spacemap, axes=(2, 0, 1))) @pytest.mark.parametrize( - ("error_msg", "dense_vars", "sparse_vars"), + ("error_msg", "product_vars", "combination_vars"), [ ( "Dense and sparse variables must be disjoint. Overlap: {'a'}", @@ -250,15 +250,15 @@ def test_spacemap_all_arguments_mapped( ["a", "c", "d"], ), ( - "Same argument provided more than once in dense variables: {'a'}", + "Same argument provided more than once in product variables: {'a'}", ["a", "a", "b"], ["c", "d"], ), ], ) -def test_spacemap_arguments_overlap(error_msg, dense_vars, sparse_vars): +def test_spacemap_arguments_overlap(error_msg, product_vars, combination_vars): with pytest.raises(ValueError, match=error_msg): - spacemap(g, dense_vars=dense_vars, sparse_vars=sparse_vars) + spacemap(g, product_vars=product_vars, combination_vars=combination_vars) # ====================================================================================== From eca7a59b5de6d573483b334d739908d55791f2d5 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 12 Feb 2025 16:16:11 +0100 Subject: [PATCH 07/16] Remove mentions of dense variables from codebase --- src/lcm/simulation/simulate.py | 64 +++++++++++++++---------------- tests/simulation/test_simulate.py | 20 +++++----- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index f738650e..c5b56043 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -108,7 +108,7 @@ def simulate( # Compute objects dependent on data-state-choice-space # ============================================================================== - dense_vars_grid_shape = tuple(len(grid) for grid in data_scs.choices.values()) + vars_grid_shape = tuple(len(grid) for grid in data_scs.choices.values()) cont_choice_grid_shape = tuple( len(grid) for grid in continuous_choice_grids[period].values() ) @@ -130,31 +130,31 @@ def simulate( # Get optimal discrete choice given the optimal conditional continuous choices # ============================================================================== - dense_argmax, sparse_argmax, value = discrete_policy_calculator(ccv) + discrete_argmax, sparse_argmax, value = discrete_policy_calculator(ccv) # Select optimal continuous choice corresponding to optimal discrete choice # ------------------------------------------------------------------------------ # The conditional continuous choice argmax is computed for each discrete choice # in the data-state-choice-space. Here we select the the optimal continuous - # choice corresponding to the optimal discrete choice (dense and sparse). + # choice corresponding to the optimal discrete choice. # ============================================================================== cont_choice_argmax = filter_ccv_policy( ccv_policy=ccv_policy, - dense_argmax=dense_argmax, - dense_vars_grid_shape=dense_vars_grid_shape, + discrete_argmax=discrete_argmax, + vars_grid_shape=vars_grid_shape, ) if sparse_argmax is not None: cont_choice_argmax = cont_choice_argmax[sparse_argmax] # Convert optimal choice indices to actual choice values # ============================================================================== - dense_choices = retrieve_non_sparse_choices( - indices=dense_argmax, + choices = retrieve_choices( + indices=discrete_argmax, grids=data_scs.choices, - grid_shape=dense_vars_grid_shape, + grid_shape=vars_grid_shape, ) - cont_choices = retrieve_non_sparse_choices( + cont_choices = retrieve_choices( indices=cont_choice_argmax, grids=continuous_choice_grids[period], grid_shape=cont_choice_grid_shape, @@ -162,7 +162,7 @@ def simulate( # Store results # ============================================================================== - choices = {**dense_choices, **cont_choices} + choices = {**choices, **cont_choices} _simulation_results.append( { @@ -217,7 +217,7 @@ def solve_continuous_problem( """Solve the agent's continuous choices problem problem. Args: - data_scs (Space): Class with entries dense_vars and sparse_vars. + data_scs: Class with entries choices and states. compute_ccv (callable): Function that returns the conditional continuation values for a given combination of states and discrete choices. The function depends on: @@ -379,28 +379,28 @@ def _generate_simulation_keys(key, ids): # ====================================================================================== -@partial(vmap_1d, variables=["ccv_policy", "dense_argmax"]) +@partial(vmap_1d, variables=["ccv_policy", "discrete_argmax"]) def filter_ccv_policy( ccv_policy, - dense_argmax, - dense_vars_grid_shape, + discrete_argmax, + vars_grid_shape, ): """Select optimal continuous choice index given optimal discrete choice. Args: ccv_policy (jax.Array): Index array of optimal continous choices conditional on discrete choices. - dense_argmax (jax.numpy.array): Index array of optimal dense choices. - dense_vars_grid_shape (tuple): Shape of the dense variables grid. + discrete_argmax (jax.numpy.array): Index array of optimal discrete choices. + vars_grid_shape (tuple): Shape of the variables grid. Returns: jax.Array: Index array of optimal continuous choices. """ - if dense_argmax is None: + if discrete_argmax is None: out = ccv_policy else: - indices = jnp.unravel_index(dense_argmax, shape=dense_vars_grid_shape) + indices = jnp.unravel_index(discrete_argmax, shape=vars_grid_shape) out = ccv_policy[indices] return out @@ -410,8 +410,8 @@ def filter_ccv_policy( # ====================================================================================== -def retrieve_non_sparse_choices(indices, grids, grid_shape): - """Retrieve dense or continuous choices given indices. +def retrieve_choices(indices, grids, grid_shape): + """Retrieve choices given indices. Args: indices (jnp.numpy.ndarray or None): General indices. Represents the index of @@ -477,9 +477,9 @@ def create_data_scs( f"Provided variables that are not states: {too_many}", ) - # get sparse and dense choices + # get choices # ================================================================================== - dense_choices = { + choices = { name: grid for name, grid in model.grids.items() if name in vi.query("is_choice & is_discrete").index.tolist() @@ -487,7 +487,7 @@ def create_data_scs( data_scs = SimulationSpace( states=states, - choices=dense_choices, + choices=choices, ) # create choice segments @@ -521,20 +521,20 @@ def get_discrete_policy_calculator(variable_info): one state. """ - choice_axes = determine_discrete_dense_choice_axes(variable_info) + choice_axes = determine_discrete_choice_axes(variable_info) def _calculate_discrete_argmax(values, choice_axes, choice_segments): # noqa: ARG001 _max = values - # Determine argmax and max over dense choices + # Determine argmax and max over choices # ============================================================================== - dense_argmax, _max = argmax(_max, axis=choice_axes) + discrete_argmax, _max = argmax(_max, axis=choice_axes) # Determine argmax and max over sparse choices # ============================================================================== sparse_argmax = None - return dense_argmax, sparse_argmax, _max + return discrete_argmax, sparse_argmax, _max return partial(_calculate_discrete_argmax, choice_axes=choice_axes) @@ -561,18 +561,18 @@ def dict_product(d): return dict(zip(d.keys(), list(stacked.T), strict=True)), len(stacked) -def determine_discrete_dense_choice_axes(variable_info): - """Determine which axes correspond to discrete and dense choices. +def determine_discrete_choice_axes(variable_info): + """Determine which axes correspond to discrete choices. Args: variable_info (pd.DataFrame): DataFrame with information about the variables. Returns: tuple: Tuple of ints, specifying which axes in a value function correspond to - discrete and dense choices. + discrete choices. """ - discrete_dense_choice_vars = variable_info.query( + discrete_choice_vars = variable_info.query( "is_choice & is_discrete", ).index.tolist() @@ -580,5 +580,5 @@ def determine_discrete_dense_choice_axes(variable_info): # We add 1 because the first dimension corresponds to the sparse state variables return tuple( - i + 1 for i, ax in enumerate(discrete_dense_choice_vars) if ax in choice_vars + i + 1 for i, ax in enumerate(discrete_choice_vars) if ax in choice_vars ) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index e07e09b3..3ca3eeca 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -18,10 +18,10 @@ _generate_simulation_keys, _process_simulated_data, create_data_scs, - determine_discrete_dense_choice_axes, + determine_discrete_choice_axes, dict_product, filter_ccv_policy, - retrieve_non_sparse_choices, + retrieve_choices, simulate, ) from lcm.solution.state_space import create_state_choice_space @@ -369,7 +369,7 @@ def test_process_simulated_data(): def test_retrieve_non_sparse_choices(): - got = retrieve_non_sparse_choices( + got = retrieve_choices( indices=jnp.array([0, 3, 7]), grids={"a": jnp.linspace(0, 1, 5), "b": jnp.linspace(10, 20, 6)}, grid_shape=(5, 6), @@ -379,7 +379,7 @@ def test_retrieve_non_sparse_choices(): def test_retrieve_non_sparse_choices_no_indices(): - got = retrieve_non_sparse_choices( + got = retrieve_choices( indices=None, grids={"a": jnp.linspace(0, 1, 5), "b": jnp.linspace(10, 20, 6)}, grid_shape=(5, 6), @@ -394,12 +394,12 @@ def test_filter_ccv_policy(): [1, 0], ], ) - dense_argmax = jnp.array([0, 1]) - dense_vars_grid_shape = (2,) + argmax = jnp.array([0, 1]) + vars_grid_shape = (2,) got = filter_ccv_policy( ccv_policy=ccc_policy, - dense_argmax=dense_argmax, - dense_vars_grid_shape=dense_vars_grid_shape, + discrete_argmax=argmax, + vars_grid_shape=vars_grid_shape, ) assert jnp.all(got == jnp.array([0, 0])) @@ -429,7 +429,7 @@ def test_dict_product(): assert_array_equal(got_dict[key], val) -def test_determine_discrete_dense_choice_axes(): +def test_determine_discrete_choice_axes(): variable_info = pd.DataFrame( { "is_state": [True, True, False, True, False, False], @@ -438,5 +438,5 @@ def test_determine_discrete_dense_choice_axes(): "is_continuous": [False, True, False, False, False, True], }, ) - got = determine_discrete_dense_choice_axes(variable_info) + got = determine_discrete_choice_axes(variable_info) assert got == (1, 2, 3) From df1288280059b15d1d66fe6115f4c0a93459e025 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 12 Feb 2025 16:24:27 +0100 Subject: [PATCH 08/16] Remove mentions of sparse variables from codebase --- src/lcm/dispatchers.py | 7 +++-- src/lcm/simulation/simulate.py | 46 ++++++++----------------------- tests/simulation/test_simulate.py | 4 +-- tests/test_dispatchers.py | 6 ++-- 4 files changed, 20 insertions(+), 43 deletions(-) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index b0fc82bf..063543d0 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -60,20 +60,21 @@ def spacemap( overlap = set(product_vars).intersection(combination_vars) if overlap: raise ValueError( - f"Dense and sparse variables must be disjoint. Overlap: {overlap}", + "Product and combination variables must be disjoint. Overlap: " + f"{overlap}", ) duplicates = {v for v in combination_vars if combination_vars.count(v) > 1} if duplicates: raise ValueError( - "Same argument provided more than once in sparse variables: " + "Same argument provided more than once in combination variables: " f"{duplicates}", ) # jax.vmap cannot deal with keyword-only arguments func = allow_args(func) - # Apply vmap_1d for sparse and _base_productmap for product variables + # Apply vmap_1d for combination variables and _base_productmap for product variables # ================================================================================== if not combination_vars: vmapped = _base_productmap(func, product_vars) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index c5b56043..928e99fd 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -81,7 +81,7 @@ def simulate( n_periods = len(vf_arr_list) n_initial_states = len(next(iter(initial_states.values()))) - _discrete_policy_calculator = get_discrete_policy_calculator( + discrete_policy_calculator = get_discrete_policy_calculator( variable_info=model.variable_info, ) @@ -96,12 +96,12 @@ def simulate( for period in range(n_periods): # Create data state choice space # ------------------------------------------------------------------------------ - # Initial states are treated as sparse variables, so that the sparse variables - # in the data-state-choice-space correspond to the feasible product of sparse - # choice variables and initial states. The space has to be created in each - # iteration because the states change over time. + # Initial states are treated as combination variables, so that the combination + # variables in the data-state-choice-space correspond to the feasible product + # of combination variables and initial states. The space has to be created in + # each iteration because the states change over time. # ============================================================================== - data_scs, data_choice_segments = create_data_scs( + data_scs, _ = create_data_scs( states=states, model=model, ) @@ -113,11 +113,6 @@ def simulate( len(grid) for grid in continuous_choice_grids[period].values() ) - discrete_policy_calculator = partial( - _discrete_policy_calculator, - choice_segments=data_choice_segments, - ) - # Compute optimal continuous choice conditional on discrete choices # ============================================================================== ccv_policy, ccv = solve_continuous_problem( @@ -130,7 +125,7 @@ def simulate( # Get optimal discrete choice given the optimal conditional continuous choices # ============================================================================== - discrete_argmax, sparse_argmax, value = discrete_policy_calculator(ccv) + discrete_argmax, value = discrete_policy_calculator(ccv) # Select optimal continuous choice corresponding to optimal discrete choice # ------------------------------------------------------------------------------ @@ -143,8 +138,6 @@ def simulate( discrete_argmax=discrete_argmax, vars_grid_shape=vars_grid_shape, ) - if sparse_argmax is not None: - cont_choice_argmax = cont_choice_argmax[sparse_argmax] # Convert optimal choice indices to actual choice values # ============================================================================== @@ -405,11 +398,6 @@ def filter_ccv_policy( return out -# ====================================================================================== -# Non-sparse choices -# ====================================================================================== - - def retrieve_choices(indices, grids, grid_shape): """Retrieve choices given indices. @@ -516,25 +504,12 @@ def get_discrete_policy_calculator(variable_info): values. The function depends on: - values (jax.Array): Multidimensional jax array with conditional continuation values. - - choice_segments (jax.Array): Jax array with the indices of the - choice segments that indicate which sparse choice variables belong to - one state. """ choice_axes = determine_discrete_choice_axes(variable_info) - def _calculate_discrete_argmax(values, choice_axes, choice_segments): # noqa: ARG001 - _max = values - - # Determine argmax and max over choices - # ============================================================================== - discrete_argmax, _max = argmax(_max, axis=choice_axes) - - # Determine argmax and max over sparse choices - # ============================================================================== - sparse_argmax = None - - return discrete_argmax, sparse_argmax, _max + def _calculate_discrete_argmax(values, choice_axes): + return argmax(values, axis=choice_axes) return partial(_calculate_discrete_argmax, choice_axes=choice_axes) @@ -578,7 +553,8 @@ def determine_discrete_choice_axes(variable_info): choice_vars = set(variable_info.query("is_choice").index.tolist()) - # We add 1 because the first dimension corresponds to the sparse state variables + # We must add 1 because the first dimension corresponds to the state variables, + # which are treated as combination variables during the simulation. return tuple( i + 1 for i, ax in enumerate(discrete_choice_vars) if ax in choice_vars ) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 3ca3eeca..d4524f64 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -368,7 +368,7 @@ def test_process_simulated_data(): assert tree_equal(expected, got) -def test_retrieve_non_sparse_choices(): +def test_retrieve_choices(): got = retrieve_choices( indices=jnp.array([0, 3, 7]), grids={"a": jnp.linspace(0, 1, 5), "b": jnp.linspace(10, 20, 6)}, @@ -378,7 +378,7 @@ def test_retrieve_non_sparse_choices(): assert_array_equal(got["b"], jnp.array([10, 16, 12])) -def test_retrieve_non_sparse_choices_no_indices(): +def test_retrieve_choices_no_indices(): got = retrieve_choices( indices=None, grids={"a": jnp.linspace(0, 1, 5), "b": jnp.linspace(10, 20, 6)}, diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index 3cea774f..219f4983 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -193,12 +193,12 @@ def setup_spacemap(): "b": jnp.array([3.0, 4]), } - sparse_values = { + combination_values = { "c": jnp.array([7.0, 8, 9, 10]), "d": jnp.array([9.0, 10, 11, 12, 13]), } - helper = jnp.array(list(itertools.product(*sparse_values.values()))).T + helper = jnp.array(list(itertools.product(*combination_values.values()))).T combination_grid = { "c": helper[0], @@ -245,7 +245,7 @@ def test_spacemap_all_arguments_mapped( ("error_msg", "product_vars", "combination_vars"), [ ( - "Dense and sparse variables must be disjoint. Overlap: {'a'}", + "Product and combination variables must be disjoint. Overlap: {'a'}", ["a", "b"], ["a", "c", "d"], ), From aa2db3076d3226bc1c5ea49e316b8b8db013ba92 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 12 Feb 2025 17:04:35 +0100 Subject: [PATCH 09/16] Improve interpretibility of SpaceInfo attributes --- src/lcm/function_representation.py | 33 +++++++++-------- src/lcm/interfaces.py | 40 ++++++++++----------- src/lcm/solution/state_space.py | 6 ++-- tests/solution/test_state_space.py | 6 ++-- tests/test_function_representation.py | 52 +++++++++++++-------------- 5 files changed, 68 insertions(+), 69 deletions(-) diff --git a/src/lcm/function_representation.py b/src/lcm/function_representation.py index ad5cc75d..3fce4b97 100644 --- a/src/lcm/function_representation.py +++ b/src/lcm/function_representation.py @@ -70,14 +70,14 @@ def get_function_representation( # check inputs # ================================================================================== _fail_if_interpolation_axes_are_not_last(space_info) - _need_interpolation = bool(space_info.interpolation_info) + _need_interpolation = bool(space_info.continuous_vars) # ================================================================================== # create functions to look up position of discrete variables from labels # ================================================================================== funcs = {} - for var in space_info.lookup_info: + for var in space_info.discrete_vars: funcs[f"__{var}_pos__"] = _get_label_translator( in_name=input_prefix + var, ) @@ -87,20 +87,20 @@ def get_function_representation( # ================================================================================== # lookup is positional, so the inputs of the wrapper functions need to be the # outcomes of tranlating labels into positions - _internal_axes = [f"__{var}_pos__" for var in space_info.axis_names] - _lookup_axes = [var for var in _internal_axes if var in funcs] + _internal_axes = [f"__{var}_pos__" for var in space_info.var_names] + _discrete_axes = [ax for ax in _internal_axes if ax in funcs] _out_name = "__interpolation_data__" if _need_interpolation else "__fval__" funcs[_out_name] = _get_lookup_function( array_name=name_of_values_on_grid, - axis_names=_lookup_axes, + axis_names=_discrete_axes, ) if _need_interpolation: # ============================================================================== # create functions to find coordinates for the interpolation # ============================================================================== - for var, grid_spec in space_info.interpolation_info.items(): + for var, grid_spec in space_info.continuous_vars.items(): funcs[f"__{var}_coord__"] = _get_coordinate_finder( in_name=input_prefix + var, grid=grid_spec, # type: ignore[arg-type] @@ -109,14 +109,14 @@ def get_function_representation( # ============================================================================== # create interpolation function # ============================================================================== - _interpolation_axes = [ + _continuous_axes = [ f"__{var}_coord__" - for var in space_info.axis_names - if var in space_info.interpolation_info + for var in space_info.var_names + if var in space_info.continuous_vars ] funcs["__fval__"] = _get_interpolator( name_of_values_on_grid="__interpolation_data__", - axis_names=_interpolation_axes, + axis_names=_continuous_axes, ) return concatenate_functions( @@ -238,21 +238,20 @@ def interpolate(*args, **kwargs): def _fail_if_interpolation_axes_are_not_last(space_info: SpaceInfo) -> None: - """Fail if the interpolation axes are not the last elements in axis_names. + """Fail if the continuous variables are not the last elements in var_names. Args: space_info: Class containing all information needed to interpret the precalculated values of a function. Raises: - ValueError: If the interpolation axes are not the last elements in axis_names. + ValueError: If the continuous variables are not the last elements in var_names. """ - common = set(space_info.interpolation_info) & set(space_info.axis_names) + common = set(space_info.continuous_vars) & set(space_info.var_names) if common: n_common = len(common) - if sorted(common) != sorted(space_info.axis_names[-n_common:]): - raise ValueError( - "Interpolation axes need to be the last entries in axis_order.", - ) + if sorted(common) != sorted(space_info.var_names[-n_common:]): + msg = "Continuous variables need to be the last entries in var_names." + raise ValueError(msg) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index f63c9c12..f38efac5 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -8,6 +8,25 @@ from lcm.typing import ParamsDict, ShockType +@dataclass(frozen=True) +class SpaceInfo: + """Information to work with the output of a function evaluated on a space. + + An example is the value function array, which is the output of the value function + evaluated on the state space. + + Attributes: + var_names: List with names of state variables. + discrete_vars: Dictionary with grids of discrete state variables. + continuous_vars: Dictionary with grids of continuous state variables. + + """ + + var_names: list[str] + discrete_vars: dict[str, DiscreteGrid] + continuous_vars: dict[str, ContinuousGrid] + + @dataclass(frozen=True) class SolutionSpace: """The state-choice space of a model used during the solution process. @@ -24,7 +43,7 @@ class SolutionSpace: """ vars: dict[str, Array] - state_space_info: "SpaceInfo" + state_space_info: SpaceInfo @dataclass(frozen=True) @@ -44,25 +63,6 @@ class SimulationSpace: choices: dict[str, Array] -@dataclass(frozen=True) -class SpaceInfo: - """Everything needed to work with the output of a function evaluated on a space. - - Attributes: - axis_names: List with axis names of an array that contains function values for - all elements in a space. - lookup_info: Dict that defines the possible labels of all discrete variables and - their order. - interpolation_info: Dict that defines information on the grids of all continuous - variables. - - """ - - axis_names: list[str] - lookup_info: dict[str, DiscreteGrid] - interpolation_info: dict[str, ContinuousGrid] - - @dataclass(frozen=True) class InternalModel: """Internal representation of a user model. diff --git a/src/lcm/solution/state_space.py b/src/lcm/solution/state_space.py index d94d2833..86fc6751 100644 --- a/src/lcm/solution/state_space.py +++ b/src/lcm/solution/state_space.py @@ -40,9 +40,9 @@ def create_state_choice_space( } state_space_info = SpaceInfo( - axis_names=discrete_states_names + continuous_states_names, - lookup_info=discrete_states, # type: ignore[arg-type] - interpolation_info=continuous_states, # type: ignore[arg-type] + var_names=discrete_states_names + continuous_states_names, + discrete_vars=discrete_states, # type: ignore[arg-type] + continuous_vars=continuous_states, # type: ignore[arg-type] ) return SolutionSpace( diff --git a/tests/solution/test_state_space.py b/tests/solution/test_state_space.py index 9209e1f8..aee2eb5a 100644 --- a/tests/solution/test_state_space.py +++ b/tests/solution/test_state_space.py @@ -29,6 +29,6 @@ def test_create_state_choice_space(): state_space_info = state_choice_space.state_space_info - assert state_space_info.axis_names == ["wealth"] - assert state_space_info.lookup_info == {} - assert state_space_info.interpolation_info == model.states + assert state_space_info.var_names == ["wealth"] + assert state_space_info.discrete_vars == {} + assert state_space_info.continuous_vars == model.states diff --git a/tests/test_function_representation.py b/tests/test_function_representation.py index da525d48..9734e721 100644 --- a/tests/test_function_representation.py +++ b/tests/test_function_representation.py @@ -23,9 +23,9 @@ def test_function_evaluator_with_one_continuous_variable(): wealth_grid = LinspaceGrid(start=-3, stop=3, n_points=7) space_info = SpaceInfo( - axis_names=["wealth"], - lookup_info={}, - interpolation_info={ + var_names=["wealth"], + discrete_vars={}, + continuous_vars={ "wealth": wealth_grid, }, ) @@ -52,9 +52,9 @@ def test_function_evaluator_with_one_discrete_variable(): vf_arr = jnp.array([1, 2]) space_info = SpaceInfo( - axis_names=["working"], - lookup_info={"working": [0, 1]}, - interpolation_info={}, + var_names=["working"], + discrete_vars={"working": [0, 1]}, + continuous_vars={}, ) # create the evaluator @@ -97,24 +97,24 @@ def test_function_evaluator(): vf_arr = discrete_part + cont_part # create info on discrete variables - lookup_info = { + discrete_vars = { "retired": [0, 1], "insured": [0, 1], } # create info on continuous grids - interpolation_info = { + continuous_vars = { "wealth": LinspaceGrid(start=100, stop=1100, n_points=6), "human_capital": LinspaceGrid(start=-3, stop=3, n_points=7), } # create info on axis of value function array - axis_names = ["retired", "insured", "wealth", "human_capital"] + var_names = ["retired", "insured", "wealth", "human_capital"] space_info = SpaceInfo( - axis_names=axis_names, - lookup_info=lookup_info, - interpolation_info=interpolation_info, + var_names=var_names, + discrete_vars=discrete_vars, + continuous_vars=continuous_vars, ) # create the evaluator @@ -210,9 +210,9 @@ def test_get_function_evaluator_illustrative(): a_grid = LinspaceGrid(start=0, stop=1, n_points=3) space_info = SpaceInfo( - axis_names=["a"], - lookup_info={}, - interpolation_info={ + var_names=["a"], + discrete_vars={}, + continuous_vars={ "a": a_grid, }, ) @@ -279,15 +279,15 @@ def f(a, b): @pytest.mark.illustrative def test_fail_if_interpolation_axes_are_not_last_illustrative(): - # Empty intersection of axis_names and interpolation_info + # Empty intersection of var_names and continuous_vars # ================================================================================== space_info = SpaceInfo( - axis_names=["a", "b"], - interpolation_info={ + var_names=["a", "b"], + continuous_vars={ "c": None, }, - lookup_info=None, + discrete_vars={}, ) _fail_if_interpolation_axes_are_not_last(space_info) # does not fail @@ -296,13 +296,13 @@ def test_fail_if_interpolation_axes_are_not_last_illustrative(): # ================================================================================== space_info = SpaceInfo( - axis_names=["a", "b", "c"], - interpolation_info={ + var_names=["a", "b", "c"], + continuous_vars={ "b": None, "c": None, "d": None, }, - lookup_info=None, + discrete_vars={}, ) _fail_if_interpolation_axes_are_not_last(space_info) # does not fail @@ -311,14 +311,14 @@ def test_fail_if_interpolation_axes_are_not_last_illustrative(): # ================================================================================== space_info = SpaceInfo( - axis_names=["b", "c", "a"], # "b", "c" are not last anymore - interpolation_info={ + var_names=["b", "c", "a"], # "b", "c" are not last anymore + continuous_vars={ "b": None, "c": None, "d": None, }, - lookup_info=None, + discrete_vars={}, ) - with pytest.raises(ValueError, match="Interpolation axes need to be the last"): + with pytest.raises(ValueError, match="Continuous variables need to be the last"): _fail_if_interpolation_axes_are_not_last(space_info) From e8f7a145b0bd3a39fb4073664cdedfe3f68bcd79 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 12 Feb 2025 17:18:35 +0100 Subject: [PATCH 10/16] Rename space_info -> state_space_info when appropriate --- src/lcm/entry_point.py | 8 ++++---- src/lcm/model_functions.py | 4 ++-- tests/simulation/test_simulate.py | 2 +- tests/test_entry_point.py | 8 ++++---- tests/test_model_functions.py | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index e2506157..7ee20386 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -74,7 +74,7 @@ def get_lcm_function( # Initialize other argument lists # ================================================================================== state_choice_spaces = [] - space_infos = [] + state_space_infos = [] compute_ccv_functions = [] compute_ccv_policy_functions = [] choice_segments = [] # type: ignore[var-annotated] @@ -95,12 +95,12 @@ def get_lcm_function( state_choice_spaces.append(sc_space) choice_segments.append(None) - space_infos.append(sc_space.state_space_info) + state_space_infos.append(sc_space.state_space_info) # ================================================================================== # Shift space info (in period t we require the space info of period t+1) # ================================================================================== - space_infos = space_infos[1:] + [{}] # type: ignore[list-item] + state_space_infos = state_space_infos[1:] + [{}] # type: ignore[list-item] # ================================================================================== # Create model functions @@ -112,7 +112,7 @@ def get_lcm_function( # ============================================================================== u_and_f = get_utility_and_feasibility_function( model=_mod, - space_info=space_infos[period], + state_space_info=state_space_infos[period], name_of_values_on_grid="vf_arr", period=period, is_last_period=is_last_period, diff --git a/src/lcm/model_functions.py b/src/lcm/model_functions.py index f400a1c1..ab08fd61 100644 --- a/src/lcm/model_functions.py +++ b/src/lcm/model_functions.py @@ -17,7 +17,7 @@ def get_utility_and_feasibility_function( model: InternalModel, - space_info: SpaceInfo, + state_space_info: SpaceInfo, name_of_values_on_grid: str, period: int, *, @@ -43,7 +43,7 @@ def get_utility_and_feasibility_function( next_weights = get_next_weights_function(model) scalar_value_function = get_function_representation( - space_info=space_info, + space_info=state_space_info, name_of_values_on_grid=name_of_values_on_grid, input_prefix="next_", ) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index d4524f64..8755f632 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -49,7 +49,7 @@ def simulate_inputs(): for period in range(model.n_periods): u_and_f = get_utility_and_feasibility_function( model=model, - space_info=sc_space.state_space_info, + state_space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=period, is_last_period=True, diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index cded434e..168f2ce0 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -189,7 +189,7 @@ def test_create_compute_conditional_continuation_value(): u_and_f = get_utility_and_feasibility_function( model=model, - space_info=sc_space.state_space_info, + state_space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -235,7 +235,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): u_and_f = get_utility_and_feasibility_function( model=model, - space_info=sc_space.state_space_info, + state_space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -286,7 +286,7 @@ def test_create_compute_conditional_continuation_policy(): u_and_f = get_utility_and_feasibility_function( model=model, - space_info=sc_space.state_space_info, + state_space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -333,7 +333,7 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): u_and_f = get_utility_and_feasibility_function( model=model, - space_info=sc_space.state_space_info, + state_space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index e834319d..d6cff927 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -37,7 +37,7 @@ def test_get_utility_and_feasibility_function(): u_and_f = get_utility_and_feasibility_function( model=model, - space_info=sc_space.state_space_info, + state_space_info=sc_space.state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, From 4e9448636aee5fdb035e73174bf402364115e9c3 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 12 Feb 2025 17:19:09 +0100 Subject: [PATCH 11/16] Do not run explanation notebooks on GHA --- .github/workflows/main.yml | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e1d33d5c..ae58d730 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -57,18 +57,18 @@ jobs: - name: Run mypy shell: bash {0} run: pixi run mypy - run-explanation-notebooks: - name: Run explanation notebooks on Python 3.12 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: prefix-dev/setup-pixi@v0.8.1 - with: - pixi-version: v0.40.3 - cache: true - cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} - environments: test-cpu - frozen: true - - name: Run explanation notebooks - shell: bash {0} - run: pixi run -e test-cpu explanation-notebooks + # run-explanation-notebooks: + # name: Run explanation notebooks on Python 3.12 + # runs-on: ubuntu-latest + # steps: + # - uses: actions/checkout@v4 + # - uses: prefix-dev/setup-pixi@v0.8.1 + # with: + # pixi-version: v0.40.3 + # cache: true + # cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} + # environments: test-cpu + # frozen: true + # - name: Run explanation notebooks + # shell: bash {0} + # run: pixi run -e test-cpu explanation-notebooks From cc67e36f42a319f6ced76f94233dc7ff8bb9e953 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 12 Feb 2025 18:12:20 +0100 Subject: [PATCH 12/16] Combine simulationa and solution space --- src/lcm/entry_point.py | 4 +- src/lcm/interfaces.py | 61 +++++++++++++----------------- src/lcm/simulation/simulate.py | 7 ++-- src/lcm/solution/solve_brute.py | 9 +++-- src/lcm/solution/state_space.py | 25 +++++++----- tests/simulation/test_simulate.py | 4 +- tests/solution/test_solve_brute.py | 17 +++++---- tests/solution/test_state_space.py | 14 +++---- tests/test_entry_point.py | 16 ++++---- tests/test_model_functions.py | 4 +- 10 files changed, 81 insertions(+), 80 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 7ee20386..f1dea574 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -88,14 +88,14 @@ def get_lcm_function( # call state space creation function, append trivial items to their lists # ============================================================================== - sc_space = create_state_choice_space( + sc_space, sc_space_info = create_state_choice_space( model=_mod, is_last_period=is_last_period, ) state_choice_spaces.append(sc_space) choice_segments.append(None) - state_space_infos.append(sc_space.state_space_info) + state_space_infos.append(sc_space_info) # ================================================================================== # Shift space info (in period t we require the space info of period t+1) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index f38efac5..8d0eb18c 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -9,58 +9,51 @@ @dataclass(frozen=True) -class SpaceInfo: - """Information to work with the output of a function evaluated on a space. +class StateChoiceSpace: + """The state-choice space. - An example is the value function array, which is the output of the value function - evaluated on the state space. + When used for the model solution: + --------------------------------- - Attributes: - var_names: List with names of state variables. - discrete_vars: Dictionary with grids of discrete state variables. - continuous_vars: Dictionary with grids of continuous state variables. + The state-choice space becomes the full Cartesian product of the state variables and + the choice variables. - """ - - var_names: list[str] - discrete_vars: dict[str, DiscreteGrid] - continuous_vars: dict[str, ContinuousGrid] + When used for the simulation: + ---------------------------- - -@dataclass(frozen=True) -class SolutionSpace: - """The state-choice space of a model used during the solution process. - - The state-choice space is the Cartesian product of the state variables and the - choice variables, stored here as a dictionary of one-dimensional arrays. The - continuous choice variables are handled outside of this class. + The state-choice space becomes the product of state-combinations with the full + Cartesian product of the choice variables. Attributes: - vars: Dictionary containing one dimensional grids of all variables, except for - continuous choice variables. - state_space_info: Information on the state variables. + states: Dictionary containing the values of the state variables. + choices: Dictionary containing the values of the choice variables. + ordered_var_names: List with names of state and choice variables in the order + they appear in the variable info table. """ - vars: dict[str, Array] - state_space_info: SpaceInfo + states: dict[str, Array] + choices: dict[str, Array] + ordered_var_names: list[str] @dataclass(frozen=True) -class SimulationSpace: - """The state-choice space of a model used during the simulation process. +class SpaceInfo: + """Information to work with the output of a function evaluated on a space. - The state-choice space is the product of the state variables with the Cartesian - product of the choice variables. + An example is the value function array, which is the output of the value function + evaluated on the state space. Attributes: - states: Dictionary containing the values of the state variables. - choices: Dictionary containing the values of the choice variables. + var_names: List with names of state variables. + discrete_vars: Dictionary with grids of discrete state variables. + continuous_vars: Dictionary with grids of continuous state variables. """ - states: dict[str, Array] - choices: dict[str, Array] + var_names: list[str] + discrete_vars: dict[str, DiscreteGrid] + continuous_vars: dict[str, ContinuousGrid] @dataclass(frozen=True) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 928e99fd..659b06dc 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -9,7 +9,7 @@ from lcm.argmax import argmax from lcm.dispatchers import spacemap, vmap_1d -from lcm.interfaces import InternalModel, SimulationSpace +from lcm.interfaces import InternalModel, StateChoiceSpace def simulate( @@ -201,7 +201,7 @@ def simulate( def solve_continuous_problem( - data_scs: SimulationSpace, + data_scs: StateChoiceSpace, compute_ccv, continuous_choice_grids, vf_arr, @@ -473,9 +473,10 @@ def create_data_scs( if name in vi.query("is_choice & is_discrete").index.tolist() } - data_scs = SimulationSpace( + data_scs = StateChoiceSpace( states=states, choices=choices, + ordered_var_names=vi.query("is_state | is_discrete").index.tolist(), ) # create choice segments diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index a148e344..87d05562 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -1,7 +1,7 @@ import jax from lcm.dispatchers import spacemap -from lcm.interfaces import SolutionSpace +from lcm.interfaces import StateChoiceSpace def solve( @@ -74,7 +74,7 @@ def solve( def solve_continuous_problem( - state_choice_space: SolutionSpace, + state_choice_space: StateChoiceSpace, compute_ccv, continuous_choice_grids, vf_arr, @@ -104,12 +104,13 @@ def solve_continuous_problem( """ _gridmapped = spacemap( func=compute_ccv, - product_vars=list(state_choice_space.vars), + product_vars=state_choice_space.ordered_var_names, ) gridmapped = jax.jit(_gridmapped) return gridmapped( - **state_choice_space.vars, + **state_choice_space.states, + **state_choice_space.choices, **continuous_choice_grids, vf_arr=vf_arr, params=params, diff --git a/src/lcm/solution/state_space.py b/src/lcm/solution/state_space.py index 86fc6751..0f1f56d9 100644 --- a/src/lcm/solution/state_space.py +++ b/src/lcm/solution/state_space.py @@ -1,13 +1,13 @@ """Create a state space for a given model.""" -from lcm.interfaces import InternalModel, SolutionSpace, SpaceInfo +from lcm.interfaces import InternalModel, SpaceInfo, StateChoiceSpace def create_state_choice_space( model: InternalModel, *, is_last_period: bool, -) -> SolutionSpace: +) -> tuple[StateChoiceSpace, SpaceInfo]: """Create a state-choice-space for the model solution. A state-choice-space is a compressed representation of all feasible states and the @@ -18,9 +18,11 @@ def create_state_choice_space( is_last_period: Whether the function is created for the last period. Returns: - SolutionSpace: An object containing the variable values of all variables in the + tuple[StateChoiceSpace, SpaceInfo]: + - An object containing the variable values of all variables in the state-choice-space, the grid specifications for the state variables, and the names of the state variables. Continuous choice variables are not included. + - The state-space information. """ vi = model.variable_info @@ -33,11 +35,11 @@ def create_state_choice_space( discrete_states = {sn: model.gridspecs[sn] for sn in discrete_states_names} continuous_states = {sn: model.gridspecs[sn] for sn in continuous_states_names} - # Create a dictionary with all state and choice variables and their feasible values, - # except for continuous choice variables, since they are treated differently. - space_grids = { - sn: model.grids[sn] for sn in vi.query("is_state | is_discrete").index.tolist() + state_grids = {sn: model.grids[sn] for sn in vi.query("is_state").index.tolist()} + choice_grids = { + sn: model.grids[sn] for sn in vi.query("is_choice & is_discrete").index.tolist() } + ordered_var_names = vi.query("is_state | is_discrete").index.tolist() state_space_info = SpaceInfo( var_names=discrete_states_names + continuous_states_names, @@ -45,7 +47,10 @@ def create_state_choice_space( continuous_vars=continuous_states, # type: ignore[arg-type] ) - return SolutionSpace( - vars=space_grids, - state_space_info=state_space_info, + state_choice_space = StateChoiceSpace( + states=state_grids, + choices=choice_grids, + ordered_var_names=ordered_var_names, ) + + return state_choice_space, state_space_info diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 8755f632..091622db 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -40,7 +40,7 @@ def simulate_inputs(): model_config = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=1) model = process_model(model_config) - sc_space = create_state_choice_space( + _, sc_space_info = create_state_choice_space( model=model, is_last_period=False, ) @@ -49,7 +49,7 @@ def simulate_inputs(): for period in range(model.n_periods): u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space.state_space_info, + state_space_info=sc_space_info, name_of_values_on_grid="vf_arr", period=period, is_last_period=True, diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index f04ada8d..4b95fe3d 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -3,7 +3,7 @@ from numpy.testing import assert_array_almost_equal as aaae from lcm.entry_point import create_compute_conditional_continuation_value -from lcm.interfaces import SolutionSpace +from lcm.interfaces import StateChoiceSpace from lcm.logging import get_logger from lcm.ndimage import map_coordinates from lcm.solution.solve_brute import solve, solve_continuous_problem @@ -25,16 +25,18 @@ def test_solve_brute(): # ================================================================================== # create the list of state_choice_spaces # ================================================================================== - _scs = SolutionSpace( - vars={ + _scs = StateChoiceSpace( + choices={ # pick [0, 1] such that no label translation is needed # lazy is like a type, it influences utility but is not affected by choices "lazy": jnp.array([0, 1]), "working": jnp.array([0, 1]), + }, + states={ # pick [0, 1, 2] such that no coordinate mapping is needed "wealth": jnp.array([0.0, 1.0, 2.0]), }, - state_space_info=None, + ordered_var_names=["lazy", "working", "wealth"], ) state_choice_spaces = [_scs] * 2 @@ -111,13 +113,14 @@ def calculate_emax(values, params): # noqa: ARG001 def test_solve_continuous_problem_no_vf_arr(): - state_choice_space = SolutionSpace( - vars={ + state_choice_space = StateChoiceSpace( + choices={ "a": jnp.array([0, 1.0]), "b": jnp.array([2, 3.0]), "c": jnp.array([4, 5, 6]), }, - state_space_info=None, + states={}, + ordered_var_names=["a", "b", "c"], ) def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 diff --git a/tests/solution/test_state_space.py b/tests/solution/test_state_space.py index aee2eb5a..eba02565 100644 --- a/tests/solution/test_state_space.py +++ b/tests/solution/test_state_space.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from lcm.input_processing import process_model -from lcm.interfaces import SolutionSpace, SpaceInfo +from lcm.interfaces import SpaceInfo, StateChoiceSpace from lcm.solution.state_space import ( create_state_choice_space, ) @@ -12,23 +12,21 @@ def test_create_state_choice_space(): model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) internal_model = process_model(model) - state_choice_space = create_state_choice_space( + state_choice_space, state_space_info = create_state_choice_space( model=internal_model, is_last_period=False, ) - assert isinstance(state_choice_space, SolutionSpace) - assert isinstance(state_choice_space.state_space_info, SpaceInfo) + assert isinstance(state_choice_space, StateChoiceSpace) + assert isinstance(state_space_info, SpaceInfo) assert jnp.array_equal( - state_choice_space.vars["retirement"], model.choices["retirement"].to_jax() + state_choice_space.choices["retirement"], model.choices["retirement"].to_jax() ) assert jnp.array_equal( - state_choice_space.vars["wealth"], model.states["wealth"].to_jax() + state_choice_space.states["wealth"], model.states["wealth"].to_jax() ) - state_space_info = state_choice_space.state_space_info - assert state_space_info.var_names == ["wealth"] assert state_space_info.discrete_vars == {} assert state_space_info.continuous_vars == model.states diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 168f2ce0..41971f8a 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -182,14 +182,14 @@ def test_create_compute_conditional_continuation_value(): }, } - sc_space = create_state_choice_space( + _, sc_space_info = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space.state_space_info, + state_space_info=sc_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -228,14 +228,14 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): }, } - sc_space = create_state_choice_space( + _, sc_space_info = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space.state_space_info, + state_space_info=sc_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -279,14 +279,14 @@ def test_create_compute_conditional_continuation_policy(): }, } - sc_space = create_state_choice_space( + _, sc_space_info = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space.state_space_info, + state_space_info=sc_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -326,14 +326,14 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): }, } - sc_space = create_state_choice_space( + _, sc_space_info = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space.state_space_info, + state_space_info=sc_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index d6cff927..eb92e3df 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -30,14 +30,14 @@ def test_get_utility_and_feasibility_function(): }, } - sc_space = create_state_choice_space( + _, sc_space_info = create_state_choice_space( model=model, is_last_period=False, ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space.state_space_info, + state_space_info=sc_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, From b82c026a5f745a211c0300c34d171c3ebf1f64a3 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 12 Feb 2025 18:17:38 +0100 Subject: [PATCH 13/16] Minor fix --- src/lcm/entry_point.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index f1dea574..f02c0936 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -88,14 +88,14 @@ def get_lcm_function( # call state space creation function, append trivial items to their lists # ============================================================================== - sc_space, sc_space_info = create_state_choice_space( + sc_space, state_space_info = create_state_choice_space( model=_mod, is_last_period=is_last_period, ) state_choice_spaces.append(sc_space) choice_segments.append(None) - state_space_infos.append(sc_space_info) + state_space_infos.append(state_space_info) # ================================================================================== # Shift space info (in period t we require the space info of period t+1) From 3bd6c444bdce5b57a63b926b6b5fe396fbe839e4 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 13 Feb 2025 11:13:37 +0100 Subject: [PATCH 14/16] Integrate comments from review --- src/lcm/discrete_problem.py | 14 +++-- src/lcm/entry_point.py | 6 +- src/lcm/interfaces.py | 8 +-- src/lcm/simulation/simulate.py | 58 +++++++++---------- src/lcm/solution/solve_brute.py | 2 +- .../{state_space.py => state_choice_space.py} | 4 +- tests/simulation/test_simulate.py | 21 ++----- tests/solution/test_solve_brute.py | 4 +- tests/solution/test_state_space.py | 4 +- tests/test_entry_point.py | 26 ++++----- tests/test_model_functions.py | 8 +-- 11 files changed, 73 insertions(+), 82 deletions(-) rename src/lcm/solution/{state_space.py => state_choice_space.py} (92%) diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index f9408475..346875b6 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -142,11 +142,13 @@ def _determine_discrete_choice_axes( variable_info: DataFrame with information about the variables. Returns: - tuple[int, ...]: A tuple of indices representing the axes' positions in - the value function that correspond to discrete choices. + A tuple of indices representing the axes' positions in the value function that + correspond to discrete choices. """ - # List of all model variables excluding the continuous choice variables. - axes = variable_info.query("is_state | is_discrete").index.tolist() - choice_vars = set(variable_info.query("is_choice").index.tolist()) - return tuple(i for i, ax in enumerate(axes) if ax in choice_vars) + discrete_choice_vars = set( + variable_info.query("is_choice & is_discrete").index.tolist() + ) + return tuple( + i for i, ax in enumerate(variable_info.index) if ax in discrete_choice_vars + ) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index f02c0936..545f1590 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -17,7 +17,7 @@ from lcm.next_state import get_next_state_function from lcm.simulation.simulate import simulate from lcm.solution.solve_brute import solve -from lcm.solution.state_space import create_state_choice_space +from lcm.solution.state_choice_space import create_state_choice_space from lcm.typing import ParamsDict from lcm.user_model import Model @@ -88,12 +88,12 @@ def get_lcm_function( # call state space creation function, append trivial items to their lists # ============================================================================== - sc_space, state_space_info = create_state_choice_space( + state_choice_space, state_space_info = create_state_choice_space( model=_mod, is_last_period=is_last_period, ) - state_choice_spaces.append(sc_space) + state_choice_spaces.append(state_choice_space) choice_segments.append(None) state_space_infos.append(state_space_info) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 8d0eb18c..7379ccfc 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -27,14 +27,14 @@ class StateChoiceSpace: Attributes: states: Dictionary containing the values of the state variables. choices: Dictionary containing the values of the choice variables. - ordered_var_names: List with names of state and choice variables in the order + ordered_var_names: Tuple with names of state and choice variables in the order they appear in the variable info table. """ states: dict[str, Array] choices: dict[str, Array] - ordered_var_names: list[str] + ordered_var_names: tuple[str, ...] @dataclass(frozen=True) @@ -45,13 +45,13 @@ class SpaceInfo: evaluated on the state space. Attributes: - var_names: List with names of state variables. + var_names: Tuple with names of state variables. discrete_vars: Dictionary with grids of discrete state variables. continuous_vars: Dictionary with grids of continuous state variables. """ - var_names: list[str] + var_names: tuple[str, ...] discrete_vars: dict[str, DiscreteGrid] continuous_vars: dict[str, ContinuousGrid] diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 659b06dc..d15074b9 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import pandas as pd from dags import concatenate_functions -from jax import vmap +from jax import Array, vmap from lcm.argmax import argmax from lcm.dispatchers import spacemap, vmap_1d @@ -101,15 +101,15 @@ def simulate( # of combination variables and initial states. The space has to be created in # each iteration because the states change over time. # ============================================================================== - data_scs, _ = create_data_scs( + data_scs = create_data_scs( states=states, model=model, - ) + )[0] # Compute objects dependent on data-state-choice-space # ============================================================================== - vars_grid_shape = tuple(len(grid) for grid in data_scs.choices.values()) - cont_choice_grid_shape = tuple( + choices_grid_shape = tuple(len(grid) for grid in data_scs.choices.values()) + cont_choices_grid_shape = tuple( len(grid) for grid in continuous_choice_grids[period].values() ) @@ -136,21 +136,21 @@ def simulate( cont_choice_argmax = filter_ccv_policy( ccv_policy=ccv_policy, discrete_argmax=discrete_argmax, - vars_grid_shape=vars_grid_shape, + vars_grid_shape=choices_grid_shape, ) # Convert optimal choice indices to actual choice values # ============================================================================== choices = retrieve_choices( - indices=discrete_argmax, + flat_indices=discrete_argmax, grids=data_scs.choices, - grid_shape=vars_grid_shape, + grids_shapes=choices_grid_shape, ) cont_choices = retrieve_choices( - indices=cont_choice_argmax, + flat_indices=cont_choice_argmax, grids=continuous_choice_grids[period], - grid_shape=cont_choice_grid_shape, + grids_shapes=cont_choices_grid_shape, ) # Store results @@ -398,28 +398,27 @@ def filter_ccv_policy( return out -def retrieve_choices(indices, grids, grid_shape): - """Retrieve choices given indices. +def retrieve_choices( + flat_indices: Array, + grids: dict[str, Array], + grids_shapes: tuple[int, ...], +) -> dict[str, Array]: + """Retrieve choices given flat indices. Args: - indices (jnp.numpy.ndarray or None): General indices. Represents the index of - the flattened grid. - grids (dict): Dictionary of grids. - grid_shape (tuple): Shape of the grids. Is used to unravel the index. + flat_indices: General indices. Represents the index of the flattened grid. + grids: Dictionary of grid values. + grids_shapes: Shape of the grids. Is used to unravel the index. Returns: - dict: Dictionary of choices. + Dictionary of choices. """ - if indices is None: - out = {} - else: - indices = vmapped_unravel_index(indices, grid_shape) - out = { - name: grid[index] - for (name, grid), index in zip(grids.items(), indices, strict=True) - } - return out + nd_indices = vmapped_unravel_index(flat_indices, grids_shapes) + return { + name: grid[index] + for (name, grid), index in zip(grids.items(), nd_indices, strict=True) + } # vmap jnp.unravel_index over the first axis of the `indices` argument, while holding @@ -476,7 +475,7 @@ def create_data_scs( data_scs = StateChoiceSpace( states=states, choices=choices, - ordered_var_names=vi.query("is_state | is_discrete").index.tolist(), + ordered_var_names=tuple(vi.query("is_state | is_discrete").index.tolist()), ) # create choice segments @@ -554,8 +553,7 @@ def determine_discrete_choice_axes(variable_info): choice_vars = set(variable_info.query("is_choice").index.tolist()) - # We must add 1 because the first dimension corresponds to the state variables, - # which are treated as combination variables during the simulation. + # The first dimension corresponds to the simulated states, so add 1. return tuple( - i + 1 for i, ax in enumerate(discrete_choice_vars) if ax in choice_vars + 1 + i for i, ax in enumerate(discrete_choice_vars) if ax in choice_vars ) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 87d05562..b5a3f05b 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -104,7 +104,7 @@ def solve_continuous_problem( """ _gridmapped = spacemap( func=compute_ccv, - product_vars=state_choice_space.ordered_var_names, + product_vars=list(state_choice_space.ordered_var_names), ) gridmapped = jax.jit(_gridmapped) diff --git a/src/lcm/solution/state_space.py b/src/lcm/solution/state_choice_space.py similarity index 92% rename from src/lcm/solution/state_space.py rename to src/lcm/solution/state_choice_space.py index 0f1f56d9..b7c3f8ac 100644 --- a/src/lcm/solution/state_space.py +++ b/src/lcm/solution/state_choice_space.py @@ -39,10 +39,10 @@ def create_state_choice_space( choice_grids = { sn: model.grids[sn] for sn in vi.query("is_choice & is_discrete").index.tolist() } - ordered_var_names = vi.query("is_state | is_discrete").index.tolist() + ordered_var_names = tuple(vi.query("is_state | is_discrete").index.tolist()) state_space_info = SpaceInfo( - var_names=discrete_states_names + continuous_states_names, + var_names=tuple(discrete_states_names + continuous_states_names), discrete_vars=discrete_states, # type: ignore[arg-type] continuous_vars=continuous_states, # type: ignore[arg-type] ) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 091622db..f7e37f20 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -24,7 +24,7 @@ retrieve_choices, simulate, ) -from lcm.solution.state_space import create_state_choice_space +from lcm.solution.state_choice_space import create_state_choice_space from tests.test_models import ( get_model_config, get_params, @@ -40,16 +40,16 @@ def simulate_inputs(): model_config = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=1) model = process_model(model_config) - _, sc_space_info = create_state_choice_space( + state_space_info = create_state_choice_space( model=model, is_last_period=False, - ) + )[1] compute_ccv_policy_functions = [] for period in range(model.n_periods): u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space_info, + state_space_info=state_space_info, name_of_values_on_grid="vf_arr", period=period, is_last_period=True, @@ -370,23 +370,14 @@ def test_process_simulated_data(): def test_retrieve_choices(): got = retrieve_choices( - indices=jnp.array([0, 3, 7]), + flat_indices=jnp.array([0, 3, 7]), grids={"a": jnp.linspace(0, 1, 5), "b": jnp.linspace(10, 20, 6)}, - grid_shape=(5, 6), + grids_shapes=(5, 6), ) assert_array_equal(got["a"], jnp.array([0, 0, 0.25])) assert_array_equal(got["b"], jnp.array([10, 16, 12])) -def test_retrieve_choices_no_indices(): - got = retrieve_choices( - indices=None, - grids={"a": jnp.linspace(0, 1, 5), "b": jnp.linspace(10, 20, 6)}, - grid_shape=(5, 6), - ) - assert got == {} - - def test_filter_ccv_policy(): ccc_policy = jnp.array( [ diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 4b95fe3d..5b0450b6 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -36,7 +36,7 @@ def test_solve_brute(): # pick [0, 1, 2] such that no coordinate mapping is needed "wealth": jnp.array([0.0, 1.0, 2.0]), }, - ordered_var_names=["lazy", "working", "wealth"], + ordered_var_names=("lazy", "working", "wealth"), ) state_choice_spaces = [_scs] * 2 @@ -120,7 +120,7 @@ def test_solve_continuous_problem_no_vf_arr(): "c": jnp.array([4, 5, 6]), }, states={}, - ordered_var_names=["a", "b", "c"], + ordered_var_names=("a", "b", "c"), ) def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 diff --git a/tests/solution/test_state_space.py b/tests/solution/test_state_space.py index eba02565..962ae898 100644 --- a/tests/solution/test_state_space.py +++ b/tests/solution/test_state_space.py @@ -2,7 +2,7 @@ from lcm.input_processing import process_model from lcm.interfaces import SpaceInfo, StateChoiceSpace -from lcm.solution.state_space import ( +from lcm.solution.state_choice_space import ( create_state_choice_space, ) from tests.test_models import get_model_config @@ -27,6 +27,6 @@ def test_create_state_choice_space(): state_choice_space.states["wealth"], model.states["wealth"].to_jax() ) - assert state_space_info.var_names == ["wealth"] + assert state_space_info.var_names == ("wealth",) assert state_space_info.discrete_vars == {} assert state_space_info.continuous_vars == model.states diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 41971f8a..b183a095 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -9,7 +9,7 @@ ) from lcm.input_processing import process_model from lcm.model_functions import get_utility_and_feasibility_function -from lcm.solution.state_space import create_state_choice_space +from lcm.solution.state_choice_space import create_state_choice_space from tests.test_models import get_model_config from tests.test_models.deterministic import RetirementStatus from tests.test_models.deterministic import utility as iskhakov_et_al_2017_utility @@ -182,14 +182,14 @@ def test_create_compute_conditional_continuation_value(): }, } - _, sc_space_info = create_state_choice_space( + state_space_info = create_state_choice_space( model=model, is_last_period=False, - ) + )[1] u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space_info, + state_space_info=state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -228,14 +228,14 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): }, } - _, sc_space_info = create_state_choice_space( + state_space_info = create_state_choice_space( model=model, is_last_period=False, - ) + )[1] u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space_info, + state_space_info=state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -279,14 +279,14 @@ def test_create_compute_conditional_continuation_policy(): }, } - _, sc_space_info = create_state_choice_space( + state_space_info = create_state_choice_space( model=model, is_last_period=False, - ) + )[1] u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space_info, + state_space_info=state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, @@ -326,14 +326,14 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): }, } - _, sc_space_info = create_state_choice_space( + state_space_info = create_state_choice_space( model=model, is_last_period=False, - ) + )[1] u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space_info, + state_space_info=state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index eb92e3df..633fc45b 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -10,7 +10,7 @@ get_multiply_weights, get_utility_and_feasibility_function, ) -from lcm.solution.state_space import create_state_choice_space +from lcm.solution.state_choice_space import create_state_choice_space from tests.test_models import get_model_config from tests.test_models.deterministic import utility @@ -30,14 +30,14 @@ def test_get_utility_and_feasibility_function(): }, } - _, sc_space_info = create_state_choice_space( + state_space_info = create_state_choice_space( model=model, is_last_period=False, - ) + )[1] u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=sc_space_info, + state_space_info=state_space_info, name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, From a5b900a48c6adc087fe9f9a1f6590e8ad3bd498e Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 13 Feb 2025 11:56:10 +0100 Subject: [PATCH 15/16] Make clear the function representation is used for the value function array --- src/lcm/entry_point.py | 1 - src/lcm/function_representation.py | 90 ++++++++++----------- src/lcm/interfaces.py | 10 +-- src/lcm/model_functions.py | 13 +--- src/lcm/solution/state_choice_space.py | 12 +-- tests/simulation/test_simulate.py | 1 - tests/solution/test_state_space.py | 10 +-- tests/test_entry_point.py | 4 - tests/test_function_representation.py | 103 ++++++++++++------------- tests/test_model_functions.py | 1 - 10 files changed, 115 insertions(+), 130 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 545f1590..4606f514 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -113,7 +113,6 @@ def get_lcm_function( u_and_f = get_utility_and_feasibility_function( model=_mod, state_space_info=state_space_infos[period], - name_of_values_on_grid="vf_arr", period=period, is_last_period=is_last_period, ) diff --git a/src/lcm/function_representation.py b/src/lcm/function_representation.py index 3fce4b97..d17ede0d 100644 --- a/src/lcm/function_representation.py +++ b/src/lcm/function_representation.py @@ -7,40 +7,45 @@ from lcm.functools import all_as_kwargs from lcm.grids import ContinuousGrid -from lcm.interfaces import SpaceInfo +from lcm.interfaces import StateSpaceInfo from lcm.ndimage import map_coordinates -def get_function_representation( - space_info: SpaceInfo, - name_of_values_on_grid: str, +def get_value_function_representation( + state_space_info: StateSpaceInfo, *, - input_prefix: str = "", + input_prefix: str = "next_", + name_of_values_on_grid: str = "vf_arr", ) -> Callable[..., Array]: - """Create a function representation of pre-calculated values on a grid. + """Create a function representation of the value function array. + + The returned function + --------------------- - An example of a pre-calculated function is a value or policy function. These are - evaluated on the space of all discrete and continuous state variables. + This function generates a function that looks up discrete values and interpolates + values for continuous variables on the value function array. The arguments of the + resulting function can be split in two categories: - This function dynamically generates a function that looks up and interpolates values - of the pre-calculated function. The arguments of the resulting function can be split - in two categories: - 1. Helper arguments such as information about the grid and the pre-calculated - values of the function. - 2. The original arguments of the function that was pre-calculated on the grid. + 1. The original arguments of the function that was used to pre-calculate the + value function on the state space grid. + + 2. Auxiliary arguments, such as information about the grids, which are needed for + example, for the interpolation. After partialling in all helper arguments, the resulting function behaves like an - analytical function. In particular, it can be jitted, differentiated and vmapped - with jax. + analytical function, i.e. it can be evaluated on points that do not lie on the grid + points of the state variables. In particular, it can also be jitted, differentiated + and vmapped with jax. + + How does it work? + ----------------- The resulting function roughly does the following steps: - - Translate values of discrete variables into positions - - Index into the array with the pre-calculated function values to extract only the - part on which interpolation is needed. - - Translate values of continuous variables into coordinates needed for interpolation - via jax.scipy.ndimage.map_coordinates. - - Do the actual interpolation. + - It translates values of discrete variables into positions. + - It translates values of continuous variables into coordinates needed for + interpolation via jax.scipy.ndimage.map_coordinates. + - It performs the interpolation. Depending on the grid, only a subset of these steps is relevant. The chosen implementation of each step is also adjusted to the type of grid. In particular we @@ -49,35 +54,32 @@ def get_function_representation( functions are called is determined by a DAG. Args: - space_info: Class containing all information needed to interpret the + state_space_info: Class containing all information needed to interpret the pre-calculated values of a function. + input_prefix: Prefix that will be added to all argument names of the resulting + function, except for the helpers arguments. Default is "next_"; since the + value function is typically evaluated on the next period's state space. name_of_values_on_grid: The name of the argument via which the pre-calculated values, that have been evaluated on the state-space grid, will be passed - into the resulting function. In the value function case, this could be - 'vf_arr', in which case, one would partial in 'vf_arr' into the - representation. - input_prefix: Prefix that will be added to all argument names of the resulting - function, except for the helpers arguments such as the value arrays. - Default is the empty string. The prefix needs to contain the separator. E.g. - `next_` if an undescore should be used as separator. + into the resulting function. Defaults to "vf_arr". Returns: - callable: A callable that lets you evaluate a function defined be precalculated - values on space formed by discrete and continuous grids. + A callable that lets you treat the result of pre-calculating a function on the + state space as an analytical function. """ # ================================================================================== # check inputs # ================================================================================== - _fail_if_interpolation_axes_are_not_last(space_info) - _need_interpolation = bool(space_info.continuous_vars) + _fail_if_interpolation_axes_are_not_last(state_space_info) + _need_interpolation = bool(state_space_info.continuous_states) # ================================================================================== # create functions to look up position of discrete variables from labels # ================================================================================== funcs = {} - for var in space_info.discrete_vars: + for var in state_space_info.discrete_states: funcs[f"__{var}_pos__"] = _get_label_translator( in_name=input_prefix + var, ) @@ -87,7 +89,7 @@ def get_function_representation( # ================================================================================== # lookup is positional, so the inputs of the wrapper functions need to be the # outcomes of tranlating labels into positions - _internal_axes = [f"__{var}_pos__" for var in space_info.var_names] + _internal_axes = [f"__{var}_pos__" for var in state_space_info.states_names] _discrete_axes = [ax for ax in _internal_axes if ax in funcs] _out_name = "__interpolation_data__" if _need_interpolation else "__fval__" @@ -100,7 +102,7 @@ def get_function_representation( # ============================================================================== # create functions to find coordinates for the interpolation # ============================================================================== - for var, grid_spec in space_info.continuous_vars.items(): + for var, grid_spec in state_space_info.continuous_states.items(): funcs[f"__{var}_coord__"] = _get_coordinate_finder( in_name=input_prefix + var, grid=grid_spec, # type: ignore[arg-type] @@ -111,8 +113,8 @@ def get_function_representation( # ============================================================================== _continuous_axes = [ f"__{var}_coord__" - for var in space_info.var_names - if var in space_info.continuous_vars + for var in state_space_info.states_names + if var in state_space_info.continuous_states ] funcs["__fval__"] = _get_interpolator( name_of_values_on_grid="__interpolation_data__", @@ -237,21 +239,23 @@ def interpolate(*args, **kwargs): return interpolate -def _fail_if_interpolation_axes_are_not_last(space_info: SpaceInfo) -> None: +def _fail_if_interpolation_axes_are_not_last(state_space_info: StateSpaceInfo) -> None: """Fail if the continuous variables are not the last elements in var_names. Args: - space_info: Class containing all information needed to interpret the + state_space_info: Class containing all information needed to interpret the precalculated values of a function. Raises: ValueError: If the continuous variables are not the last elements in var_names. """ - common = set(space_info.continuous_vars) & set(space_info.var_names) + common = set(state_space_info.continuous_states) & set( + state_space_info.states_names + ) if common: n_common = len(common) - if sorted(common) != sorted(space_info.var_names[-n_common:]): + if sorted(common) != sorted(state_space_info.states_names[-n_common:]): msg = "Continuous variables need to be the last entries in var_names." raise ValueError(msg) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 7379ccfc..09c4f109 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -38,8 +38,8 @@ class StateChoiceSpace: @dataclass(frozen=True) -class SpaceInfo: - """Information to work with the output of a function evaluated on a space. +class StateSpaceInfo: + """Information to work with the output of a function evaluated on a state space. An example is the value function array, which is the output of the value function evaluated on the state space. @@ -51,9 +51,9 @@ class SpaceInfo: """ - var_names: tuple[str, ...] - discrete_vars: dict[str, DiscreteGrid] - continuous_vars: dict[str, ContinuousGrid] + states_names: tuple[str, ...] + discrete_states: dict[str, DiscreteGrid] + continuous_states: dict[str, ContinuousGrid] @dataclass(frozen=True) diff --git a/src/lcm/model_functions.py b/src/lcm/model_functions.py index ab08fd61..cd107f94 100644 --- a/src/lcm/model_functions.py +++ b/src/lcm/model_functions.py @@ -5,20 +5,19 @@ from dags.signature import with_signature from lcm.dispatchers import productmap -from lcm.function_representation import get_function_representation +from lcm.function_representation import get_value_function_representation from lcm.functools import ( all_as_args, all_as_kwargs, get_union_of_arguments, ) -from lcm.interfaces import InternalModel, SpaceInfo +from lcm.interfaces import InternalModel, StateSpaceInfo from lcm.next_state import get_next_state_function def get_utility_and_feasibility_function( model: InternalModel, - state_space_info: SpaceInfo, - name_of_values_on_grid: str, + state_space_info: StateSpaceInfo, period: int, *, is_last_period: bool, @@ -42,11 +41,7 @@ def get_utility_and_feasibility_function( next_state = get_next_state_function(model, target="solve") next_weights = get_next_weights_function(model) - scalar_value_function = get_function_representation( - space_info=state_space_info, - name_of_values_on_grid=name_of_values_on_grid, - input_prefix="next_", - ) + scalar_value_function = get_value_function_representation(state_space_info) multiply_weights = get_multiply_weights(stochastic_variables) diff --git a/src/lcm/solution/state_choice_space.py b/src/lcm/solution/state_choice_space.py index b7c3f8ac..21657811 100644 --- a/src/lcm/solution/state_choice_space.py +++ b/src/lcm/solution/state_choice_space.py @@ -1,13 +1,13 @@ """Create a state space for a given model.""" -from lcm.interfaces import InternalModel, SpaceInfo, StateChoiceSpace +from lcm.interfaces import InternalModel, StateChoiceSpace, StateSpaceInfo def create_state_choice_space( model: InternalModel, *, is_last_period: bool, -) -> tuple[StateChoiceSpace, SpaceInfo]: +) -> tuple[StateChoiceSpace, StateSpaceInfo]: """Create a state-choice-space for the model solution. A state-choice-space is a compressed representation of all feasible states and the @@ -41,10 +41,10 @@ def create_state_choice_space( } ordered_var_names = tuple(vi.query("is_state | is_discrete").index.tolist()) - state_space_info = SpaceInfo( - var_names=tuple(discrete_states_names + continuous_states_names), - discrete_vars=discrete_states, # type: ignore[arg-type] - continuous_vars=continuous_states, # type: ignore[arg-type] + state_space_info = StateSpaceInfo( + states_names=tuple(discrete_states_names + continuous_states_names), + discrete_states=discrete_states, # type: ignore[arg-type] + continuous_states=continuous_states, # type: ignore[arg-type] ) state_choice_space = StateChoiceSpace( diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index f7e37f20..a84e307e 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -50,7 +50,6 @@ def simulate_inputs(): u_and_f = get_utility_and_feasibility_function( model=model, state_space_info=state_space_info, - name_of_values_on_grid="vf_arr", period=period, is_last_period=True, ) diff --git a/tests/solution/test_state_space.py b/tests/solution/test_state_space.py index 962ae898..f4fe1341 100644 --- a/tests/solution/test_state_space.py +++ b/tests/solution/test_state_space.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from lcm.input_processing import process_model -from lcm.interfaces import SpaceInfo, StateChoiceSpace +from lcm.interfaces import StateChoiceSpace, StateSpaceInfo from lcm.solution.state_choice_space import ( create_state_choice_space, ) @@ -18,7 +18,7 @@ def test_create_state_choice_space(): ) assert isinstance(state_choice_space, StateChoiceSpace) - assert isinstance(state_space_info, SpaceInfo) + assert isinstance(state_space_info, StateSpaceInfo) assert jnp.array_equal( state_choice_space.choices["retirement"], model.choices["retirement"].to_jax() @@ -27,6 +27,6 @@ def test_create_state_choice_space(): state_choice_space.states["wealth"], model.states["wealth"].to_jax() ) - assert state_space_info.var_names == ("wealth",) - assert state_space_info.discrete_vars == {} - assert state_space_info.continuous_vars == model.states + assert state_space_info.states_names == ("wealth",) + assert state_space_info.discrete_states == {} + assert state_space_info.continuous_states == model.states diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index b183a095..f1e2c99f 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -190,7 +190,6 @@ def test_create_compute_conditional_continuation_value(): u_and_f = get_utility_and_feasibility_function( model=model, state_space_info=state_space_info, - name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, ) @@ -236,7 +235,6 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): u_and_f = get_utility_and_feasibility_function( model=model, state_space_info=state_space_info, - name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, ) @@ -287,7 +285,6 @@ def test_create_compute_conditional_continuation_policy(): u_and_f = get_utility_and_feasibility_function( model=model, state_space_info=state_space_info, - name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, ) @@ -334,7 +331,6 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): u_and_f = get_utility_and_feasibility_function( model=model, state_space_info=state_space_info, - name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, ) diff --git a/tests/test_function_representation.py b/tests/test_function_representation.py index 9734e721..5d44104b 100644 --- a/tests/test_function_representation.py +++ b/tests/test_function_representation.py @@ -12,20 +12,20 @@ _get_interpolator, _get_label_translator, _get_lookup_function, - get_function_representation, + get_value_function_representation, ) from lcm.interfaces import ( - SpaceInfo, + StateSpaceInfo, ) def test_function_evaluator_with_one_continuous_variable(): wealth_grid = LinspaceGrid(start=-3, stop=3, n_points=7) - space_info = SpaceInfo( - var_names=["wealth"], - discrete_vars={}, - continuous_vars={ + state_space_info = StateSpaceInfo( + states_names=["wealth"], + discrete_states={}, + continuous_states={ "wealth": wealth_grid, }, ) @@ -33,11 +33,7 @@ def test_function_evaluator_with_one_continuous_variable(): vf_arr = jnp.pi * wealth_grid.to_jax() + 2 # create the evaluator - evaluator = get_function_representation( - space_info=space_info, - name_of_values_on_grid="vf_arr", - input_prefix="next_", - ) + evaluator = get_value_function_representation(state_space_info) # partial the function values into the evaluator func = partial(evaluator, vf_arr=vf_arr) @@ -51,18 +47,14 @@ def test_function_evaluator_with_one_continuous_variable(): def test_function_evaluator_with_one_discrete_variable(): vf_arr = jnp.array([1, 2]) - space_info = SpaceInfo( - var_names=["working"], - discrete_vars={"working": [0, 1]}, - continuous_vars={}, + state_space_info = StateSpaceInfo( + states_names=["working"], + discrete_states={"working": [0, 1]}, + continuous_states={}, ) # create the evaluator - evaluator = get_function_representation( - space_info=space_info, - name_of_values_on_grid="vf_arr", - input_prefix="next_", - ) + evaluator = get_value_function_representation(state_space_info) # partial the function values into the evaluator func = partial(evaluator, vf_arr=vf_arr) @@ -84,8 +76,8 @@ def test_function_evaluator(): The utility function is wealth + human_capital + c. c takes a different value for each discrete state choice combination. - The setup of space_info here is quite long. Usually these inputs will be generated - from a model specification. + The setup of state_space_info here is quite long. Usually these inputs will be + generated from a model specification. """ # create a value function array @@ -111,24 +103,25 @@ def test_function_evaluator(): # create info on axis of value function array var_names = ["retired", "insured", "wealth", "human_capital"] - space_info = SpaceInfo( - var_names=var_names, - discrete_vars=discrete_vars, - continuous_vars=continuous_vars, + state_space_info = StateSpaceInfo( + states_names=var_names, + discrete_states=discrete_vars, + continuous_states=continuous_vars, ) # create the evaluator - evaluator = get_function_representation( - space_info=space_info, - name_of_values_on_grid="vf_arr", + evaluator = get_value_function_representation( + state_space_info=state_space_info, ) - # test the evaluator + # test the evaluator; note that the prefix 'next_' is added to the variable names + # by default, and that the argument name of the value function array is 'vf_arr' by + # default; these can be changed when calling get_value_function_representation out = evaluator( - retired=1, - insured=0, - wealth=600, - human_capital=1.5, + next_retired=1, + next_insured=0, + next_wealth=600, + next_human_capital=1.5, vf_arr=vf_arr, ) @@ -209,10 +202,10 @@ def _utility(wealth, working): def test_get_function_evaluator_illustrative(): a_grid = LinspaceGrid(start=0, stop=1, n_points=3) - space_info = SpaceInfo( - var_names=["a"], - discrete_vars={}, - continuous_vars={ + state_space_info = StateSpaceInfo( + states_names=["a"], + discrete_states={}, + continuous_states={ "a": a_grid, }, ) @@ -220,8 +213,8 @@ def test_get_function_evaluator_illustrative(): values = jnp.pi * a_grid.to_jax() + 2 # create the evaluator - evaluator = get_function_representation( - space_info=space_info, + evaluator = get_value_function_representation( + state_space_info=state_space_info, name_of_values_on_grid="values_name", input_prefix="prefix_", ) @@ -282,43 +275,43 @@ def test_fail_if_interpolation_axes_are_not_last_illustrative(): # Empty intersection of var_names and continuous_vars # ================================================================================== - space_info = SpaceInfo( - var_names=["a", "b"], - continuous_vars={ + state_space_info = StateSpaceInfo( + states_names=["a", "b"], + continuous_states={ "c": None, }, - discrete_vars={}, + discrete_states={}, ) - _fail_if_interpolation_axes_are_not_last(space_info) # does not fail + _fail_if_interpolation_axes_are_not_last(state_space_info) # does not fail # Non-empty intersection but correct order # ================================================================================== - space_info = SpaceInfo( - var_names=["a", "b", "c"], - continuous_vars={ + state_space_info = StateSpaceInfo( + states_names=["a", "b", "c"], + continuous_states={ "b": None, "c": None, "d": None, }, - discrete_vars={}, + discrete_states={}, ) - _fail_if_interpolation_axes_are_not_last(space_info) # does not fail + _fail_if_interpolation_axes_are_not_last(state_space_info) # does not fail # Non-empty intersection and in-correct order # ================================================================================== - space_info = SpaceInfo( - var_names=["b", "c", "a"], # "b", "c" are not last anymore - continuous_vars={ + state_space_info = StateSpaceInfo( + states_names=["b", "c", "a"], # "b", "c" are not last anymore + continuous_states={ "b": None, "c": None, "d": None, }, - discrete_vars={}, + discrete_states={}, ) with pytest.raises(ValueError, match="Continuous variables need to be the last"): - _fail_if_interpolation_axes_are_not_last(space_info) + _fail_if_interpolation_axes_are_not_last(state_space_info) diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index 633fc45b..2df77096 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -38,7 +38,6 @@ def test_get_utility_and_feasibility_function(): u_and_f = get_utility_and_feasibility_function( model=model, state_space_info=state_space_info, - name_of_values_on_grid="vf_arr", period=model.n_periods - 1, is_last_period=True, ) From 4199a06bc6f1a60b8ade04d442bf30c8140b2e3c Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 13 Feb 2025 16:30:47 +0100 Subject: [PATCH 16/16] Improve readability of dispatchers --- src/lcm/dispatchers.py | 72 ++++++++++----------------- src/lcm/entry_point.py | 4 +- src/lcm/grids.py | 5 +- src/lcm/model_functions.py | 2 +- src/lcm/simulation/simulate.py | 6 +-- src/lcm/solution/solve_brute.py | 3 +- src/lcm/utils.py | 12 +++++ tests/test_dispatchers.py | 24 ++++----- tests/test_function_representation.py | 6 +-- tests/test_utils.py | 17 +++++++ 10 files changed, 80 insertions(+), 71 deletions(-) create mode 100644 src/lcm/utils.py create mode 100644 tests/test_utils.py diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index 063543d0..0ce7f59b 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -5,14 +5,15 @@ from jax import Array, vmap from lcm.functools import allow_args, allow_only_kwargs +from lcm.utils import find_duplicates F = TypeVar("F", bound=Callable[..., Array]) def spacemap( func: F, - product_vars: list[str], - combination_vars: list[str] | None = None, + product_vars: tuple[str, ...], + combination_vars: tuple[str, ...], ) -> F: """Apply vmap such that func can be evaluated on product and combination variables. @@ -48,52 +49,32 @@ def spacemap( described above but there might be additional dimensions. """ - # Check inputs and prepare function - # ================================================================================== - duplicates = {v for v in product_vars if product_vars.count(v) > 1} - if duplicates: - raise ValueError( - f"Same argument provided more than once in product variables: {duplicates}", + if duplicates := find_duplicates(product_vars, combination_vars): + msg = ( + "Same argument provided more than once in product variables or combination " + f"variables, or is present in both: {duplicates}" ) + raise ValueError(msg) + + func_callable_with_args = allow_args(func) + + vmapped = _base_productmap(func_callable_with_args, product_vars) if combination_vars: - overlap = set(product_vars).intersection(combination_vars) - if overlap: - raise ValueError( - "Product and combination variables must be disjoint. Overlap: " - f"{overlap}", - ) - - duplicates = {v for v in combination_vars if combination_vars.count(v) > 1} - if duplicates: - raise ValueError( - "Same argument provided more than once in combination variables: " - f"{duplicates}", - ) - - # jax.vmap cannot deal with keyword-only arguments - func = allow_args(func) - - # Apply vmap_1d for combination variables and _base_productmap for product variables - # ================================================================================== - if not combination_vars: - vmapped = _base_productmap(func, product_vars) - else: - vmapped = _base_productmap(func, product_vars) vmapped = vmap_1d( vmapped, variables=combination_vars, callable_with="only_args" ) # This raises a mypy error but is perfectly fine to do. See # https://github.com/python/mypy/issues/12472 - vmapped.__signature__ = inspect.signature(func) # type: ignore[attr-defined] + vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined] return allow_only_kwargs(vmapped) def vmap_1d( func: F, - variables: list[str], + variables: tuple[str, ...], *, callable_with: Literal["only_args", "only_kwargs"] = "only_kwargs", ) -> F: @@ -105,7 +86,7 @@ def vmap_1d( Args: func: The function to be dispatched. - variables: List with names of arguments that over which we map. + variables: Tuple with names of arguments that over which we map. callable_with: Whether to apply the allow_kwargs decorator to the dispatched function. If "only_args", the returned function can only be called with positional arguments. If "only_kwargs", the returned function can only be @@ -123,8 +104,7 @@ def vmap_1d( described above but there might be additional dimensions. """ - duplicates = {v for v in variables if variables.count(v) > 1} - if duplicates: + if duplicates := find_duplicates(variables): raise ValueError( f"Same argument provided more than once in variables: {duplicates}", ) @@ -161,7 +141,7 @@ def vmap_1d( return out -def productmap(func: F, variables: list[str]) -> F: +def productmap(func: F, variables: tuple[str, ...]) -> F: """Apply vmap such that func is evaluated on the Cartesian product of variables. This is achieved by an iterative application of vmap. @@ -171,7 +151,7 @@ def productmap(func: F, variables: list[str]) -> F: Args: func: The function to be dispatched. - variables: List with names of arguments that over which the Cartesian product + variables: Tuple with names of arguments that over which the Cartesian product should be formed. Returns: @@ -185,25 +165,23 @@ def productmap(func: F, variables: list[str]) -> F: described above but there might be additional dimensions. """ - func = allow_args(func) # jax.vmap cannot deal with keyword-only arguments - - duplicates = {v for v in variables if variables.count(v) > 1} - if duplicates: + if duplicates := find_duplicates(variables): raise ValueError( f"Same argument provided more than once in variables: {duplicates}", ) - signature = inspect.signature(func) - vmapped = _base_productmap(func, variables) + func_callable_with_args = allow_args(func) + + vmapped = _base_productmap(func_callable_with_args, variables) # This raises a mypy error but is perfectly fine to do. See # https://github.com/python/mypy/issues/12472 - vmapped.__signature__ = signature # type: ignore[attr-defined] + vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined] return allow_only_kwargs(vmapped) -def _base_productmap(func: F, product_axes: list[str]) -> F: +def _base_productmap(func: F, product_axes: tuple[str, ...]) -> F: """Map func over the Cartesian product of product_axes. Like vmap, this function does not preserve the function signature and does not allow @@ -211,7 +189,7 @@ def _base_productmap(func: F, product_axes: list[str]) -> F: Args: func: The function to be dispatched. Cannot have keyword-only arguments. - product_axes: List with names of arguments over which we apply vmap. + product_axes: Tuple with names of arguments over which we apply vmap. Returns: A callable with the same arguments as func. See `product_map` for details. diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 4606f514..ef951c0e 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -198,7 +198,7 @@ def create_compute_conditional_continuation_value( if continuous_choice_variables: utility_and_feasibility = productmap( func=utility_and_feasibility, - variables=continuous_choice_variables, + variables=tuple(continuous_choice_variables), ) @functools.wraps(utility_and_feasibility) @@ -236,7 +236,7 @@ def create_compute_conditional_continuation_policy( if continuous_choice_variables: utility_and_feasibility = productmap( func=utility_and_feasibility, - variables=continuous_choice_variables, + variables=tuple(continuous_choice_variables), ) @functools.wraps(utility_and_feasibility) diff --git a/src/lcm/grids.py b/src/lcm/grids.py index 12769e6c..a63e820c 100644 --- a/src/lcm/grids.py +++ b/src/lcm/grids.py @@ -10,6 +10,7 @@ from lcm import grid_helpers from lcm.exceptions import GridInitializationError, format_messages from lcm.typing import Scalar +from lcm.utils import find_duplicates class Grid(ABC): @@ -188,12 +189,12 @@ def _validate_discrete_grid(category_class: type) -> None: values = list(names_and_values.values()) - duplicated_values = [v for v in values if values.count(v) > 1] + duplicated_values = find_duplicates(values) if duplicated_values: error_messages.append( "Field values of the category_class passed to DiscreteGrid must be unique. " "The following values are duplicated: " - f"{set(duplicated_values)}" + f"{duplicated_values}" ) if values != list(range(len(values))): diff --git a/src/lcm/model_functions.py b/src/lcm/model_functions.py index cd107f94..632821b7 100644 --- a/src/lcm/model_functions.py +++ b/src/lcm/model_functions.py @@ -138,7 +138,7 @@ def get_multiply_weights(stochastic_variables): callable """ - arg_names = [f"weight_next_{var}" for var in stochastic_variables] + arg_names = tuple(f"weight_next_{var}" for var in stochastic_variables) @with_signature(args=arg_names) def _outer(*args, **kwargs): diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index d15074b9..8f6d1fe7 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -234,8 +234,8 @@ def solve_continuous_problem( """ _gridmapped = spacemap( func=compute_ccv, - product_vars=list(data_scs.choices), - combination_vars=list(data_scs.states), + product_vars=tuple(data_scs.choices), + combination_vars=tuple(data_scs.states), ) gridmapped = jax.jit(_gridmapped) @@ -372,7 +372,7 @@ def _generate_simulation_keys(key, ids): # ====================================================================================== -@partial(vmap_1d, variables=["ccv_policy", "discrete_argmax"]) +@partial(vmap_1d, variables=("ccv_policy", "discrete_argmax")) def filter_ccv_policy( ccv_policy, discrete_argmax, diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index b5a3f05b..25080635 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -104,7 +104,8 @@ def solve_continuous_problem( """ _gridmapped = spacemap( func=compute_ccv, - product_vars=list(state_choice_space.ordered_var_names), + product_vars=state_choice_space.ordered_var_names, + combination_vars=(), ) gridmapped = jax.jit(_gridmapped) diff --git a/src/lcm/utils.py b/src/lcm/utils.py new file mode 100644 index 00000000..86d2fd47 --- /dev/null +++ b/src/lcm/utils.py @@ -0,0 +1,12 @@ +from collections import Counter +from collections.abc import Iterable +from itertools import chain +from typing import TypeVar + +T = TypeVar("T") + + +def find_duplicates(*containers: Iterable[T]) -> set[T]: + combined = chain.from_iterable(containers) + counts = Counter(combined) + return {v for v, count in counts.items() if count > 1} diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index 219f4983..d9c467f5 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -103,7 +103,7 @@ def test_productmap_with_all_arguments_mapped(func, args, grids, expected, reque def test_productmap_with_positional_args(setup_productmap_f): - decorated = productmap(f, ["a", "b", "c"]) + decorated = productmap(f, ("a", "b", "c")) match = ( "This function has been decorated so that it allows only kwargs, but was " "called with positional arguments." @@ -113,10 +113,10 @@ def test_productmap_with_positional_args(setup_productmap_f): def test_productmap_different_func_order(setup_productmap_f): - decorated_f = productmap(f, ["a", "b", "c"]) + decorated_f = productmap(f, ("a", "b", "c")) expected = decorated_f(**setup_productmap_f) - decorated_f2 = productmap(f2, ["a", "b", "c"]) + decorated_f2 = productmap(f2, ("a", "b", "c")) calculated_f2 = decorated_f2(**setup_productmap_f) aaae(calculated_f2, expected) @@ -125,7 +125,7 @@ def test_productmap_different_func_order(setup_productmap_f): def test_productmap_change_arg_order(setup_productmap_f, expected_productmap_f): expected = jnp.transpose(expected_productmap_f, (1, 0, 2)) - decorated = productmap(f, ["b", "a", "c"]) + decorated = productmap(f, ("b", "a", "c")) calculated = decorated(**setup_productmap_f) aaae(calculated, expected) @@ -142,7 +142,7 @@ def test_productmap_with_all_arguments_mapped_some_len_one(): expected = allow_args(f)(*helper).reshape(1, 1, 5) - decorated = productmap(f, ["a", "b", "c"]) + decorated = productmap(f, ("a", "b", "c")) calculated = decorated(**grids) aaae(calculated, expected) @@ -154,7 +154,7 @@ def test_productmap_with_all_arguments_mapped_some_scalar(): "c": jnp.linspace(1, 5, 5), } - decorated = productmap(f, ["a", "b", "c"]) + decorated = productmap(f, ("a", "b", "c")) with pytest.raises(ValueError, match="vmap was requested to map its argument"): decorated(**grids) @@ -170,7 +170,7 @@ def test_productmap_with_some_arguments_mapped(): expected = allow_args(f)(*helper).reshape(10, 5) - decorated = productmap(f, ["a", "c"]) + decorated = productmap(f, ("a", "c")) calculated = decorated(**grids) aaae(calculated, expected) @@ -178,7 +178,7 @@ def test_productmap_with_some_arguments_mapped(): def test_productmap_with_some_argument_mapped_twice(): error_msg = "Same argument provided more than once." with pytest.raises(ValueError, match=error_msg): - productmap(f, ["a", "a", "c"]) + productmap(f, ("a", "a", "c")) # ====================================================================================== @@ -233,8 +233,8 @@ def test_spacemap_all_arguments_mapped( decorated = spacemap( g, - list(product_vars), - list(combination_vars), + tuple(product_vars), + tuple(combination_vars), ) calculated = decorated(**product_vars, **combination_vars) @@ -245,12 +245,12 @@ def test_spacemap_all_arguments_mapped( ("error_msg", "product_vars", "combination_vars"), [ ( - "Product and combination variables must be disjoint. Overlap: {'a'}", + "Same argument provided more than once in product variables or combination", ["a", "b"], ["a", "c", "d"], ), ( - "Same argument provided more than once in product variables: {'a'}", + "Same argument provided more than once in product variables or combination", ["a", "a", "b"], ["c", "d"], ), diff --git a/tests/test_function_representation.py b/tests/test_function_representation.py index 5d44104b..6444b83f 100644 --- a/tests/test_function_representation.py +++ b/tests/test_function_representation.py @@ -83,7 +83,7 @@ def test_function_evaluator(): # create a value function array discrete_part = jnp.arange(4).repeat(6 * 7).reshape((2, 2, 6, 7)) * 100 - cont_func = productmap(lambda x, y: x + y, ["x", "y"]) + cont_func = productmap(lambda x, y: x + y, ("x", "y")) cont_part = cont_func(x=jnp.linspace(100, 1100, 6), y=jnp.linspace(-3, 3, 7)) vf_arr = discrete_part + cont_part @@ -181,7 +181,7 @@ def test_get_interpolator(): def _utility(wealth, working): return 2 * wealth - working - prod_utility = productmap(_utility, variables=["wealth", "working"]) + prod_utility = productmap(_utility, variables=("wealth", "working")) values = prod_utility( wealth=jnp.arange(4, dtype=float), @@ -260,7 +260,7 @@ def test_get_interpolator_illustrative(): def f(a, b): return a - b - prod_f = productmap(f, variables=["a", "b"]) + prod_f = productmap(f, variables=("a", "b")) values = prod_f(a=jnp.arange(2, dtype=float), b=jnp.arange(3, dtype=float)) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..de5dd88c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,17 @@ +from lcm.utils import find_duplicates + + +def test_find_duplicates_singe_container_no_duplicates(): + assert find_duplicates([1, 2, 3, 4, 5]) == set() + + +def test_find_duplicates_single_container_with_duplicates(): + assert find_duplicates([1, 2, 3, 4, 5, 5]) == {5} + + +def test_find_duplicates_multiple_containers_no_duplicates(): + assert find_duplicates([1, 2, 3, 4, 5], [6, 7, 8, 9, 10]) == set() + + +def test_find_duplicates_multiple_containers_with_duplicates(): + assert find_duplicates([1, 2, 3, 4, 5, 5], [6, 7, 8, 9, 10, 5]) == {5}