1010from graphviz import Digraph
1111
1212from neuromllite .utils import evaluate
13-
13+
14+ from pyneuroml .pynml import convert_to_units
1415
1516class GraphVizHandler (DefaultNetworkHandler ):
1617
18+ CUTOFF_INH_SYN_MV = - 50 # erev below -50mV => inhibitory, above => excitatory
19+
1720 positions = {}
1821 pop_indices = {}
1922
@@ -22,6 +25,7 @@ class GraphVizHandler(DefaultNetworkHandler):
2225
2326 proj_weights = {}
2427 proj_shapes = {}
28+ proj_lines = {}
2529 proj_pre_pops = {}
2630 proj_post_pops = {}
2731 proj_conns = {}
@@ -51,8 +55,9 @@ def finalise_document(self):
5155 lweight = 0.5 + fweight * 2.0
5256
5357 if self .level >= 2 :
54- print ("%s: weight %s -> %s; fw: %s; lw: %s" % (projName , self .max_weight ,self .min_weight ,fweight ,lweight ))
58+ # print("%s: weight %s -> %s; fw: %s; lw: %s"%(projName, self.max_weight,self.min_weight,fweight,lweight))
5559 self .f .attr ('edge' ,
60+ style = self .proj_lines [projName ],
5661 arrowhead = self .proj_shapes [projName ],
5762 arrowsize = '%s' % (min (1 ,lweight )),
5863 penwidth = '%s' % (lweight ),
@@ -119,7 +124,7 @@ def handle_population(self, population_id, component, size=-1, component_obj=Non
119124 else :
120125 fcolor = '#ffffff'
121126
122- print ('Color %s -> %s -> %s' % (properties ['color' ], rgb , color ))
127+ # print('Color %s -> %s -> %s'%(properties['color'], rgb, color))
123128
124129 if properties and 'type' in properties :
125130 self .pop_types [population_id ] = properties ['type' ]
@@ -132,7 +137,6 @@ def handle_population(self, population_id, component, size=-1, component_obj=Non
132137 if self .level >= 4 :
133138
134139 from neuroml import SpikeSourcePoisson
135- from pyneuroml .pynml import convert_to_units
136140
137141 if component_obj and isinstance (component_obj ,SpikeSourcePoisson ):
138142 start = convert_to_units (component_obj .start , 'ms' )
@@ -177,10 +181,7 @@ def handle_location(self, id, population_id, component, x, y, z):
177181 def handle_projection (self , projName , prePop , postPop , synapse , hasWeights = False , hasDelays = False , type = "projection" , synapse_obj = None , pre_synapse_obj = None ):
178182
179183 shape = 'normal'
180- '''
181- if synapse_obj:
182- print synapse_obj.erev
183- shape = 'dot'''
184+ line = 'normal'
184185
185186 weight = 1.0
186187 self .proj_pre_pops [projName ] = prePop
@@ -190,15 +191,20 @@ def handle_projection(self, projName, prePop, postPop, synapse, hasWeights=False
190191 if 'I' in self .pop_types [prePop ]:
191192 shape = 'dot'
192193
194+ if type == 'electricalProjection' :
195+ shape = 'none'
196+ line = 'dashed'
197+
198+ if synapse_obj :
199+
200+ if hasattr (synapse_obj ,'erev' ) and convert_to_units (synapse_obj .erev ,'mV' )< self .CUTOFF_INH_SYN_MV :
201+ shape = 'dot'
202+
193203 if self .nl_network :
194- #print synapse
195- #print self.nl_network.synapses
196204 syn = self .nl_network .get_child (synapse ,'synapses' )
197205 if syn :
198- #print syn
199206 if syn .parameters :
200- #print syn.parameters
201- if 'e_rev' in syn .parameters and syn .parameters ['e_rev' ]< - 50 :
207+ if 'e_rev' in syn .parameters and syn .parameters ['e_rev' ]< self .CUTOFF_INH_SYN_MV :
202208 shape = 'dot'
203209
204210 proj = self .nl_network .get_child (projName ,'projections' )
@@ -210,13 +216,13 @@ def handle_projection(self, projName, prePop, postPop, synapse, hasWeights=False
210216 weight = abs (proj_weight )
211217 if proj .random_connectivity :
212218 weight *= proj .random_connectivity .probability
213- #print 'w: %s'%weight
214219
215220 self .max_weight = max (self .max_weight , weight )
216221 self .min_weight = min (self .min_weight , weight )
217222
218223 self .proj_weights [projName ] = weight
219224 self .proj_shapes [projName ] = shape
225+ self .proj_lines [projName ] = line
220226 self .proj_conns [projName ] = 0
221227
222228
@@ -237,9 +243,7 @@ def finalise_projection(self, projName, prePop, postPop, synapse=None, type="pro
237243
238244 print_v ("Projection finalising: " + projName + " from " + prePop + " to " + postPop + " completed" )
239245
240-
241246
242- '''
243247 #
244248 # Should be overridden to create input source array
245249 #
@@ -251,8 +255,54 @@ def handle_input_list(self, inputListId, population_id, component, size, input_c
251255 self .log .error ("Error! Need a size attribute in sites element to create spike source!" )
252256 return
253257
254- self.input_info[inputListId] = (population_id, component)
258+ if self .level >= 2 :
259+
260+ label = '<%s' % inputListId
261+ if self .level >= 3 :
262+ label += '<br/><i>%s input%s</i>' % ( size , '' if size == 1 else 's' )
263+ if self .level >= 4 :
264+
265+ from neuroml import PulseGenerator
266+ from neuroml import TransientPoissonFiringSynapse
267+ from neuroml import PoissonFiringSynapse
268+ from pyneuroml .pynml import convert_to_units
269+
270+ if input_comp_obj and isinstance (input_comp_obj ,PulseGenerator ):
271+ start = convert_to_units (input_comp_obj .delay , 'ms' )
272+ if start == int (start ): start = int (start )
273+ duration = convert_to_units (input_comp_obj .duration ,'ms' )
274+ if duration == int (duration ): duration = int (duration )
275+ amplitude = convert_to_units (input_comp_obj .amplitude ,'pA' )
276+ if amplitude == int (amplitude ): amplitude = int (amplitude )
277+
278+ label += '<br/>Pulse %s-%sms @ %spA' % (start ,start + duration , amplitude )
279+
280+ if input_comp_obj and isinstance (input_comp_obj ,PoissonFiringSynapse ):
281+
282+ average_rate = convert_to_units (input_comp_obj .average_rate ,'Hz' )
283+ if average_rate == int (average_rate ): average_rate = int (average_rate )
284+
285+ label += '<br/>Syn: %s @ %sHz' % (input_comp_obj .synapse , average_rate )
286+
287+ if input_comp_obj and isinstance (input_comp_obj ,TransientPoissonFiringSynapse ):
288+
289+ start = convert_to_units (input_comp_obj .delay , 'ms' )
290+ if start == int (start ): start = int (start )
291+ duration = convert_to_units (input_comp_obj .duration ,'ms' )
292+ if duration == int (duration ): duration = int (duration )
293+ average_rate = convert_to_units (input_comp_obj .average_rate ,'Hz' )
294+ if average_rate == int (average_rate ): average_rate = int (average_rate )
295+
296+ label += '<br/>Syn: %s %s-%sms @ %sHz' % (input_comp_obj .synapse ,start ,start + duration , average_rate )
297+
298+ label += '>'
299+
300+ self .f .attr ('node' , color = '#444444' , style = '' , fontcolor = '#444444' )
301+ self .f .node (inputListId , label = label )
302+ self .f .edge (inputListId , population_id , arrowhead = 'empty' )
303+
255304
305+ '''
256306 #
257307 # Should be overridden to to connect each input to the target cell
258308 #
0 commit comments