44from utilities_ import *
55from weighted_clustering import *
66from tqdm import tqdm , trange
7+ from sklearn .metrics import pairwise_distances
78
89'''TemporalGraph class
910minimal usage example:
@@ -35,7 +36,7 @@ class TemporalGraph():
3536 ----------
3637 time : ndarray
3738 time array (1 dim)
38- data : ndarry
39+ data : ndarray
3940 data array (n dim)
4041 clusterer : sklearn clusterer
4142 the clusterer to use for the slice-wise clustering, must accept sample_weights
@@ -46,8 +47,24 @@ class TemporalGraph():
4647 show_outliers : bool
4748 If true, include unclustered points in the graph
4849 slice_method : str
49- One of 'time' or 'data'. If time, generates N_checkpoints evenly spaced in time. If data, generates
50- N_checkpoints such that there are equal amounts of data between the points.
50+ One of 'time' or 'data'. If time, generates N_checkpoints evenly spaced in time. If data,
51+ generates N_checkpoints such that there are equal amounts of data between the points.
52+ rate_sensitivity : float
53+ A positive float, or -1. The rate parameter is raised to this parameter, so higher numbers
54+ means that the algorithm is more sensitive to changes in rate. If rate_sensivity == -1,
55+ then the rate parameter is taken log2.
56+ kernel : function
57+ A function with signiture f(t0, t, density, binwidth, epsilon=0.01, params=None).
58+ Two options are included in weighted_clustering.py, `weighted_clustering.square` and
59+ `weighted_clustering.gaussian`.
60+ kernel_parameters : tuple or None,
61+ Passed to `kernel` as params kwarg.
62+ precomputed_distances : ndarray
63+ an (n_data, n_data) array of pairwise distances between points. If None then it will
64+ be computed using `sklearn.metrics.pairwise_distances`.
65+ verbose : bool
66+ Does what you expect.
67+
5168
5269 G : networkx.classes.Digraph(Graph)
5370 The temporal graph itself.
@@ -66,6 +83,7 @@ def __init__(
6683 rate_sensitivity = 1 ,
6784 kernel = gaussian ,
6885 kernel_params = None ,
86+ precomputed_distances = None ,
6987 verbose = False ,
7088 ):
7189 if np .size (time ) != np .shape (data )[0 ]:
@@ -77,6 +95,9 @@ def __init__(
7795
7896 self .time = np .array (time )
7997 self .N_data = np .size (time )
98+ if len (data .shape ) == 1 :
99+ data = data .reshape (- 1 ,1 )
100+ self .n_components = data .shape [1 ]
80101 self .data = data
81102 self .checkpoints = checkpoints
82103 if slice_method in ['time' ,'data' ]:
@@ -107,10 +128,15 @@ def __init__(
107128 self .verbose = verbose
108129 self .disable = not verbose # tqdm
109130 self .show_outliers = False
131+ self .distances = precomputed_distances
132+ if precomputed_distances is None :
133+ self .distances = pairwise_distances (data )
110134
111135 def _compute_checkpoints (self ):
112136 if self .slice_method == 'data' :
113- checkpoints = self .time [np .linspace (0 , N_data , self .N_checkpoints + 2 )[1 :- 1 ]]
137+ idx = np .linspace (0 , self .N_data , self .N_checkpoints + 2 )[1 :- 1 ]
138+ idx = np .array ([int (x ) for x in idx ])
139+ checkpoints = self .time [idx ]
114140 if self .slice_method == 'time' :
115141 checkpoints = np .linspace (np .amin (self .time ), np .amax (self .time ), self .N_checkpoints + 2 )[1 :- 1 ]
116142 self .checkpoints = checkpoints
@@ -121,13 +147,29 @@ def _compute_densities(self):
121147 self ._compute_checkpoints ()
122148 if self .verbose :
123149 print ("Computing spatial density..." )
150+ data_width = np .mean (
151+ [np .amax (self .data [:,k ])- np .amin (self .data [:,k ])
152+ for k in range (self .data .shape [1 ])]
153+ )
124154 rates = compute_point_rates (
125155 self .data ,
126156 self .time ,
157+ self .distances ,
127158 sensitivity = self .sensitivity ,
159+ width = data_width / 10 ,
128160 )
129- self .densities = rates
130- return rates
161+ iso_idx = (rates == np .inf )
162+ nisolated = np .size ((iso_idx ).nonzero ())
163+ if nisolated != 0 :
164+ print (f'Warning: You have { nisolated } isolated points. If this is a small number, its probably fine.' )
165+ densities = 1 / rates
166+ densities = sigmoid (densities , np .median (densities ))
167+ if self .sensitivity == - 1 :
168+ self .densities = 1 / (1 - np .log2 (densities ))
169+ else :
170+ self .densities = densities ** self .sensitivity
171+ self .densities [iso_idx ] = 0
172+ return self .densities
131173
132174 def _cluster (self ):
133175 if self .densities is None :
@@ -288,19 +330,25 @@ def populate_edge_attrs(self):
288330 self .G [u ][v ]['dst_weight' ] = percentage_inweight
289331
290332 def populate_node_attrs (self , cmap = None , labels = None ):
333+ pos = False #todo fix
334+ if self .n_components == 2 :
335+ pos = True
291336 # Add colours and cluster name labels to the vertices.
292337 t_attrs = nx .get_node_attributes (self .G , 'slice_no' )
293338 cl_attrs = nx .get_node_attributes (self .G , 'cluster_no' )
294- avg_xpos = compute_cluster_yaxis (self .clusters , self .data [:,0 ])
295- avg_ypos = compute_cluster_yaxis (self .clusters , self .data [:,1 ])
339+ if pos :
340+ avg_xpos = compute_cluster_yaxis (self .clusters , self .data [:,0 ])
341+ avg_ypos = compute_cluster_yaxis (self .clusters , self .data [:,1 ])
296342 clr_list = {}
297343 size_list = {}
298344 pos_list = {}
299345 for node in self .G .nodes ():
300346 t_idx = t_attrs [node ]
301347 cl_idx = cl_attrs [node ]
302- node_xpos = avg_xpos [t_idx ][cl_idx ]
303- node_ypos = avg_ypos [t_idx ][cl_idx ]
348+ if pos :
349+ node_xpos = avg_xpos [t_idx ][cl_idx ]
350+ node_ypos = avg_ypos [t_idx ][cl_idx ]
351+ pos_list [node ] = (node_xpos , node_ypos )
304352 if cmap :
305353 clr = cmap (node_xpos , node_ypos )/ 255
306354 else :
@@ -309,11 +357,12 @@ def populate_node_attrs(self, cmap=None, labels=None):
309357
310358 size = np .size (self .get_vertex_data (node ))
311359 size_list [node ] = size
312- pos_list [ node ] = ( node_xpos , node_ypos )
360+
313361
314362 nx .set_node_attributes (self .G , clr_list , "colour" )
315363 nx .set_node_attributes (self .G , size_list , "count" )
316- nx .set_node_attributes (self .G , pos_list , "pos" )
364+ if pos :
365+ nx .set_node_attributes (self .G , pos_list , "pos" )
317366
318367 def get_vertex_data (self , node , ghost = False ):
319368 t_idx = self .G .nodes ()[node ]['slice_no' ]
@@ -336,19 +385,17 @@ def generate_plot(self, label_edges = True, threshold = 0.48, vertices = None):
336385 nx .draw_networkx_edges (
337386 G , pos , edgelist = esmall , width = 0.5 * edge_width , alpha = 0.5 , edge_color = "b" , style = "dashed"
338387 )
339- nx .draw_networkx_edges (G , pos , edgelist = elarge , width = 3 )
388+ nx .draw_networkx_edges (G , pos , edgelist = elarge , width = 1 , arrows = False )
340389 if label_edges :
341390 edge_labels = nx .get_edge_attributes (G , "weight" )
342391 nx .draw_networkx_edge_labels (G , pos , edge_labels )
343392
344- node_size = [np .size (self .get_vertex_data (node )) for node in vertices ]
345- avg_xpos = compute_cluster_yaxis (self .clusters , self .data [:,0 ])
346- avg_ypos = compute_cluster_yaxis (self .clusters , self .data [:,1 ])
393+ node_size = [5 * np .log2 (np .size (self .get_vertex_data (node ))) for node in vertices ]
347394 clr_dict = nx .get_node_attributes (self .G , 'colour' )
348395 node_clr = [clr_dict [node ] for node in vertices ]
349396
350- nx .draw_networkx_nodes (G , pos , node_size = node_size ,node_color = node_clr )
351- nx .draw_networkx_labels (G , pos )
397+ nx .draw_networkx_nodes (G , pos , node_size = node_size , node_color = node_clr )
398+ # nx.draw_networkx_labels(G, pos)
352399 ax = plt .gca ()
353400
354401 return ax
0 commit comments