Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions bmtk/simulator/core/simulator_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,77 @@ def build_recurrent_edges(self, **opts):
def build_virtual_connections(self):
raise NotImplementedError()

def to_multigraph(self, source=None, target=None, edge_attrs=['nsyns'], **filter_attrs):
import networkx as nx

# Get node population(s) and node_id(s) that match "source" query. If no source filter
# then just leave blank and assume all nodes should be returned.
src_pops = [e._edge_pop.source_population for e in self._edge_populations]
src_node_ids = None
if source is not None:
src_node_set = self.get_node_set(source)
src_pops = src_node_set.population_names()
src_node_ids = src_node_set.node_ids

# Like above get node population(s)/node_id(s) for the given "target" query.
trg_pops = [e._edge_pop.target_population for e in self._edge_populations]
trg_node_ids = None
if target is not None:
trg_node_set = self.get_node_set(target)
trg_pops = trg_node_set.population_names()
trg_node_ids = trg_node_set.node_ids

# Iterative through all SONATA edge populations in network
edges_df = pd.DataFrame()
for epop in self._edge_populations:
# Skip if not valid source and target names.
if epop._edge_pop.source_population not in src_pops or epop._edge_pop.target_population not in trg_pops:
continue

# convert edges to pd dataframe
pop_df = epop._edge_pop.to_dataframe()
if 'nsyns' not in pop_df.columns:
pop_df['nsyns'] = 1

# Filter by source and/or target nodes, if applicable.
if src_node_ids:
pop_df = pop_df[pop_df['source_node_id'].isin(src_node_ids)]

if trg_node_ids:
pop_df = pop_df[pop_df['target_node_id'].isin(trg_node_ids)]

# Filter by edge attributes, if applicable.
for k, v in filter_attrs.items():
if pop_df is None or len(pop_df) == 0:
break

if k not in pop_df.columns:
pop_df = None
elif isinstance(v, (list, tuple)):
pop_df = pop_df[pop_df[k].isin(v)]
else:
pop_df = pop_df[pop_df[k] == v]

if pop_df is not None:
pop_df = pop_df[['source_population', 'source_node_id', 'target_population', 'target_node_id'] + edge_attrs]
edges_df = pd.concat([edges_df, pop_df], sort=False)

if edges_df is None or len(edges_df) == 0:
return nx.MultiDiGraph()

edges_df = edges_df.assign(
source_nodes=lambda df: list(zip(df['source_population'], df['source_node_id'].astype(int))),
target_nodes=lambda df: list(zip(df['target_population'], df['target_node_id'].astype(int))),
)

return nx.from_pandas_edgelist(
edges_df,
source='source_nodes',
target='target_nodes',
edge_attr=edge_attrs,
create_using=nx.MultiDiGraph()
)

@classmethod
def from_config(cls, conf, **properties):
"""Generates a graph structure from a json config file or dictionary.
Expand Down
8 changes: 7 additions & 1 deletion bmtk/utils/sonata/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,14 @@ def edge_types_table(self):
return self._types_table

def to_dataframe(self):
raise NotImplementedError()
ret_df = pd.DataFrame()
for grp_id in self.group_ids:
grp_df = self.get_group(grp_id).to_dataframe()

ret_df = pd.concat([ret_df, grp_df], sort=False)
ret_df['source_population'] = self.source_population
ret_df['target_population'] = self.target_population
return ret_df

def build_indicies(self):
indicies_grp = None
Expand Down
Loading