@@ -86,6 +86,10 @@ def circular_mean(self, ary):  # pylint: disable=no-self-use
8686        """ 
8787        return  circmean (ary , high = np .pi , low = - np .pi )
8888
89+     def  _circular_standardize (self , ary ):  # pylint: disable=no-self-use 
90+         """Standardize circular data to the interval [-pi, pi].""" 
91+         return  np .mod (ary  +  np .pi , 2  *  np .pi ) -  np .pi 
92+ 
8993    def  quantile (self , ary , quantile , ** kwargs ):  # pylint: disable=no-self-use 
9094        """Compute the quantile of an array of samples. 
9195
@@ -226,20 +230,9 @@ def _histogram(self, ary, bins=None, range=None, weights=None, density=None):
226230            bins  =  self ._get_bins (ary )
227231        return  np .histogram (ary , bins = bins , range = range , weights = weights , density = density )
228232
229-     def  _hdi_linear_nearest_common (self , ary , prob , skipna , circular ):
230-         ary  =  ary .flatten ()
231-         if  skipna :
232-             nans  =  np .isnan (ary )
233-             if  not  nans .all ():
234-                 ary  =  ary [~ nans ]
233+     def  _hdi_linear_nearest_common (self , ary , prob ):  # pylint: disable=no-self-use 
235234        n  =  len (ary )
236235
237-         mean  =  None 
238-         if  circular :
239-             mean  =  self .circular_mean (ary )
240-             ary  =  ary  -  mean 
241-             ary  =  np .arctan2 (np .sin (ary ), np .cos (ary ))
242- 
243236        ary  =  np .sort (ary )
244237        interval_idx_inc  =  int (np .floor (prob  *  n ))
245238        n_intervals  =  n  -  interval_idx_inc 
@@ -249,62 +242,147 @@ def _hdi_linear_nearest_common(self, ary, prob, skipna, circular):
249242            raise  ValueError ("Too few elements for interval calculation. " )
250243
251244        min_idx  =  np .argmin (interval_width )
245+         hdi_interval  =  ary [[min_idx , min_idx  +  interval_idx_inc ]]
252246
253-         return  ary ,  mean ,  min_idx ,  interval_idx_inc 
247+         return  hdi_interval 
254248
255249    def  _hdi_nearest (self , ary , prob , circular , skipna ):
256250        """Compute HDI over the flattened array as closest samples that contain the given prob.""" 
257-         ary , mean , min_idx , interval_idx_inc  =  self ._hdi_linear_nearest_common (
258-             ary , prob , skipna , circular 
259-         )
260- 
261-         hdi_min  =  ary [min_idx ]
262-         hdi_max  =  ary [min_idx  +  interval_idx_inc ]
251+         ary  =  ary .flatten ()
252+         if  skipna :
253+             nans  =  np .isnan (ary )
254+             if  not  nans .all ():
255+                 ary  =  ary [~ nans ]
263256
264257        if  circular :
265-             hdi_min  =  hdi_min   +   mean 
266-             hdi_max  =  hdi_max   +  mean 
267-              hdi_min   =   np . arctan2 ( np . sin ( hdi_min ),  np . cos ( hdi_min )) 
268-              hdi_max   =   np . arctan2 ( np . sin ( hdi_max ),  np . cos ( hdi_max ) )
258+             mean  =  self . circular_mean ( ary ) 
259+             ary  =  self . _circular_standardize ( ary   -  mean ) 
260+ 
261+         hdi_interval   =   self . _hdi_linear_nearest_common ( ary ,  prob )
269262
270-         hdi_interval  =  np .array ([hdi_min , hdi_max ])
263+         if  circular :
264+             hdi_interval  =  self ._circular_standardize (hdi_interval  +  mean )
271265
272266        return  hdi_interval 
273267
274-     def  _hdi_multimodal (self , ary , prob , skipna , max_modes ):
268+     def  _hdi_multimodal_continuous (
269+         self , ary , prob , skipna , max_modes , circular , from_sample = False , ** kwargs 
270+     ):
275271        """Compute HDI if the distribution is multimodal.""" 
276272        ary  =  ary .flatten ()
277273        if  skipna :
278274            ary  =  ary [~ np .isnan (ary )]
279275
280-         if  ary .dtype .kind  ==  "f" :
281-             bins , density , _  =  self .kde (ary )
282-             lower , upper  =  bins [0 ], bins [- 1 ]
283-             range_x  =  upper  -  lower 
284-             dx  =  range_x  /  len (density )
276+         bins , density , _  =  self .kde (ary , circular = circular , ** kwargs )
277+         if  from_sample :
278+             ary_density  =  np .interp (ary , bins , density )
279+             hdi_intervals , interval_probs  =  self ._hdi_from_point_densities (
280+                 ary , ary_density , prob , circular 
281+             )
285282        else :
286-             bins  =  self ._get_bins (ary )
287-             density , _  =  self ._histogram (ary , bins = bins , density = True )
288-             dx  =  np .diff (bins )[0 ]
289- 
290-         density  *=  dx 
291- 
292-         idx  =  np .argsort (- density )
293-         intervals  =  bins [idx ][density [idx ].cumsum () <=  prob ]
294-         intervals .sort ()
295- 
296-         intervals_splitted  =  np .split (intervals , np .where (np .diff (intervals ) >=  dx  *  1.1 )[0 ] +  1 )
297- 
298-         hdi_intervals  =  np .full ((max_modes , 2 ), np .nan )
299-         for  i , interval  in  enumerate (intervals_splitted ):
300-             if  i  ==  max_modes :
301-                 warnings .warn (
302-                     f"found more modes than { max_modes }  , returning only the first { max_modes }   modes" 
303-                 )
304-                 break 
305-             if  interval .size  ==  0 :
306-                 hdi_intervals [i ] =  np .asarray ([bins [0 ], bins [0 ]])
307-             else :
308-                 hdi_intervals [i ] =  np .asarray ([interval [0 ], interval [- 1 ]])
309- 
310-         return  np .array (hdi_intervals )
283+             dx  =  (bins [- 1 ] -  bins [0 ]) /  (len (bins ) -  1 )
284+             bin_probs  =  density  *  dx 
285+ 
286+             hdi_intervals , interval_probs  =  self ._hdi_from_bin_probabilities (
287+                 bins , bin_probs , prob , circular , dx 
288+             )
289+ 
290+         return  self ._pad_hdi_to_maxmodes (hdi_intervals , interval_probs , max_modes )
291+ 
292+     def  _hdi_multimodal_discrete (self , ary , prob , max_modes , bins = None ):
293+         """Compute HDI if the distribution is multimodal.""" 
294+         ary  =  ary .flatten ()
295+ 
296+         if  bins  is  None :
297+             bins , counts  =  np .unique (ary , return_counts = True )
298+             bin_probs  =  counts  /  len (ary )
299+             dx  =  1 
300+         else :
301+             counts , edges  =  self ._histogram (ary , bins = bins )
302+             bins  =  0.5  *  (edges [1 :] +  edges [:- 1 ])
303+             bin_probs  =  counts  /  counts .sum ()
304+             dx  =  bins [1 ] -  bins [0 ]
305+ 
306+         hdi_intervals , interval_probs  =  self ._hdi_from_bin_probabilities (
307+             bins , bin_probs , prob , False , dx 
308+         )
309+ 
310+         return  self ._pad_hdi_to_maxmodes (hdi_intervals , interval_probs , max_modes )
311+ 
312+     def  _hdi_from_point_densities (self , points , densities , prob , circular ):
313+         if  circular :
314+             points  =  self ._circular_standardize (points )
315+ 
316+         sorted_idx  =  np .argsort (points )
317+         points  =  points [sorted_idx ]
318+         densities  =  densities [sorted_idx ]
319+ 
320+         # find idx of points in the interval 
321+         interval_size  =  int (np .ceil (prob  *  len (points )))
322+         sorted_idx  =  np .argsort (densities )[::- 1 ]
323+         idx_in_interval  =  sorted_idx [:interval_size ]
324+         idx_in_interval .sort ()
325+ 
326+         # find idx of interval bounds 
327+         probs_in_interval  =  np .full (idx_in_interval .shape , 1  /  len (points ))
328+         interval_bounds_idx , interval_probs  =  self ._interval_points_to_bounds (
329+             idx_in_interval , probs_in_interval , 1 , circular , period = len (points )
330+         )
331+ 
332+         return  points [interval_bounds_idx ], interval_probs 
333+ 
334+     def  _hdi_from_bin_probabilities (self , bins , bin_probs , prob , circular , dx ):
335+         if  circular :
336+             bins  =  self ._circular_standardize (bins )
337+             sorted_idx  =  np .argsort (bins )
338+             bins  =  bins [sorted_idx ]
339+             bin_probs  =  bin_probs [sorted_idx ]
340+ 
341+         # find idx of bins in the interval 
342+         sorted_idx  =  np .argsort (bin_probs )[::- 1 ]
343+         cum_probs  =  bin_probs [sorted_idx ].cumsum ()
344+         interval_size  =  np .searchsorted (cum_probs , prob , side = "left" ) +  1 
345+         idx_in_interval  =  sorted_idx [:interval_size ]
346+         idx_in_interval .sort ()
347+ 
348+         # get points in intervals 
349+         intervals  =  bins [idx_in_interval ]
350+         probs_in_interval  =  bin_probs [idx_in_interval ]
351+ 
352+         return  self ._interval_points_to_bounds (intervals , probs_in_interval , dx , circular )
353+ 
354+     def  _interval_points_to_bounds (self , points , probs , dx , circular , period = 2  *  np .pi ):  # pylint: disable=no-self-use 
355+         cum_probs  =  probs .cumsum ()
356+ 
357+         is_bound  =  np .diff (points ) >  dx  *  1.01 
358+         is_lower_bound  =  np .insert (is_bound , 0 , True )
359+         is_upper_bound  =  np .append (is_bound , True )
360+         interval_bounds  =  np .column_stack ([points [is_lower_bound ], points [is_upper_bound ]])
361+         interval_probs  =  (
362+             cum_probs [is_upper_bound ] -  cum_probs [is_lower_bound ] +  probs [is_lower_bound ]
363+         )
364+ 
365+         if  (
366+             circular 
367+             and  np .mod (dx  *  1.01  +  interval_bounds [- 1 , - 1 ] -  interval_bounds [0 , 0 ], period )
368+             <=  dx  *  1.01 
369+         ):
370+             interval_bounds [- 1 , 1 ] =  interval_bounds [0 , 1 ]
371+             interval_bounds  =  interval_bounds [1 :, :]
372+             interval_probs [- 1 ] +=  interval_probs [0 ]
373+             interval_probs  =  interval_probs [1 :]
374+ 
375+         return  interval_bounds , interval_probs 
376+ 
377+     def  _pad_hdi_to_maxmodes (self , hdi_intervals , interval_probs , max_modes ):  # pylint: disable=no-self-use 
378+         if  hdi_intervals .shape [0 ] >  max_modes :
379+             warnings .warn (
380+                 f"found more modes than { max_modes }  , returning only the { max_modes }   highest " 
381+                 "probability modes" 
382+             )
383+             hdi_intervals  =  hdi_intervals [np .argsort (interval_probs )[::- 1 ][:max_modes ], :]
384+         elif  hdi_intervals .shape [0 ] <  max_modes :
385+             hdi_intervals  =  np .vstack (
386+                 [hdi_intervals , np .full ((max_modes  -  hdi_intervals .shape [0 ], 2 ), np .nan )]
387+             )
388+         return  hdi_intervals 
0 commit comments