Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 18 additions & 26 deletions src/dcegm/asset_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from jax import vmap

from dcegm.law_of_motion import (
calc_assets_beginning_of_period_2cont_vec,
calc_beginning_of_period_assets_1cont_vec,
calc_beginning_of_period_assets_for_single_state,
)


Expand Down Expand Up @@ -37,30 +36,23 @@ def adjust_observed_assets(observed_states_dict, params, model_class, aux_outs=F
second_cont_state_vars = observed_states_dict[second_cont_state_name]
observed_states_dict_int.pop(second_cont_state_name)

adjusted_assets = vmap(
calc_assets_beginning_of_period_2cont_vec,
in_axes=(0, 0, 0, None, None, None, None),
)(
observed_states_dict_int,
second_cont_state_vars,
assets_end_last_period,
jnp.array(0.0, dtype=jnp.float64),
params,
model_funcs["compute_assets_begin_of_period"],
aux_outs,
)

all_states = {
**observed_states_dict_int,
"continuous_state": second_cont_state_vars,
}
else:
adjusted_assets = vmap(
calc_beginning_of_period_assets_1cont_vec,
in_axes=(0, 0, None, None, None, None),
)(
observed_states_dict,
assets_end_last_period,
jnp.array(0.0, dtype=jnp.float64),
params,
model_funcs["compute_assets_begin_of_period"],
aux_outs,
)
all_states = observed_states_dict_int

adjusted_assets = vmap(
calc_beginning_of_period_assets_for_single_state,
in_axes=(0, 0, None, None, None, None),
)(
all_states,
assets_end_last_period,
jnp.array(0.0, dtype=jnp.float64),
params,
model_funcs["compute_assets_begin_of_period"],
aux_outs,
)

return adjusted_assets
2 changes: 1 addition & 1 deletion src/dcegm/final_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def calc_value_and_budget_for_each_gridpoint(
wealth_final_period = calc_assets_beginning_of_period_2cont_vec(
state_vec=state_vec,
continuous_state_beginning_of_period=second_continuous_state,
asset_grid_point_end_of_previous_period=asset_grid_point_end_of_previous_period,
asset_end_of_previous_period=asset_grid_point_end_of_previous_period,
income_shock_draw=jnp.array(0.0),
params=params,
compute_assets_begin_of_period=compute_assets_begin_of_period,
Expand Down
11 changes: 10 additions & 1 deletion src/dcegm/interfaces/model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,14 @@ def __init__(
debug_info: str = None,
model_save_path: str = None,
model_load_path: str = None,
use_stochastic_sparsity: bool = False,
):
"""Setup the model and check if load or save is required."""
"""Setup the model and check if load or save is required.

Args:
use_stochastic_sparsity (bool, optional): EXPERIMENTAL: Use stochastic transition sparsity.

"""

if model_load_path is not None:
model_dict = load_model_dict(
Expand All @@ -61,6 +67,7 @@ def __init__(
stochastic_states_transitions=stochastic_states_transitions,
shock_functions=shock_functions,
path=model_load_path,
use_stochastic_sparsity=use_stochastic_sparsity,
)
elif model_save_path is not None:
model_dict = create_model_dict_and_save(
Expand All @@ -74,6 +81,7 @@ def __init__(
shock_functions=shock_functions,
path=model_save_path,
debug_info=debug_info,
use_stochastic_sparsity=use_stochastic_sparsity,
)
else:
model_dict = create_model_dict(
Expand All @@ -86,6 +94,7 @@ def __init__(
stochastic_states_transitions=stochastic_states_transitions,
shock_functions=shock_functions,
debug_info=debug_info,
use_stochastic_sparsity=use_stochastic_sparsity,
)

self.model_specs = jax.tree_util.tree_map(try_jax_array, model_specs)
Expand Down
2 changes: 1 addition & 1 deletion src/dcegm/interpolation/interp2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def interp2d_policy_on_wealth_and_regular_grid(
to interpolate.

Returns:
float: The interpolated value of the policy function at the given
jnp.ndarray | float: The interpolated value of the policy function at the given
(regular, wealth) point.

"""
Expand Down
196 changes: 76 additions & 120 deletions src/dcegm/law_of_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,80 +23,93 @@ def calc_cont_grids_next_period(
income_shock_draws_unscaled * income_shock_std + income_shock_mean
)

# Generate result dict
cont_grids_next_period = {}

if continuous_states_info["second_continuous_exists"]:
continuous_state_next_period = calculate_continuous_state(
discrete_states_beginning_of_period=state_space_dict,
continuous_grid=continuous_states_info["second_continuous_grid"],
params=params,
compute_continuous_state=model_funcs["next_period_continuous_state"],
)
# Fill in result dict
cont_grids_next_period["second_continuous"] = continuous_state_next_period

# Extra dimension for continuous state
assets_beginning_of_next_period = calc_assets_beginning_of_period_2cont(
discrete_states_beginning_of_next_period=state_space_dict,
continuous_state_beginning_of_next_period=continuous_state_next_period,
assets_grid_end_of_period=continuous_states_info[
"assets_grid_end_of_period"
],
income_shocks=income_shocks_scaled,
params=params,
compute_assets_begin_of_period=model_funcs[
"compute_assets_begin_of_period"
],
)

cont_grids_next_period = {
"assets_begin_of_period": assets_beginning_of_next_period,
"second_continuous": continuous_state_next_period,
# Prepare dict used to calculate beginning of period assets
state_specific_grids = {
"states": state_space_dict,
"continuous_state": continuous_state_next_period,
}

else:
assets_begin_of_next_period = calc_beginning_of_period_assets_1cont(
discrete_states_beginning_of_period=state_space_dict,
assets_grid_end_of_period=continuous_states_info[
"assets_grid_end_of_period"
],
income_shocks_current_period=income_shocks_scaled,
state_specific_grids = {
"states": state_space_dict,
}

def fix_assets_and_shocks_for_broadcast(
states,
asset_end_of_previous_period,
income_draw,
):
assets_begin_of_period = calc_beginning_of_period_assets_for_single_state(
state_vec=states,
asset_end_of_previous_period=asset_end_of_previous_period,
income_shock_draw=income_draw,
params=params,
compute_assets_begin_of_period=model_funcs[
"compute_assets_begin_of_period"
],
aux_outs=False,
)
cont_grids_next_period = {
"assets_begin_of_period": assets_begin_of_next_period,
}
return assets_begin_of_period

return cont_grids_next_period


def calc_beginning_of_period_assets_1cont(
discrete_states_beginning_of_period,
assets_grid_end_of_period,
income_shocks_current_period,
params,
compute_assets_begin_of_period,
):
assets_begin_of_period = vmap(
broadcast_function = lambda states: vmap(
vmap(
vmap(
calc_beginning_of_period_assets_1cont_vec,
in_axes=(None, None, 0, None, None, None), # income shocks
),
in_axes=(None, 0, None, None, None, None), # assets
fix_assets_and_shocks_for_broadcast,
in_axes=(None, None, 0), # income shocks
),
in_axes=(0, None, None, None, None, None), # discrete states
in_axes=(None, 0, None), # assets
)(
discrete_states_beginning_of_period,
assets_grid_end_of_period,
income_shocks_current_period,
params,
compute_assets_begin_of_period,
False,
states,
continuous_states_info["assets_grid_end_of_period"],
income_shocks_scaled,
)
return assets_begin_of_period

final_args = ()
# Default is no chaining of vmaps. Then I add consequently vmap over specific grids
vmap_chain = broadcast_function

for grid_name in state_specific_grids.keys():
if grid_name != "states":
# Use default argument to capture current values
vmap_chain = add_vmap_chain_for_grid(vmap_chain, grid_name)
final_args += (state_specific_grids[grid_name],)

final_args = (state_specific_grids["states"],) + final_args
assets_begin_of_next_period = vmap(vmap_chain)(*final_args)
cont_grids_next_period["assets_begin_of_period"] = assets_begin_of_next_period
return cont_grids_next_period


def calc_beginning_of_period_assets_1cont_vec(
def add_vmap_chain_for_grid(inner_func, gname):
"""The function adds a vmap layer for a specific grid.

It vmaps over the remaining dimension of the grid. So if we have a grid that is
(n_discrete_states, n_grid_points), we can later vmap over the discrete states and
this function will add the n_grid_points dimension to be vmapped over. The function
only expects later the grid to arrive in n_grid_points. So we can also use the
function in the final period calculation.

"""

def grid_wrapper(states, new_state_grid):
all_states = {**states, gname: new_state_grid}
return inner_func(all_states)

return vmap(grid_wrapper, in_axes=(None, 0))


def calc_beginning_of_period_assets_for_single_state(
state_vec,
asset_end_of_previous_period,
income_shock_draw,
Expand Down Expand Up @@ -124,22 +137,23 @@ def calc_beginning_of_period_assets_1cont_vec(
def calc_assets_beginning_of_period_2cont_vec(
state_vec,
continuous_state_beginning_of_period,
asset_grid_point_end_of_previous_period,
asset_end_of_previous_period,
income_shock_draw,
params,
compute_assets_begin_of_period,
aux_outs,
):

out_budget = compute_assets_begin_of_period(
all_states = {
**state_vec,
continuous_state=continuous_state_beginning_of_period,
asset_end_of_previous_period=asset_grid_point_end_of_previous_period,
income_shock_previous_period=income_shock_draw,
"continuous_state": continuous_state_beginning_of_period,
}
checked_out = calc_beginning_of_period_assets_for_single_state(
state_vec=all_states,
asset_end_of_previous_period=asset_end_of_previous_period,
income_shock_draw=income_shock_draw,
params=params,
)
checked_out = check_budget_equation_and_return_wealth_plus_optional_aux(
out_budget, optional_aux=aux_outs
compute_assets_begin_of_period=compute_assets_begin_of_period,
aux_outs=aux_outs,
)
return checked_out

Expand Down Expand Up @@ -179,39 +193,6 @@ def calc_continuous_state_for_each_grid_point(
return out


def calc_assets_beginning_of_period_2cont(
discrete_states_beginning_of_next_period,
continuous_state_beginning_of_next_period,
assets_grid_end_of_period,
income_shocks,
params,
compute_assets_begin_of_period,
):

assets_begin_of_period = vmap(
vmap(
vmap(
vmap(
calc_assets_beginning_of_period_2cont_vec,
in_axes=(None, None, None, 0, None, None, None), # income shocks
),
in_axes=(None, None, 0, None, None, None, None), # assets
),
in_axes=(None, 0, None, None, None, None, None), # continuous state
),
in_axes=(0, 0, None, None, None, None, None), # discrete states
)(
discrete_states_beginning_of_next_period,
continuous_state_beginning_of_next_period,
assets_grid_end_of_period,
income_shocks,
params,
compute_assets_begin_of_period,
False,
)
return assets_begin_of_period


# =====================================================================================
# Simulation
# =====================================================================================
Expand All @@ -226,7 +207,7 @@ def calculate_assets_begin_of_period_for_all_agents(
):
"""Simulation."""
assets_begin_of_next_period = vmap(
calc_beginning_of_period_assets_1cont_vec,
calc_beginning_of_period_assets_for_single_state,
in_axes=(0, 0, 0, None, None, None),
)(
states_beginning_of_period,
Expand Down Expand Up @@ -256,28 +237,3 @@ def calculate_second_continuous_state_for_all_agents(
compute_continuous_state,
)
return continuous_state_beginning_of_next_period


def calc_assets_begin_of_period_for_all_agents(
states_beginning_of_period,
continuous_state_beginning_of_period,
assets_end_of_period,
income_shocks_of_period,
params,
compute_assets_begin_of_period,
):
"""Simulation."""

assets_begin_of_next_period, aux_dict = vmap(
calc_assets_beginning_of_period_2cont_vec,
in_axes=(0, 0, 0, 0, None, None, None),
)(
states_beginning_of_period,
continuous_state_beginning_of_period,
assets_end_of_period,
income_shocks_of_period,
params,
compute_assets_begin_of_period,
True,
)
return assets_begin_of_next_period, aux_dict
4 changes: 2 additions & 2 deletions src/dcegm/pre_processing/batches/last_two_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ def add_last_two_period_information(
"child_states_second_last_period": child_states_second_last_period,
}

state_choice_space_dict = model_structure["state_choice_space_dict"]
# Also add state choice mat as dictionary for each of the two periods
for idx, period_name in [
(idx_state_choice_final_period, "final"),
(idx_state_choice_second_last_period, "second_last"),
]:
last_two_period_info[f"state_choice_mat_{period_name}_period"] = {
key: state_choice_space[:, i][idx]
for i, key in enumerate(discrete_states_names + ["choice"])
key: var[idx] for key, var in state_choice_space_dict.items()
}
return last_two_period_info
Loading
Loading