Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions soepy/shared/non_consumption_utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import jax.numpy as jnp


def calculate_non_consumption_utility(
model_params,
states,
child_bins,
):
"""Calculate non-pecuniary utility contribution.

Parameters
----------
states : np.ndarray
Shape (n_states, n_state_vars) matrix of states
child_bins : np.ndarray
Shape (n_states,) array with child bin indices for each state

Returns
-------
non_consumption_utility : np.ndarray
Shape (n_states, 3) matrix with utilities for [no work, part-time, full-time]
"""
educ = states[:, 1]
unobs_types = states[:, 5]

# Base utilities
util_pt = model_params.theta_p[unobs_types]
util_ft = model_params.theta_f[unobs_types]

# Binary selectors (implicitly cast to 0/1 in arithmetic)
b0 = child_bins == 0
b1 = child_bins == 1
b2 = child_bins == 2
b3 = child_bins == 3
b4 = child_bins > 3

# Education-dependent components
no_kids_f = model_params.no_kids_f[educ]
no_kids_p = model_params.no_kids_p[educ]
yes_kids_f = model_params.yes_kids_f[educ]
yes_kids_p = model_params.yes_kids_p[educ]

# Part-time utility (always includes full time base utility)
util_pt += (
b0 * (no_kids_f + no_kids_p)
+ b1
* (
yes_kids_f
+ yes_kids_p
+ model_params.child_0_2_f
+ model_params.child_0_2_p
)
+ b2
* (
yes_kids_f
+ yes_kids_p
+ model_params.child_3_5_f
+ model_params.child_3_5_p
)
+ b3
* (
yes_kids_f
+ yes_kids_p
+ model_params.child_6_10_f
+ model_params.child_6_10_p
)
+ b4 * (yes_kids_f + yes_kids_p)
)

# Full-time utility
util_ft += (
b0 * no_kids_f
+ b1 * (yes_kids_f + model_params.child_0_2_f)
+ b2 * (yes_kids_f + model_params.child_3_5_f)
+ b3 * (yes_kids_f + model_params.child_6_10_f)
+ b4 * yes_kids_f
)

# Stack: [no work, part-time, full-time]
non_consumption_utility = jnp.stack(
(
jnp.zeros_like(util_pt),
util_pt,
util_ft,
),
axis=1,
)

return jnp.exp(non_consumption_utility)
25 changes: 19 additions & 6 deletions soepy/shared/numerical_integration.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import numpy as np
from scipy.special import roots_hermite

from soepy.shared.shared_auxiliary import draw_disturbances


def get_integration_draws_and_weights(model_spec, model_params):
def get_integration_draws_and_weights(model_spec):
if model_spec.integration_method == "quadrature":
# Draw standard points and corresponding weights
standard_draws, draw_weights_emax = roots_hermite(model_spec.num_draws_emax)
# Rescale draws and weights
draws_emax = standard_draws * np.sqrt(2) * model_params.shock_sd
draws_emax = standard_draws * np.sqrt(2)
draw_weights_emax *= 1 / np.sqrt(np.pi)
elif model_spec.integration_method == "monte_carlo":
draws_emax = draw_disturbances(
model_spec.seed_emax, 1, model_spec.num_draws_emax, model_params
draws_emax = draw_zero_one_distributed_shocks(
model_spec.seed_emax, 1, model_spec.num_draws_emax
)[0]
draw_weights_emax = (
np.ones(model_spec.num_draws_emax) / model_spec.num_draws_emax
Expand All @@ -24,3 +22,18 @@ def get_integration_draws_and_weights(model_spec, model_params):
)

return draws_emax, draw_weights_emax


def draw_zero_one_distributed_shocks(seed, num_periods, num_draws):
"""Creates desired number of draws of a multivariate standard normal
distribution.

"""
np.random.seed(seed)

mean = 0

# Create draws from the standard normal distribution
draws = np.random.normal(mean, 1, (num_periods, num_draws))

return draws
186 changes: 0 additions & 186 deletions soepy/shared/shared_auxiliary.py

This file was deleted.

70 changes: 70 additions & 0 deletions soepy/shared/wages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
from jax import numpy as jnp


def calculate_log_wage(model_params, states, is_expected):
"""Calculate utility components for all choices given state, period, and shocks.

Parameters
----------
model_params : namedtuple
Contains all parameters of the model including information on dimensions
(number of periods, agents, random draws, etc.) and coefficients to be
estimated.
states : np.ndarray
Array with shape (num_states, 5) containing period, years of schooling,
the lagged choice, the years of experience in part-time, and the
years of experience in full-time employment.
is_expected: bool
A boolean indicator that differentiates between the human capital accumulation
process that agents expect (is_expected = True) and that the market generates
(is_expected = False)

Returns
-------
log_wage_systematic : array
One dimensional array with length num_states containing the part of the wages
at the respective state space point that do not depend on the agent's choice,
nor on the random shock.
non_consumption_utilities : np.ndarray
Array of dimension (num_states, num_choices) containing the utility
contribution of non-pecuniary factors.

"""
if is_expected:
# Calculate biased part-time expectation by using ratio from expected data and
# structural paramteters
gamma_p = (
model_params.gamma_p_bias / (model_params.gamma_p / model_params.gamma_f)
) * model_params.gamma_p
else:
gamma_p = model_params.gamma_p

log_wage_systematic = calculate_log_wage_systematic(
model_params.gamma_0,
model_params.gamma_f,
gamma_p,
states,
)

return log_wage_systematic


def calculate_log_wage_systematic(gamma_0, gamma_f, gamma_p, states):
"""Calculate systematic wages, i.e., wages net of shock, for all states."""

exp_p_states, exp_f_states = states[:, 3], states[:, 4]

log_exp_p = jnp.log(exp_p_states + 1)
log_exp_f = jnp.log(exp_f_states + 1)

# Assign wage returns
gamma_0_edu = jnp.array(gamma_0)[states[:, 1]]
gamma_f_edu = jnp.array(gamma_f)[states[:, 1]]
gamma_p_edu = jnp.array(gamma_p)[states[:, 1]]

# Calculate wage in the given state
log_wage_systematic = (
gamma_0_edu + gamma_f_edu * log_exp_f + gamma_p_edu * log_exp_p
)
return log_wage_systematic
Loading
Loading