diff --git a/soepy/shared/non_consumption_utility.py b/soepy/shared/non_consumption_utility.py new file mode 100644 index 0000000..2551769 --- /dev/null +++ b/soepy/shared/non_consumption_utility.py @@ -0,0 +1,89 @@ +import jax.numpy as jnp + + +def calculate_non_consumption_utility( + model_params, + states, + child_bins, +): + """Calculate non-pecuniary utility contribution. + + Parameters + ---------- + states : np.ndarray + Shape (n_states, n_state_vars) matrix of states + child_bins : np.ndarray + Shape (n_states,) array with child bin indices for each state + + Returns + ------- + non_consumption_utility : np.ndarray + Shape (n_states, 3) matrix with utilities for [no work, part-time, full-time] + """ + educ = states[:, 1] + unobs_types = states[:, 5] + + # Base utilities + util_pt = model_params.theta_p[unobs_types] + util_ft = model_params.theta_f[unobs_types] + + # Binary selectors (implicitly cast to 0/1 in arithmetic) + b0 = child_bins == 0 + b1 = child_bins == 1 + b2 = child_bins == 2 + b3 = child_bins == 3 + b4 = child_bins > 3 + + # Education-dependent components + no_kids_f = model_params.no_kids_f[educ] + no_kids_p = model_params.no_kids_p[educ] + yes_kids_f = model_params.yes_kids_f[educ] + yes_kids_p = model_params.yes_kids_p[educ] + + # Part-time utility (always includes full time base utility) + util_pt += ( + b0 * (no_kids_f + no_kids_p) + + b1 + * ( + yes_kids_f + + yes_kids_p + + model_params.child_0_2_f + + model_params.child_0_2_p + ) + + b2 + * ( + yes_kids_f + + yes_kids_p + + model_params.child_3_5_f + + model_params.child_3_5_p + ) + + b3 + * ( + yes_kids_f + + yes_kids_p + + model_params.child_6_10_f + + model_params.child_6_10_p + ) + + b4 * (yes_kids_f + yes_kids_p) + ) + + # Full-time utility + util_ft += ( + b0 * no_kids_f + + b1 * (yes_kids_f + model_params.child_0_2_f) + + b2 * (yes_kids_f + model_params.child_3_5_f) + + b3 * (yes_kids_f + model_params.child_6_10_f) + + b4 * yes_kids_f + ) + + # Stack: [no work, part-time, full-time] + non_consumption_utility = jnp.stack( + ( + jnp.zeros_like(util_pt), + util_pt, + util_ft, + ), + axis=1, + ) + + return jnp.exp(non_consumption_utility) diff --git a/soepy/shared/numerical_integration.py b/soepy/shared/numerical_integration.py index 18e9914..e42ee36 100644 --- a/soepy/shared/numerical_integration.py +++ b/soepy/shared/numerical_integration.py @@ -1,19 +1,17 @@ import numpy as np from scipy.special import roots_hermite -from soepy.shared.shared_auxiliary import draw_disturbances - -def get_integration_draws_and_weights(model_spec, model_params): +def get_integration_draws_and_weights(model_spec): if model_spec.integration_method == "quadrature": # Draw standard points and corresponding weights standard_draws, draw_weights_emax = roots_hermite(model_spec.num_draws_emax) # Rescale draws and weights - draws_emax = standard_draws * np.sqrt(2) * model_params.shock_sd + draws_emax = standard_draws * np.sqrt(2) draw_weights_emax *= 1 / np.sqrt(np.pi) elif model_spec.integration_method == "monte_carlo": - draws_emax = draw_disturbances( - model_spec.seed_emax, 1, model_spec.num_draws_emax, model_params + draws_emax = draw_zero_one_distributed_shocks( + model_spec.seed_emax, 1, model_spec.num_draws_emax )[0] draw_weights_emax = ( np.ones(model_spec.num_draws_emax) / model_spec.num_draws_emax @@ -24,3 +22,18 @@ def get_integration_draws_and_weights(model_spec, model_params): ) return draws_emax, draw_weights_emax + + +def draw_zero_one_distributed_shocks(seed, num_periods, num_draws): + """Creates desired number of draws of a multivariate standard normal + distribution. + + """ + np.random.seed(seed) + + mean = 0 + + # Create draws from the standard normal distribution + draws = np.random.normal(mean, 1, (num_periods, num_draws)) + + return draws diff --git a/soepy/shared/shared_auxiliary.py b/soepy/shared/shared_auxiliary.py deleted file mode 100644 index bf39ad4..0000000 --- a/soepy/shared/shared_auxiliary.py +++ /dev/null @@ -1,186 +0,0 @@ -import jax.numpy as jnp -import numpy as np - - -def draw_disturbances(seed, num_periods, num_draws, model_params): - """Creates desired number of draws of a multivariate standard normal - distribution. - - """ - np.random.seed(seed) - - mean = 0 - - # Create draws from the standard normal distribution - draws = np.random.normal(mean, model_params.shock_sd, (num_periods, num_draws)) - - return draws - - -def calculate_log_wage(model_params, states, is_expected): - """Calculate utility components for all choices given state, period, and shocks. - - Parameters - ---------- - model_params : namedtuple - Contains all parameters of the model including information on dimensions - (number of periods, agents, random draws, etc.) and coefficients to be - estimated. - states : np.ndarray - Array with shape (num_states, 5) containing period, years of schooling, - the lagged choice, the years of experience in part-time, and the - years of experience in full-time employment. - is_expected: bool - A boolean indicator that differentiates between the human capital accumulation - process that agents expect (is_expected = True) and that the market generates - (is_expected = False) - - Returns - ------- - log_wage_systematic : array - One dimensional array with length num_states containing the part of the wages - at the respective state space point that do not depend on the agent's choice, - nor on the random shock. - non_consumption_utilities : np.ndarray - Array of dimension (num_states, num_choices) containing the utility - contribution of non-pecuniary factors. - - """ - if is_expected: - # Calculate biased part-time expectation by using ratio from expected data and - # structural paramteters - gamma_p = ( - model_params.gamma_p_bias / (model_params.gamma_p / model_params.gamma_f) - ) * model_params.gamma_p - else: - gamma_p = model_params.gamma_p - - log_wage_systematic = calculate_log_wage_systematic( - model_params.gamma_0, - model_params.gamma_f, - gamma_p, - states, - ) - - return log_wage_systematic - - -def calculate_log_wage_systematic(gamma_0, gamma_f, gamma_p, states): - """Calculate systematic wages, i.e., wages net of shock, for all states.""" - - exp_p_states, exp_f_states = states[:, 3], states[:, 4] - - log_exp_p = jnp.log(exp_p_states + 1) - log_exp_f = jnp.log(exp_f_states + 1) - - # Assign wage returns - gamma_0_edu = gamma_0[states[:, 1]] - gamma_f_edu = gamma_f[states[:, 1]] - gamma_p_edu = gamma_p[states[:, 1]] - - # Calculate wage in the given state - log_wage_systematic = ( - gamma_0_edu + gamma_f_edu * log_exp_f + gamma_p_edu * log_exp_p - ) - return log_wage_systematic - - -def calculate_non_consumption_utility( - model_params, - states, - child_bins, -): - """Calculate non-pecuniary utility contribution. - - Parameters - ---------- - states : np.ndarray - Shape (n_states, n_state_vars) matrix of states - child_bins : np.ndarray - Shape (n_states,) array with child bin indices for each state - - Returns - ------- - non_consumption_utility : np.ndarray - Shape (n_states, 3) matrix with utilities for [no work, part-time, full-time] - """ - theta_p = model_params.theta_p - theta_f = model_params.theta_f - no_kids_f = model_params.no_kids_f - no_kids_p = model_params.no_kids_p - yes_kids_f = model_params.yes_kids_f - yes_kids_p = model_params.yes_kids_p - child_0_2_f = model_params.child_0_2_f - child_0_2_p = model_params.child_0_2_p - child_3_5_f = model_params.child_3_5_f - child_3_5_p = model_params.child_3_5_p - child_6_10_f = model_params.child_6_10_f - child_6_10_p = model_params.child_6_10_p - - n_states = states.shape[0] - educ_levels = states[:, 1] # Extract education level for all states - - # Initialize output: column 0 (no work) = 0, columns 1-2 get theta values - non_consumption_utility = np.zeros((n_states, 3)) - non_consumption_utility[:, 1] = theta_p[states[:, 5]] # part-time base utility - non_consumption_utility[:, 2] = theta_f[states[:, 5]] # full-time base utility - - # Create masks for each child bin - no_kids_mask = child_bins == 0 - child_0_2_mask = child_bins == 1 - child_3_5_mask = child_bins == 2 - child_6_10_mask = child_bins == 3 - older_kids_mask = child_bins > 3 - - # No kids (child_bin == 0) - non_consumption_utility[no_kids_mask, 1] += ( - no_kids_f[educ_levels[no_kids_mask]] + no_kids_p[educ_levels[no_kids_mask]] - ) - non_consumption_utility[no_kids_mask, 2] += no_kids_f[educ_levels[no_kids_mask]] - - # Child 0-2 (child_bin == 1) - non_consumption_utility[child_0_2_mask, 1] += ( - yes_kids_f[educ_levels[child_0_2_mask]] - + yes_kids_p[educ_levels[child_0_2_mask]] - + child_0_2_f - + child_0_2_p - ) - non_consumption_utility[child_0_2_mask, 2] += ( - yes_kids_f[educ_levels[child_0_2_mask]] + child_0_2_f - ) - - # Child 3-5 (child_bin == 2) - non_consumption_utility[child_3_5_mask, 1] += ( - yes_kids_f[educ_levels[child_3_5_mask]] - + yes_kids_p[educ_levels[child_3_5_mask]] - + child_3_5_f - + child_3_5_p - ) - non_consumption_utility[child_3_5_mask, 2] += ( - yes_kids_f[educ_levels[child_3_5_mask]] + child_3_5_f - ) - - # Child 6-10 (child_bin == 3) - non_consumption_utility[child_6_10_mask, 1] += ( - yes_kids_f[educ_levels[child_6_10_mask]] - + yes_kids_p[educ_levels[child_6_10_mask]] - + child_6_10_f - + child_6_10_p - ) - non_consumption_utility[child_6_10_mask, 2] += ( - yes_kids_f[educ_levels[child_6_10_mask]] + child_6_10_f - ) - - # Older kids (child_bin > 3) - non_consumption_utility[older_kids_mask, 1] += ( - yes_kids_f[educ_levels[older_kids_mask]] - + yes_kids_p[educ_levels[older_kids_mask]] - ) - non_consumption_utility[older_kids_mask, 2] += yes_kids_f[ - educ_levels[older_kids_mask] - ] - - # Apply exponential transformation - non_consumption_utility = np.exp(non_consumption_utility) - - return non_consumption_utility diff --git a/soepy/shared/wages.py b/soepy/shared/wages.py new file mode 100644 index 0000000..64488dc --- /dev/null +++ b/soepy/shared/wages.py @@ -0,0 +1,70 @@ +import numpy as np +from jax import numpy as jnp + + +def calculate_log_wage(model_params, states, is_expected): + """Calculate utility components for all choices given state, period, and shocks. + + Parameters + ---------- + model_params : namedtuple + Contains all parameters of the model including information on dimensions + (number of periods, agents, random draws, etc.) and coefficients to be + estimated. + states : np.ndarray + Array with shape (num_states, 5) containing period, years of schooling, + the lagged choice, the years of experience in part-time, and the + years of experience in full-time employment. + is_expected: bool + A boolean indicator that differentiates between the human capital accumulation + process that agents expect (is_expected = True) and that the market generates + (is_expected = False) + + Returns + ------- + log_wage_systematic : array + One dimensional array with length num_states containing the part of the wages + at the respective state space point that do not depend on the agent's choice, + nor on the random shock. + non_consumption_utilities : np.ndarray + Array of dimension (num_states, num_choices) containing the utility + contribution of non-pecuniary factors. + + """ + if is_expected: + # Calculate biased part-time expectation by using ratio from expected data and + # structural paramteters + gamma_p = ( + model_params.gamma_p_bias / (model_params.gamma_p / model_params.gamma_f) + ) * model_params.gamma_p + else: + gamma_p = model_params.gamma_p + + log_wage_systematic = calculate_log_wage_systematic( + model_params.gamma_0, + model_params.gamma_f, + gamma_p, + states, + ) + + return log_wage_systematic + + +def calculate_log_wage_systematic(gamma_0, gamma_f, gamma_p, states): + """Calculate systematic wages, i.e., wages net of shock, for all states.""" + + exp_p_states, exp_f_states = states[:, 3], states[:, 4] + + log_exp_p = jnp.log(exp_p_states + 1) + log_exp_f = jnp.log(exp_f_states + 1) + + # Assign wage returns + gamma_0_edu = jnp.array(gamma_0)[states[:, 1]] + gamma_f_edu = jnp.array(gamma_f)[states[:, 1]] + gamma_p_edu = jnp.array(gamma_p)[states[:, 1]] + + # Calculate wage in the given state + log_wage_systematic = ( + gamma_0_edu + gamma_f_edu * log_exp_f + gamma_p_edu * log_exp_p + ) + return log_wage_systematic diff --git a/soepy/simulate/simulate_auxiliary.py b/soepy/simulate/simulate_auxiliary.py index 5ba520b..9aea3d3 100644 --- a/soepy/simulate/simulate_auxiliary.py +++ b/soepy/simulate/simulate_auxiliary.py @@ -4,9 +4,9 @@ from soepy.exogenous_processes.determine_lagged_choice import lagged_choice_initial from soepy.shared.non_employment import calc_erziehungsgeld from soepy.shared.non_employment import calculate_non_employment_consumption_resources -from soepy.shared.shared_auxiliary import calculate_log_wage -from soepy.shared.shared_auxiliary import draw_disturbances +from soepy.shared.numerical_integration import draw_zero_one_distributed_shocks from soepy.shared.shared_constants import HOURS +from soepy.shared.wages import calculate_log_wage from soepy.simulate.constants_sim import DATA_FORMATS_SIM from soepy.simulate.constants_sim import DATA_FORMATS_SPARSE from soepy.simulate.constants_sim import DATA_LABLES_SIM @@ -362,10 +362,10 @@ def prepare_simulation_data( ) # Draw shocks - attrs_spec = ["seed_sim", "num_periods", "num_agents_sim"] - draws_sim = draw_disturbances( - *[getattr(model_spec, attr) for attr in attrs_spec], model_params + draws_sim = draw_zero_one_distributed_shocks( + model_spec.seed_sim, model_spec.num_periods, model_spec.num_agents_sim ) + draws_sim *= model_params.shock_sd # Calculate utility components log_wage_systematic = calculate_log_wage(model_params, states, is_expected) diff --git a/soepy/simulate/simulate_python.py b/soepy/simulate/simulate_python.py index 96c7143..f09a3f0 100644 --- a/soepy/simulate/simulate_python.py +++ b/soepy/simulate/simulate_python.py @@ -10,6 +10,7 @@ from soepy.pre_processing.model_processing import read_model_spec_init from soepy.simulate.simulate_auxiliary import pyth_simulate from soepy.solve.create_state_space import create_state_space_objects +from soepy.solve.solve_python import get_solve_function from soepy.solve.solve_python import pyth_solve @@ -62,21 +63,21 @@ def simulate( # Simulate agents experiences according to parameters in the model specification df = pyth_simulate( - model_params, - model_spec, - states, - indexer, - emaxs, - covariates, - non_consumption_utilities, - child_age_update_rule, - prob_educ_level, - prob_child_age, - prob_partner_present, - prob_exp_ft, - prob_exp_pt, - prob_child, - prob_partner, + model_params=model_params, + model_spec=model_spec, + states=states, + indexer=indexer, + emaxs=emaxs, + covariates=covariates, + non_consumption_utilities=non_consumption_utilities, + child_age_update_rule=child_age_update_rule, + prob_educ_level=prob_educ_level, + prob_child_age=prob_child_age, + prob_partner_present=prob_partner_present, + prob_exp_ft=prob_exp_ft, + prob_exp_pt=prob_exp_pt, + prob_child=prob_child, + prob_partner=prob_partner, is_expected=False, data_sparse=data_sparse, ).set_index(["Identifier", "Period"]) @@ -120,32 +121,59 @@ def get_simulate_func( child_state_indexes, ) = create_state_space_objects(model_spec) - partial_simulate = partial( - partiable_simulate, - states, - indexer, - covariates, - child_age_update_rule, - child_state_indexes, - prob_educ_level, - prob_child_age, - prob_partner_present, - prob_exp_ft, - prob_exp_pt, - prob_child, - prob_partner, - is_expected, - data_sparse, + solve_func = get_solve_function( + states=states, + covariates=covariates, + child_state_indexes=child_state_indexes, + model_spec=model_spec, + prob_child=prob_child, + prob_partner=prob_partner, + is_expected=is_expected, ) - return partial_simulate + + def simulate_func(model_params_init_file_name, model_spec_init_file_name): + # Read in model specification from yaml file + model_params_df, model_params = read_model_params_init( + model_params_init_file_name + ) + + model_spec = read_model_spec_init(model_spec_init_file_name, model_params_df) + + # Obtain model solution + non_consumption_utilities, emaxs = solve_func(model_params) + + # Simulate agents experiences according to parameters in the model specification + df = pyth_simulate( + model_params=model_params, + model_spec=model_spec, + states=states, + indexer=indexer, + emaxs=emaxs, + covariates=covariates, + non_consumption_utilities=non_consumption_utilities, + child_age_update_rule=child_age_update_rule, + prob_educ_level=prob_educ_level, + prob_child_age=prob_child_age, + prob_partner_present=prob_partner_present, + prob_exp_ft=prob_exp_ft, + prob_exp_pt=prob_exp_pt, + prob_child=prob_child, + prob_partner=prob_partner, + is_expected=False, + data_sparse=data_sparse, + ).set_index(["Identifier", "Period"]) + + return df + + return simulate_func def partiable_simulate( + solve_func, states, indexer, covariates, child_age_update_rule, - child_state_indexes, prob_educ_level, prob_child_age, prob_partner_present, @@ -153,7 +181,6 @@ def partiable_simulate( prob_exp_pt, prob_child, prob_partner, - is_expected, data_sparse, model_params_init_file_name, model_spec_init_file_name, @@ -164,16 +191,7 @@ def partiable_simulate( model_spec = read_model_spec_init(model_spec_init_file_name, model_params_df) # Obtain model solution - non_consumption_utilities, emaxs = pyth_solve( - states, - covariates, - child_state_indexes, - model_params, - model_spec, - prob_child, - prob_partner, - is_expected, - ) + non_consumption_utilities, emaxs = solve_func(model_params) # Simulate agents experiences according to parameters in the model specification df = pyth_simulate( diff --git a/soepy/solve/create_state_space.py b/soepy/solve/create_state_space.py index 722288b..b64bbaa 100644 --- a/soepy/solve/create_state_space.py +++ b/soepy/solve/create_state_space.py @@ -84,32 +84,34 @@ def pyth_create_state_space(model_spec): # Assumption: 1st kid is born no earlier than age 17. # Can be relaxed, e.g., we assume that 1st kid can arrive earliest when # a woman is 16 years old, the condition becomes: - if age_kid - model_spec.child_age_init_max > period: - continue - - if ( - period > model_spec.last_child_bearing_period - and 0 - <= age_kid - <= min(period - (model_spec.last_child_bearing_period + 1), 10) - ): - continue + # if age_kid - model_spec.child_age_init_max > period: + # continue + # + # if ( + # period > model_spec.last_child_bearing_period + # and 0 + # <= age_kid + # <= min(period - (model_spec.last_child_bearing_period + 1), 10) + # ): + # continue age_idx = _kid_age_to_index(age_kid, n_kids_ages) for educ_level in range(model_spec.num_educ_levels): - edu_years = model_spec.educ_years[educ_level] + # edu_years = model_spec.educ_years[educ_level] - # has she completed education already? - if edu_years > period: - continue + # # has she completed education already? + # if edu_years > period: + # continue + + last_period = model_spec.num_periods - 1 # Basic feasibility region for experiences (vectorized): - # exp_f + exp_p <= period + 2*init_exp_max - edu_years - max_total = period + 2 * model_spec.init_exp_max - edu_years - # also must be <= period + init_exp_max individually (this is already implied by max_exp axis), + # exp_f + exp_p <= period + 2*init_exp_max + max_total = last_period + 2 * model_spec.init_exp_max + # also must be <= period + init_exp_max individually, # but original additionally checks exp_f > period + init_exp_max etc. - max_ind = period + model_spec.init_exp_max + max_ind = last_period + model_spec.init_exp_max feasible = (EP + EF) <= max_total feasible &= (EF <= max_ind) & (EP <= max_ind) @@ -118,50 +120,47 @@ def pyth_create_state_space(model_spec): fp = EP[feasible] ff = EF[feasible] - if fp.size == 0: + # if period == edu_years: + # # Entry-period: original code adds states for ALL lagged choices, no extra restrictions. + # lagged = np.tile( + # np.arange(NUM_CHOICES, dtype=np.int32), fp.size + # ) + # exp_p_rep = np.repeat(fp, NUM_CHOICES) + # exp_f_rep = np.repeat(ff, NUM_CHOICES) + # + # else: + # Non-entry periods: apply the lagged-choice restrictions (vectorized). + # max_ft = period + model_spec.init_exp_max - edu_years + # max_pt = period + model_spec.init_exp_max - edu_years + + # allowed[c, j] means for pair c (fp[c], ff[c]) lagged choice j is allowed + allowed = np.ones((fp.size, NUM_CHOICES), dtype=bool) + + # # only worked full-time -> lagged must be 2 + # mask_only_ft = ff == max_ft + # allowed[mask_only_ft, :] = False + # allowed[mask_only_ft, 2] = True + # + # # only worked part-time -> lagged must be 1 + # mask_only_pt = fp == max_pt + # allowed[mask_only_pt, :] = False + # allowed[mask_only_pt, 1] = True + # + # # never worked full-time -> cannot have lagged 2 + # allowed[(ff == 0), 2] = False + # # never worked part-time -> cannot have lagged 1 + # allowed[(fp == 0), 1] = False + + # always employed -> cannot have lagged 0 + # allowed[(fp + ff == max_total), 0] = False + + # Build rows by expanding only allowed (pair, lagged) combinations + pair_idx, lagged = np.nonzero(allowed) + if lagged.size == 0: continue - - if period == edu_years: - # Entry-period: original code adds states for ALL lagged choices, no extra restrictions. - lagged = np.tile( - np.arange(NUM_CHOICES, dtype=np.int32), fp.size - ) - exp_p_rep = np.repeat(fp, NUM_CHOICES) - exp_f_rep = np.repeat(ff, NUM_CHOICES) - - else: - # Non-entry periods: apply the lagged-choice restrictions (vectorized). - max_ft = period + model_spec.init_exp_max - edu_years - max_pt = period + model_spec.init_exp_max - edu_years - - # allowed[c, j] means for pair c (fp[c], ff[c]) lagged choice j is allowed - allowed = np.ones((fp.size, NUM_CHOICES), dtype=bool) - - # only worked full-time -> lagged must be 2 - mask_only_ft = ff == max_ft - allowed[mask_only_ft, :] = False - allowed[mask_only_ft, 2] = True - - # only worked part-time -> lagged must be 1 - mask_only_pt = fp == max_pt - allowed[mask_only_pt, :] = False - allowed[mask_only_pt, 1] = True - - # never worked full-time -> cannot have lagged 2 - allowed[(ff == 0), 2] = False - # never worked part-time -> cannot have lagged 1 - allowed[(fp == 0), 1] = False - - # always employed -> cannot have lagged 0 - allowed[(fp + ff == max_total), 0] = False - - # Build rows by expanding only allowed (pair, lagged) combinations - pair_idx, lagged = np.nonzero(allowed) - if lagged.size == 0: - continue - exp_p_rep = fp[pair_idx] - exp_f_rep = ff[pair_idx] - lagged = lagged.astype(np.int32, copy=False) + exp_p_rep = fp[pair_idx] + exp_f_rep = ff[pair_idx] + lagged = lagged.astype(np.int32, copy=False) n = lagged.size @@ -181,9 +180,6 @@ def pyth_create_state_space(model_spec): # Fill indexer with consecutive ids [i, i+n) ids = np.arange(i, i + n, dtype=np.int32) - # IMPORTANT: age axis index must follow the original -1 -> last convention - age_index_for_indexer = age_idx # already mapped - indexer[ period, educ_level, @@ -191,7 +187,7 @@ def pyth_create_state_space(model_spec): exp_p_rep, exp_f_rep, type_, - age_index_for_indexer, + age_idx, partner, ] = ids @@ -222,11 +218,12 @@ def create_child_indexes(states, indexer, model_spec, child_age_update_rule): type_ = states[:, 5] age_kid_val = states[:, 6] partner = states[:, 7] + max_exp = np.max(exp_p) n_kid_ages = indexer.shape[6] # kid-age axis length age_idx = np.where(age_kid_val == -1, n_kid_ages - 1, age_kid_val) - parent_idx = np.where(period < (model_spec.num_periods - 1)) + parent_idx = np.where(period < (model_spec.num_periods - 1))[0] next_period = period[parent_idx] + 1 @@ -248,8 +245,8 @@ def create_child_indexes(states, indexer, model_spec, child_age_update_rule): # - choice 1: exp_p + 1 # - choice 2: exp_f + 1 for choice in range(NUM_CHOICES): - exp_part = exp_p[parent_idx] + (choice == 1) - exp_full = exp_f[parent_idx] + (choice == 2) + exp_part = np.minimum(exp_p[parent_idx] + (choice == 1), max_exp) + exp_full = np.minimum(exp_f[parent_idx] + (choice == 2), max_exp) # branch 0: use updated child age (rule), partner shock in {0,1} child_indexes[parent_idx, choice, 0, 0] = indexer[ diff --git a/soepy/solve/emaxs.py b/soepy/solve/emaxs.py index 3807208..ca1bdb3 100644 --- a/soepy/solve/emaxs.py +++ b/soepy/solve/emaxs.py @@ -81,13 +81,8 @@ def construct_emax( hours, mu, non_employment_consumption_resources, - deductions_spec, - income_tax_spec, - child_care_costs, - index_child_care_costs, - male_wages, - child_benefits, - equivalence_scales, + covariates, + model_spec, tax_splitting, ): """Simulate expected maximum utility for a given distribution of the unobservables. @@ -154,13 +149,17 @@ def max_aggregated_utilities_broadcast( non_consumption_utilities_choices, emax_choices, non_employment_consumption_resources_choices, - male_wage, - child_benefit, - equivalence, - index_child_care_cost, + covariate, draw, draw_weight, ): + # Corresponding equivalence scale for period states + male_wage = covariate[1] + equivalence_scale = covariate[2] + child_benefit = covariate[3] + + index_child_care_cost = jnp.where(covariate[0] > 2, 0, covariate[0]).astype(int) + return _get_max_aggregated_utilities( delta=delta, log_wage_systematic=log_wage_systematic_choices, @@ -171,31 +170,28 @@ def max_aggregated_utilities_broadcast( hours=hours, mu=mu, non_employment_consumption_resources=non_employment_consumption_resources_choices, - deductions_spec=deductions_spec, - income_tax_spec=income_tax_spec, + deductions_spec=model_spec.ssc_deductions, + income_tax_spec=model_spec.tax_params, male_wage=male_wage, child_benefits=child_benefit, - equivalence=equivalence, + equivalence=equivalence_scale, tax_splitting=tax_splitting, - child_care_costs=jnp.asarray(child_care_costs), + child_care_costs=model_spec.child_care_costs, child_care_bin=index_child_care_cost, ) emaxs_current_states = jax.vmap( jax.vmap( max_aggregated_utilities_broadcast, - in_axes=(None, None, None, None, None, None, None, None, 0, 0), + in_axes=(None, None, None, None, None, 0, 0), ), - in_axes=(0, 0, 0, 0, 0, 0, 0, 0, None, None), + in_axes=(0, 0, 0, 0, 0, None, None), )( log_wages_systematic, non_consumption_utilities, emax, non_employment_consumption_resources, - male_wages, - child_benefits, - equivalence_scales, - index_child_care_costs, + covariates, draws, draw_weights, ) diff --git a/soepy/solve/solve_python.py b/soepy/solve/solve_python.py index dedd378..d3bd1b3 100644 --- a/soepy/solve/solve_python.py +++ b/soepy/solve/solve_python.py @@ -2,11 +2,12 @@ import jax.numpy as jnp import numpy as np +from soepy.shared.non_consumption_utility import calculate_non_consumption_utility from soepy.shared.non_employment import calculate_non_employment_consumption_resources from soepy.shared.numerical_integration import get_integration_draws_and_weights -from soepy.shared.shared_auxiliary import calculate_log_wage -from soepy.shared.shared_auxiliary import calculate_non_consumption_utility from soepy.shared.shared_constants import HOURS +from soepy.shared.shared_constants import NUM_CHOICES +from soepy.shared.wages import calculate_log_wage from soepy.solve.emaxs import construct_emax from soepy.solve.validation_solve import construct_emax_validation @@ -62,27 +63,16 @@ def pyth_solve( num_choices contains continuation values of the state space point. Lat element contains the expected maximum value function of the state space point. """ - - draws_emax, draw_weights_emax = get_integration_draws_and_weights( - model_spec, model_params - ) - - # Solve the model in a backward induction procedure - # Error term for continuation values is integrated out - # numerically in a Monte Carlo procedure - emaxs, non_consumption_utilities = pyth_backward_induction( - model_spec, - model_spec.tax_splitting, - model_params, - states, - child_state_indexes, - draws_emax, - draw_weights_emax, - covariates, - prob_child, - prob_partner, - is_expected, + solve_func = get_solve_function( + states=states, + covariates=covariates, + child_state_indexes=child_state_indexes, + model_spec=model_spec, + prob_child=prob_child, + prob_partner=prob_partner, + is_expected=is_expected, ) + non_consumption_utilities, emaxs = solve_func(model_params) # Return function output return ( @@ -91,18 +81,122 @@ def pyth_solve( ) -def pyth_backward_induction( - model_spec, - tax_splitting, - model_params, +def get_solve_function( states, + covariates, child_state_indexes, + model_spec, + prob_child, + prob_partner, + is_expected, +): + """Return the solve function used in the model.""" + # Draw integration draws and weights for EMAX calculation + unscaled_draws_emax, draw_weights_emax = get_integration_draws_and_weights( + model_spec + ) + + tax_splitting = model_spec.tax_splitting + + # Make all arrays in model params jax arrays + # Transform model specs and model params to jax arrays + model_spec = jax.tree_util.tree_map(lambda x: try_array(x), model_spec) + + hours = jnp.array(HOURS) + + n_periods = model_spec.num_periods + n_states_per_period = int(states.shape[0] / n_periods) + + # Reshape into period-major blocks. + states_pp = jnp.asarray( + states.reshape(n_periods, n_states_per_period, states.shape[1]) + ) + covariates_pp = jnp.asarray( + covariates.reshape(n_periods, n_states_per_period, covariates.shape[1]) + ) + + child_state_indexes_pp = jnp.asarray( + child_state_indexes.reshape( + n_periods, + n_states_per_period, + child_state_indexes.shape[1], + child_state_indexes.shape[2], + child_state_indexes.shape[3], + ) + ) + + # Convert global child indices to *local indices of the next-period block*. + # This keeps the scan step free of any state-space indexing logic. + # + # For period t, child states live in period t+1 whose block starts at (t+1)*n_states_per_period. + child_state_indexes_local_pp = jnp.asarray( + child_state_indexes_pp + - (np.arange(n_periods)[:, None, None, None, None] + 1) * n_states_per_period + ) + + unscaled_draws_emax = jnp.asarray(unscaled_draws_emax) + draw_weights_emax = jnp.asarray(draw_weights_emax) + prob_child = jnp.asarray(prob_child) + prob_partner = jnp.asarray(prob_partner) + + # Generate closure + def func_to_jit( + params_arg, + states_arg, + covariates_arg, + child_state_indexes_local_arg, + unscaled_draws_emax_arg, + draw_weights_emax_arg, + prob_child_arg, + prob_partner_arg, + ): + return pyth_backward_induction( + model_params=params_arg, + states_per_period=states_arg, + covariates_per_period=covariates_arg, + child_state_indexes_local_per_period=child_state_indexes_local_arg, + draws=unscaled_draws_emax_arg * params_arg.shock_sd, + draw_weights=draw_weights_emax_arg, + prob_child=prob_child_arg, + prob_partner=prob_partner_arg, + model_spec=model_spec, + hours=hours, + is_expected=is_expected, + tax_splitting=tax_splitting, + ) + + # Create solve function to jit + def solve_function(params): + params_int = jax.tree_util.tree_map(lambda x: jnp.asarray(x), params) + + non_consumption_utilities, emaxs = jax.jit(func_to_jit)( + params_arg=params_int, + states_arg=states_pp, + covariates_arg=covariates_pp, + child_state_indexes_local_arg=child_state_indexes_local_pp, + unscaled_draws_emax_arg=unscaled_draws_emax, + draw_weights_emax_arg=draw_weights_emax, + prob_child_arg=prob_child, + prob_partner_arg=prob_partner, + ) + return non_consumption_utilities, emaxs + + return solve_function + + +def pyth_backward_induction( + model_params, + states_per_period, + covariates_per_period, + child_state_indexes_local_per_period, draws, draw_weights, - covariates, prob_child, prob_partner, + model_spec, + hours, is_expected, + tax_splitting, ): """Get expected maximum value function at every state space point. Backward induction is performed all at once for all states in a given period. @@ -140,192 +234,152 @@ def pyth_backward_induction( as its first elements. The last row element corresponds to the maximum expected value function of the state. """ + # Convert inputs once to JAX arrays (scan body stays pure and compilation-friendly). + period_specific_objects = { + "states": states_per_period, + "covariates": covariates_per_period, + "child_state_indexes_local": child_state_indexes_local_per_period, + "prob_child": prob_child, + "prob_partner": prob_partner, + } + + # Reverse time for backward induction (scan goes forward over reversed time). + period_specific_objects_rev = jax.tree_util.tree_map( + lambda a: a[::-1], period_specific_objects + ) - hours = np.array(HOURS) - non_consumption_utilities = calculate_non_consumption_utility( - model_params, - states, - covariates[:, 0], + # Initial "next-period" emaxs: terminal continuation values are zero. + emaxs_next_init = jnp.zeros( + (states_per_period.shape[1], NUM_CHOICES + 1), dtype=float ) - emaxs = np.zeros((states.shape[0], non_consumption_utilities.shape[1] + 1)) - - partial_body = jax.jit( - lambda params, period_index, emaxs_childs, prob_child_period, prob_partner_period: period_body_backward_induction( - model_params=params, - state_period_index=period_index, - emaxs_child_states=emaxs_childs, - states=jnp.asarray(states), - covariates=jnp.asarray(covariates), - non_consumption_utilities=jnp.asarray(non_consumption_utilities), - prob_child_period=prob_child_period, - prob_partner_period=prob_partner_period, - draws=jnp.asarray(draws), - draw_weights=jnp.asarray(draw_weights), - model_spec=model_spec, - is_expected=is_expected, - hours=jnp.asarray(hours), - tax_splitting=tax_splitting, + def scan_step(emaxs_next, period_data): + """One backward-induction step over a single period block. + + Carry + ----- + emaxs_next : array + Emax array for the next period in time (already computed in the scan). + + period_data : dict (pytree) + Period-specific arrays for states, covariates, transition indices, and + probability objects. + + Returns + ------- + carry : array + The current period's emax array (becomes next carry). + out : array + The current period's emax array (collected over time by scan). + """ + states_period = period_data["states"] + covariates_period = period_data["covariates"] + child_state_indexes_local = period_data["child_state_indexes_local"] + prob_child_period = period_data["prob_child"] + prob_partner_period = period_data["prob_partner"] + + # Continuation values are the maximum value function of child states. + # The child maximum value function lives in the last column (index 3). + emaxs_child_states = emaxs_next[:, 3][child_state_indexes_local] + # --------------------------------------------------------------------- + # Period reward and expectation computation. + # --------------------------------------------------------------------- + # Probability that a child arrives + prob_child_period_states = prob_child_period[states_period[:, 1]] + + # Probability of partner states. + prob_partner_period_states = prob_partner_period[ + states_period[:, 1], states_period[:, 7] + ] + + # Period rewards + log_wage_systematic_period = calculate_log_wage( + model_params, states_period, is_expected + ) + np.log(model_spec.elasticity_scale) + + non_consumption_utilities_period = calculate_non_consumption_utility( + model_params, + states_period, + covariates_period[:, 0], ) - ) - min_ind_child_period = 0 - - # Loop backwards over all periods - for period in np.arange(model_spec.num_periods - 1, -1, -1, dtype=int): - bool_ind = states[:, 0] == period - state_period_index = np.where(bool_ind)[0] - # Continuation value calculation not performed for last period - # since continuation values are known to be zero - if period == model_spec.num_periods - 1: - emaxs_child_states = jnp.zeros( - shape=(state_period_index.shape[0], 3, 2, 2), dtype=float + + non_employment_consumption_resources_period = ( + calculate_non_employment_consumption_resources( + deductions_spec=model_spec.ssc_deductions, + income_tax_spec=model_spec.tax_params, + model_spec=model_spec, + states=states_period, + log_wage_systematic=log_wage_systematic_period, + male_wage=covariates_period[:, 1], + child_benefits=covariates_period[:, 3], + tax_splitting=model_spec.tax_splitting, + hours=hours, ) - # Assign for next period the min index of current period^as the min index of the child period - min_ind_child_period = state_period_index[0] - else: - child_states_ind_period = child_state_indexes[state_period_index] - emaxs_child_states = emaxs_period[:, 3][ - child_states_ind_period - min_ind_child_period - ] - # Assign for next period the min index of current period^as the min index of the child period - min_ind_child_period = state_period_index[0] - - emaxs_period = partial_body( - params=model_params, - period_index=state_period_index, - emaxs_childs=emaxs_child_states, - prob_child_period=prob_child[period], - prob_partner_period=prob_partner[period], ) - emaxs[state_period_index] = emaxs_period - - return emaxs, non_consumption_utilities + if model_spec.parental_leave_regime == "elterngeld": + emaxs_curr = construct_emax( + delta=model_params.delta, + log_wages_systematic=log_wage_systematic_period, + non_consumption_utilities=non_consumption_utilities_period, + draws=draws, + draw_weights=draw_weights, + emaxs_child_states=emaxs_child_states, + prob_child=prob_child_period_states, + prob_partner=prob_partner_period_states, + hours=hours, + mu=model_params.mu, + non_employment_consumption_resources=non_employment_consumption_resources_period, + covariates=covariates_period, + model_spec=model_spec, + tax_splitting=tax_splitting, + ) + elif model_spec.parental_leave_regime == "erziehungsgeld": + baby_child_period = (states_period[:, 6] == 0) | (states_period[:, 6] == 1) + + emaxs_curr = construct_emax_validation( + delta=model_params.delta, + baby_child=baby_child_period, + log_wages_systematic=log_wage_systematic_period, + non_consumption_utilities=non_consumption_utilities_period, + draws=draws, + draw_weights=draw_weights, + emaxs_child_states=emaxs_child_states, + prob_child=prob_child_period_states, + prob_partner=prob_partner_period_states, + hours=hours, + mu=model_params.mu, + non_employment_consumption_resources=non_employment_consumption_resources_period, + model_spec=model_spec, + covariates=covariates_period, + tax_splitting=tax_splitting, + ) + else: + raise ValueError( + f"Parental leave regime {model_spec.parental_leave_regime} not specified." + ) + # Current period becomes the next-period carry for the following (earlier) step. + return emaxs_curr, (emaxs_curr, non_consumption_utilities_period) -def period_body_backward_induction( - model_params, - state_period_index, - emaxs_child_states, - states, - covariates, - non_consumption_utilities, - prob_child_period, - prob_partner_period, - draws, - draw_weights, - model_spec, - is_expected, - hours, - tax_splitting, -): + # Run backward induction: outputs are in reverse time order (terminal -> first). + _, (emaxs_rev, non_consumption_utilities_rev) = jax.lax.scan( + scan_step, emaxs_next_init, period_specific_objects_rev + ) - deductions_spec = model_spec.ssc_deductions - tax_params = model_spec.tax_params - child_care_costs = model_spec.child_care_costs - - erziehungsgeld_inc_single = model_spec.erziehungsgeld_income_threshold_single - erziehungsgeld_inc_married = model_spec.erziehungsgeld_income_threshold_married - erziehungsgeld = model_spec.erziehungsgeld - - # Extract period information - # States and covariates - states_period = states[state_period_index] - covariates_period = covariates[state_period_index] - non_consumption_utilities_period = non_consumption_utilities[state_period_index] - - # Corresponding equivalence scale for period states - male_wage_period = covariates_period[:, 1] - equivalence_scale_period = covariates_period[:, 2] - child_benefits_period = covariates_period[:, 3] - - index_child_care_costs_period = jnp.where( - covariates_period[:, 0] > 2, 0, covariates_period[:, 0] - ).astype(int) - - # Probability that a child arrives - prob_child_period_states = prob_child_period[states_period[:, 1]] - - # Probability of partner states. - prob_partner_period_states = prob_partner_period[ - states_period[:, 1], states_period[:, 7] - ] - - # Period rewards - log_wage_systematic_period = calculate_log_wage( - model_params, states_period, is_expected - ) + np.log(model_spec.elasticity_scale) - - non_employment_consumption_resources_period = ( - calculate_non_employment_consumption_resources( - deductions_spec=model_spec.ssc_deductions, - income_tax_spec=model_spec.tax_params, - model_spec=model_spec, - states=states_period, - log_wage_systematic=log_wage_systematic_period, - male_wage=male_wage_period, - child_benefits=child_benefits_period, - tax_splitting=model_spec.tax_splitting, - hours=hours, - ) + # Flip back to chronological order and flatten to (num_states, NUM_CHOICES + 1). + emaxs_flat = jnp.flip(emaxs_rev, axis=0).reshape(-1, emaxs_rev.shape[-1]) + non_consumption_utilities = jnp.flip(non_consumption_utilities_rev, axis=0).reshape( + -1, + non_consumption_utilities_rev.shape[2], ) - if model_spec.parental_leave_regime == "elterngeld": - # Calculate emax for current period reached by the loop - emaxs_period = construct_emax( - model_params.delta, - log_wage_systematic_period, - non_consumption_utilities_period, - draws, - draw_weights, - emaxs_child_states, - prob_child_period_states, - prob_partner_period_states, - hours, - model_params.mu, - non_employment_consumption_resources_period, - deductions_spec, - tax_params, - child_care_costs, - index_child_care_costs_period, - male_wage_period, - child_benefits_period, - equivalence_scale_period, - tax_splitting, - ) - elif model_spec.parental_leave_regime == "erziehungsgeld": - - baby_child_period = (states_period[:, 6] == 0) | (states_period[:, 6] == 1) - # Calculate emax for current period reached by the loop - emaxs_period = construct_emax_validation( - model_params.delta, - baby_child_period, - log_wage_systematic_period, - non_consumption_utilities_period, - draws, - draw_weights, - emaxs_child_states, - prob_child_period_states, - prob_partner_period_states, - hours, - model_params.mu, - non_employment_consumption_resources_period, - deductions_spec, - tax_params, - child_care_costs, - index_child_care_costs_period, - male_wage_period, - child_benefits_period, - equivalence_scale_period, - erziehungsgeld_inc_single, - erziehungsgeld_inc_married, - erziehungsgeld, - tax_splitting, - ) + return non_consumption_utilities, emaxs_flat - else: - raise ValueError( - f"Parental leave regime {model_spec.parental_leave_regime} not specified." - ) - return emaxs_period +def try_array(x): + """Try to convert x to a jax array, otherwise return x unchanged.""" + try: + return jnp.asarray(x) + except Exception: + return x diff --git a/soepy/solve/validation_solve.py b/soepy/solve/validation_solve.py index 38d7227..bff6dde 100644 --- a/soepy/solve/validation_solve.py +++ b/soepy/solve/validation_solve.py @@ -108,16 +108,8 @@ def construct_emax_validation( hours, mu, non_employment_consumption_resources, - deductions_spec, - income_tax_spec, - child_care_costs, - index_child_care_costs, - male_wages, - child_benefits, - equivalence_scales, - erziehungsgeld_inc_single, - erziehungsgeld_inc_married, - erziehungsgeld, + model_spec, + covariates, tax_splitting, ): """Simulate expected maximum utility for a given distribution of the unobservables. The function calculates the @@ -132,21 +124,23 @@ def construct_emax_validation( emaxs_child_states, prob_child, prob_partner ) # (num_states, 3) - child_care_costs_j = jnp.asarray(child_care_costs) - def max_aggregated_utilities_broadcast( log_wage_systematic_choices, non_consumption_utilities_choices, emax_choices, non_employment_consumption_resources_choice, - male_wage, - child_benefit, - equivalence, - index_child_care_cost, + covariate, baby_child_scalar, draw, draw_weight, ): + # Corresponding equivalence scale for period states + male_wage = covariate[1] + equivalence_scale = covariate[2] + child_benefit = covariate[3] + + index_child_care_cost = jnp.where(covariate[0] > 2, 0, covariate[0]).astype(int) + return _get_max_aggregated_utilities_validation( delta=delta, baby_child=baby_child_scalar, @@ -158,17 +152,17 @@ def max_aggregated_utilities_broadcast( hours=hours, mu=mu, non_employment_consumption_resources=non_employment_consumption_resources_choice, - deductions_spec=deductions_spec, - income_tax_spec=income_tax_spec, + deductions_spec=model_spec.ssc_deductions, + income_tax_spec=model_spec.tax_params, male_wage=male_wage, child_benefits=child_benefit, - equivalence=equivalence, + equivalence=equivalence_scale, tax_splitting=tax_splitting, - child_care_costs=child_care_costs_j, + child_care_costs=model_spec.child_care_costs, child_care_bin=index_child_care_cost, - erziehungsgeld_inc_single=erziehungsgeld_inc_single, - erziehungsgeld_inc_married=erziehungsgeld_inc_married, - erziehungsgeld=erziehungsgeld, + erziehungsgeld_inc_single=model_spec.erziehungsgeld_income_threshold_single, + erziehungsgeld_inc_married=model_spec.erziehungsgeld_income_threshold_married, + erziehungsgeld=model_spec.erziehungsgeld, ) # Shape expectations matching your previous rewrite: @@ -176,18 +170,15 @@ def max_aggregated_utilities_broadcast( emaxs_current_states = jax.vmap( jax.vmap( max_aggregated_utilities_broadcast, - in_axes=(None, None, None, None, None, None, None, None, None, 0, 0), + in_axes=(None, None, None, None, None, None, 0, 0), ), - in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, None, None), + in_axes=(0, 0, 0, 0, 0, 0, None, None), )( log_wages_systematic, non_consumption_utilities, emax, non_employment_consumption_resources, - male_wages, - child_benefits, - equivalence_scales, - index_child_care_costs, + covariates, baby_child, draws, draw_weights, diff --git a/soepy/test/test_child_index.py b/soepy/test/test_child_index.py index 61cad42..74b3d44 100644 --- a/soepy/test/test_child_index.py +++ b/soepy/test/test_child_index.py @@ -8,16 +8,13 @@ from soepy.exogenous_processes.partner import gen_prob_partner from soepy.pre_processing.model_processing import read_model_params_init from soepy.pre_processing.model_processing import read_model_spec_init -from soepy.shared.non_employment import calculate_non_employment_consumption_resources from soepy.shared.numerical_integration import get_integration_draws_and_weights -from soepy.shared.shared_auxiliary import calculate_log_wage -from soepy.shared.shared_auxiliary import calculate_non_consumption_utility from soepy.soepy_config import TEST_RESOURCES_DIR from soepy.solve.covariates import construct_covariates from soepy.solve.create_state_space import create_child_indexes from soepy.solve.create_state_space import pyth_create_state_space from soepy.solve.emaxs import do_weighting_emax_scalar -from soepy.solve.solve_python import pyth_backward_induction +from soepy.solve.solve_python import get_solve_function @pytest.fixture(scope="module") @@ -69,26 +66,17 @@ def input_data(): states, indexer, model_spec, child_age_update_rule ) - draws_emax, draw_weights_emax = get_integration_draws_and_weights( - model_spec, model_params - ) - # Set draws to zero to isolate the effect of child indexing - # Solve the model in a backward induction procedure - # Error term for continuation values is integrated out - # numerically in a Monte Carlo procedure - emaxs, _ = pyth_backward_induction( - model_spec=model_spec, - tax_splitting=model_spec.tax_splitting, - model_params=model_params, - states=states, - child_state_indexes=child_state_indexes, - draws=draws_emax, - draw_weights=draw_weights_emax, - covariates=covariates, - prob_child=prob_child, - prob_partner=prob_partner, + # Solve function + solve_func = get_solve_function( + states, + covariates, + child_state_indexes, + model_spec, + prob_child, + prob_partner, is_expected=True, ) + _, emaxs = solve_func(model_params) return states, emaxs, child_state_indexes, prob_child, prob_partner diff --git a/soepy/test/test_construction_emax.py b/soepy/test/test_construction_emax.py index 87f8d67..b1649b7 100644 --- a/soepy/test/test_construction_emax.py +++ b/soepy/test/test_construction_emax.py @@ -9,15 +9,13 @@ from soepy.pre_processing.model_processing import read_model_params_init from soepy.pre_processing.model_processing import read_model_spec_init from soepy.shared.non_employment import calculate_non_employment_consumption_resources -from soepy.shared.numerical_integration import get_integration_draws_and_weights -from soepy.shared.shared_auxiliary import calculate_log_wage -from soepy.shared.shared_auxiliary import calculate_non_consumption_utility from soepy.shared.shared_constants import HOURS +from soepy.shared.wages import calculate_log_wage from soepy.soepy_config import TEST_RESOURCES_DIR from soepy.solve.covariates import construct_covariates from soepy.solve.create_state_space import create_child_indexes from soepy.solve.create_state_space import pyth_create_state_space -from soepy.solve.solve_python import pyth_backward_induction +from soepy.solve.solve_python import get_solve_function @pytest.fixture(scope="module") @@ -74,28 +72,17 @@ def input_data(): states, indexer, model_spec, child_age_update_rule ) - draws_emax, draw_weights_emax = get_integration_draws_and_weights( - model_spec, model_params - ) - - draws_emax *= 0 - - # Solve the model in a backward induction procedure - # Error term for continuation values is integrated out - # numerically in a Monte Carlo procedure - emaxs, non_consumption_utilities = pyth_backward_induction( - model_spec=model_spec, - tax_splitting=model_spec.tax_splitting, - model_params=model_params, - states=states, - child_state_indexes=child_state_indexes, - draws=draws_emax, - draw_weights=draw_weights_emax, - covariates=covariates, - prob_child=prob_child, - prob_partner=prob_partner, + # Solve function + solve_func = get_solve_function( + states, + covariates, + child_state_indexes, + model_spec, + prob_child, + prob_partner, is_expected=True, ) + non_consumption_utilities, emaxs = solve_func(model_params) log_wage_systematic = calculate_log_wage(model_params, states, True) non_employment_consumption_resources = ( diff --git a/soepy/test/test_continuation_emax.py b/soepy/test/test_continuation_emax.py index 289377e..7ac8d32 100644 --- a/soepy/test/test_continuation_emax.py +++ b/soepy/test/test_continuation_emax.py @@ -8,16 +8,11 @@ from soepy.exogenous_processes.partner import gen_prob_partner from soepy.pre_processing.model_processing import read_model_params_init from soepy.pre_processing.model_processing import read_model_spec_init -from soepy.shared.non_employment import calculate_non_employment_consumption_resources -from soepy.shared.numerical_integration import get_integration_draws_and_weights -from soepy.shared.shared_auxiliary import calculate_log_wage -from soepy.shared.shared_auxiliary import calculate_non_consumption_utility -from soepy.shared.shared_constants import HOURS from soepy.soepy_config import TEST_RESOURCES_DIR from soepy.solve.covariates import construct_covariates from soepy.solve.create_state_space import create_child_indexes from soepy.solve.create_state_space import pyth_create_state_space -from soepy.solve.solve_python import pyth_backward_induction +from soepy.solve.solve_python import get_solve_function @pytest.fixture(scope="module") @@ -68,27 +63,17 @@ def input_data(): child_state_indexes = create_child_indexes( states, indexer, model_spec, child_age_update_rule ) - - draws_emax, draw_weights_emax = get_integration_draws_and_weights( - model_spec, model_params - ) - - # Solve the model in a backward induction procedure - # Error term for continuation values is integrated out - # numerically in a Monte Carlo procedure - emaxs, _ = pyth_backward_induction( - model_spec=model_spec, - tax_splitting=model_spec.tax_splitting, - model_params=model_params, - states=states, - child_state_indexes=child_state_indexes, - draws=draws_emax, - draw_weights=draw_weights_emax, - covariates=covariates, - prob_child=prob_child, - prob_partner=prob_partner, + # Solve function + solve_func = get_solve_function( + states, + covariates, + child_state_indexes, + model_spec, + prob_child, + prob_partner, is_expected=True, ) + non_consumption_utilities, emaxs = solve_func(model_params) return ( emaxs, @@ -97,6 +82,7 @@ def input_data(): prob_child, prob_partner, child_age_update_rule, + child_state_indexes, ) @@ -108,6 +94,7 @@ def test_emaxs_married(input_data): prob_child, prob_partner, child_age_update_rule, + child_state_indexes, ) = input_data # Get states from period 1, type 1, married and no kid states_selected = states[ @@ -171,6 +158,7 @@ def test_emaxs_single(input_data): prob_child, prob_partner, child_age_update_rule, + child_state_indexes, ) = input_data # Get states from period 1, type 1, not married and no kid states_selected = states[ @@ -234,8 +222,9 @@ def test_emaxs_single_with_kid(input_data): prob_child, prob_partner, child_age_update_rule, + child_state_indexes, ) = input_data - # Get states from period 1, type 1, married and kids + # Get states from period 1, type 1, single and kids states_selected = states[ (states[:, 0] == 1) & (states[:, 6] != -1) & (states[:, 7] == 0) ] @@ -264,6 +253,7 @@ def test_emaxs_single_with_kid(input_data): partner_ind, ] + # No kid arrival states emax_cstate_sep_no_kid_arr = emaxs[ indexer[ period + 1, @@ -291,6 +281,7 @@ def test_emaxs_single_with_kid(input_data): 3, ] + # Kid arrival states emax_cstate_sep_kid = emaxs[ indexer[period + 1, educ_level, 0, exp_pt, exp_ft, type_1, 0, 0], 3 ] diff --git a/soepy/test/test_nonemployment.py b/soepy/test/test_nonemployment.py index bc42d66..84869db 100644 --- a/soepy/test/test_nonemployment.py +++ b/soepy/test/test_nonemployment.py @@ -7,9 +7,9 @@ from soepy.pre_processing.model_processing import read_model_params_init from soepy.pre_processing.model_processing import read_model_spec_init from soepy.shared.non_employment import calculate_non_employment_consumption_resources -from soepy.shared.shared_auxiliary import calculate_log_wage from soepy.shared.shared_constants import HOURS from soepy.shared.tax_and_transfers import calculate_net_income +from soepy.shared.wages import calculate_log_wage from soepy.soepy_config import TEST_RESOURCES_DIR from soepy.solve.create_state_space import create_state_space_objects diff --git a/soepy/test/test_state_space.py b/soepy/test/test_state_space.py index 3b981a4..9836c21 100644 --- a/soepy/test/test_state_space.py +++ b/soepy/test/test_state_space.py @@ -6,8 +6,8 @@ from soepy.pre_processing.model_processing import read_model_params_init from soepy.pre_processing.model_processing import read_model_spec_init from soepy.shared.non_employment import calculate_non_employment_consumption_resources -from soepy.shared.shared_auxiliary import calculate_log_wage from soepy.shared.shared_constants import HOURS +from soepy.shared.wages import calculate_log_wage from soepy.soepy_config import TEST_RESOURCES_DIR from soepy.solve.create_state_space import create_state_space_objects @@ -126,9 +126,10 @@ def test_child_update_rule_aging_child(input_data): child_age_update_rule, ) = input_data aging_child = (states[:, 6] > -1) & (states[:, 6] <= model_spec.child_age_max) - np.testing.assert_array_equal( - states[aging_child][:, 6] + 1, child_age_update_rule[aging_child] - ) + new_age = states[aging_child][:, 6] + 1 + # Children that are in the last tracked age are not there any more in the next period + new_age[new_age == model_spec.child_age_max + 1] = -1 + np.testing.assert_array_equal(new_age, child_age_update_rule[aging_child]) def test_work_choices(input_data): diff --git a/soepy/test/test_unit.py b/soepy/test/test_unit.py index c3b0c4d..df59bf8 100644 --- a/soepy/test/test_unit.py +++ b/soepy/test/test_unit.py @@ -156,37 +156,6 @@ def test_unit_data_frame_shape(): np.testing.assert_array_equal(df.shape[0], shape) -def test_unit_childbearing_age(): - """This test verifies that the state space does not contain newly born children - after the last childbearing period""" - expected = 0 - - model_spec = collections.namedtuple( - "model_spec", - "num_periods num_educ_levels num_types \ - last_child_bearing_period child_age_max \ - educ_years child_age_init_max init_exp_max", - ) - - num_periods = randint(1, 11) - last_child_bearing_period = randrange(num_periods) - model_spec = model_spec( - num_periods, 3, 2, last_child_bearing_period, 10, [0, 1, 2], 4, 4 - ) - - states, _ = pyth_create_state_space(model_spec) - - np.testing.assert_equal( - sum( - states[np.where(states[:, 0] == model_spec.last_child_bearing_period + 1)][ - :, 6 - ] - == 0 - ), - expected, - ) - - def test_no_children_no_exp(): """This test ensures that i) child age equals -1 in the entire simulates sample, diff --git a/soepy/test/test_utility_components.py b/soepy/test/test_utility_components.py index 3c4f9af..bdb615e 100644 --- a/soepy/test/test_utility_components.py +++ b/soepy/test/test_utility_components.py @@ -7,7 +7,7 @@ from soepy.pre_processing.model_processing import read_model_params_init from soepy.pre_processing.model_processing import read_model_spec_init -from soepy.shared.shared_auxiliary import calculate_log_wage +from soepy.shared.wages import calculate_log_wage from soepy.soepy_config import TEST_RESOURCES_DIR from soepy.solve.create_state_space import create_state_space_objects diff --git a/soepy/test/test_validation_childless.py b/soepy/test/test_validation_childless.py index 439706e..3323066 100644 --- a/soepy/test/test_validation_childless.py +++ b/soepy/test/test_validation_childless.py @@ -143,7 +143,7 @@ def test_childless(input_data): def test_childless_emax(input_data): not_having_kids = input_data["original_covs"][:, 3] == 0 - np.testing.assert_almost_equal( + np.testing.assert_array_almost_equal( input_data["original_emax"][not_having_kids, :], input_data["validation_emax"][not_having_kids, :], )