@@ -121,30 +121,43 @@ def fit(self, X, y=None, sample_weight=None):
121121 # Triangle must be cumulative and in "development" mode
122122 obj = self ._set_fit_groups (X ).incr_to_cum ().val_to_dev ().copy ()
123123 xp = obj .get_array_module ()
124+
124125 if self .fillna :
125126 tri_array = num_to_nan ((obj + self .fillna ).values )
126127 else :
127128 tri_array = num_to_nan (obj .values .copy ())
128- average_ = self ._validate_assumption (X , self .average , axis = 3 )[... , :X .shape [3 ]- 1 ]
129+
130+ average_ = self ._validate_assumption (X , self .average , axis = 3 )[
131+ ..., : X .shape [3 ] - 1
132+ ]
129133 self .average_ = average_ .flatten ()
130- n_periods_ = self ._validate_assumption (X , self .n_periods , axis = 3 )[... , :X .shape [3 ]- 1 ]
134+ n_periods_ = self ._validate_assumption (X , self .n_periods , axis = 3 )[
135+ ..., : X .shape [3 ] - 1
136+ ]
131137 x , y = tri_array [..., :- 1 ], tri_array [..., 1 :]
132138 exponent = xp .array (
133- [{"regression" : 0 , "volume" : 1 , "simple" : 2 }[x ]
134- for x in average_ [0 , 0 , 0 ]]
139+ [{"regression" : 0 , "volume" : 1 , "simple" : 2 }[x ] for x in average_ [0 , 0 , 0 ]]
135140 )
136141 exponent = xp .nan_to_num (exponent * (y * 0 + 1 ))
137142 link_ratio = y / x
138143
139144 if hasattr (X , "w_v2_" ):
140- self .w_v2_ = self ._set_weight_func (obj .age_to_age * X .w_v2_ ,obj .iloc [...,:- 1 ,:- 1 ])
145+ self .w_v2_ = self ._set_weight_func (
146+ factor = obj .age_to_age * X .w_v2_ ,
147+ # secondary_rank=obj.iloc[..., :-1, :-1]
148+ )
141149 else :
142- self .w_v2_ = self ._set_weight_func (obj .age_to_age ,obj .iloc [...,:- 1 ,:- 1 ])
150+ self .w_v2_ = self ._set_weight_func (
151+ factor = obj .age_to_age ,
152+ # secondary_rank=obj.iloc[..., :-1, :-1]
153+ )
154+
143155 self .w_ = self ._assign_n_periods_weight (
144156 obj , n_periods_
145157 ) * self ._drop_adjustment (obj , link_ratio )
146158 w = num_to_nan (self .w_ / (x ** (exponent )))
147159 params = WeightedRegression (axis = 2 , thru_orig = True , xp = xp ).fit (x , y , w )
160+
148161 if self .n_periods != 1 :
149162 params = params .sigma_fill (self .sigma_interpolation )
150163 else :
@@ -153,20 +166,23 @@ def fit(self, X, y=None, sample_weight=None):
153166 "of freedom to support calculation of all regression"
154167 " statistics. Only LDFs have been calculated."
155168 )
169+
156170 params .std_err_ = xp .nan_to_num (params .std_err_ ) + xp .nan_to_num (
157171 (1 - xp .nan_to_num (params .std_err_ * 0 + 1 ))
158172 * params .sigma_
159173 / xp .swapaxes (xp .sqrt (x ** (2 - exponent ))[..., 0 :1 , :], - 1 , - 2 )
160174 )
175+
161176 params = xp .concatenate ((params .slope_ , params .sigma_ , params .std_err_ ), 3 )
162177 params = xp .swapaxes (params , 2 , 3 )
163178 self .ldf_ = self ._param_property (obj , params , 0 )
164179 self .sigma_ = self ._param_property (obj , params , 1 )
165180 self .std_err_ = self ._param_property (obj , params , 2 )
166181 resid = - obj .iloc [..., :- 1 ] * self .ldf_ .values + obj .iloc [..., 1 :].values
167- std = xp .sqrt ((1 / num_to_nan (w )) * (self .sigma_ ** 2 ).values )
182+ std = xp .sqrt ((1 / num_to_nan (w )) * (self .sigma_ ** 2 ).values )
168183 resid = resid / num_to_nan (std )
169184 self .std_residuals_ = resid [resid .valuation < obj .valuation_date ]
185+
170186 return self
171187
172188 def transform (self , X ):
@@ -184,10 +200,21 @@ def transform(self, X):
184200 """
185201 X_new = X .copy ()
186202 X_new .group_index = self ._set_transform_groups (X_new )
187- triangles = ["std_err_" , "ldf_" , "sigma_" ,"std_residuals_" ,"average_" , "w_" , "sigma_interpolation" ,"w_v2_" ]
203+ triangles = [
204+ "std_err_" ,
205+ "ldf_" ,
206+ "sigma_" ,
207+ "std_residuals_" ,
208+ "average_" ,
209+ "w_" ,
210+ "sigma_interpolation" ,
211+ "w_v2_" ,
212+ ]
188213 for item in triangles :
189214 setattr (X_new , item , getattr (self , item ))
215+
190216 X_new ._set_slicers ()
217+
191218 return X_new
192219
193220 def _param_property (self , X , params , idx ):
@@ -202,4 +229,5 @@ def _param_property(self, X, params, idx):
202229 obj .is_cumulative = False
203230 obj .virtual_columns .columns = {}
204231 obj ._set_slicers ()
232+
205233 return obj
0 commit comments