Skip to content

Commit 5859d3d

Browse files
committed
feat: add induced_subgraph functionality
1 parent a034753 commit 5859d3d

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

GNNGraphs/src/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ export rand_graph,
9797

9898
include("sampling.jl")
9999
export sample_neighbors
100+
export induced_subgraph
100101

101102
include("operators.jl")
102103
# Base.intersect

GNNGraphs/src/sampling.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,53 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1;
116116
end
117117
return gnew
118118
end
119+
120+
"""
121+
induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) -> GNNGraph
122+
123+
Generates a subgraph from the original graph using the provided `nodes`.
124+
The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph.
125+
If a node has no neighbors, an isolated node will be added to the subgraph.
126+
127+
# Arguments:
128+
- `graph::GNNGraph`: The original graph containing nodes, edges, and node features.
129+
- `nodes::Vector{Int}`: A vector of node indices to include in the subgraph.
130+
131+
# Returns:
132+
A new `GNNGraph` containing the subgraph with the specified nodes and their features.
133+
"""
134+
function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
135+
if isempty(nodes)
136+
return GNNGraph() # Return empty graph if no nodes are provided
137+
end
138+
139+
node_map = Dict(node => i for (i, node) in enumerate(nodes))
140+
141+
# Collect edges to add
142+
source = Int[]
143+
target = Int[]
144+
backup_gnn = GNNGraph()
145+
for node in nodes
146+
neighbors = Graphs.neighbors(graph, node, dir = :in)
147+
if isempty(neighbors)
148+
backup_gnn = add_nodes(backup_gnn, 1)
149+
end
150+
for neighbor in neighbors
151+
if neighbor in keys(node_map)
152+
push!(source, node_map[node])
153+
push!(target, node_map[neighbor])
154+
end
155+
end
156+
end
157+
158+
# Extract features for the new nodes
159+
#new_features = graph.x[:, nodes]
160+
161+
if isempty(source) && isempty(target)
162+
#backup_gnn.ndata.x = new_features ### TODO fix & add edges data (probably push themto the new vector?)
163+
return backup_gnn # Return empty graph if no nodes are provided
164+
end
165+
166+
return GNNGraph(source, target)
167+
#, ndata = new_features) # Return the new GNNGraph with subgraph and features
168+
end

GNNGraphs/test/sampling.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,20 @@ if GRAPH_T == :coo
4545
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
4646
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
4747
end
48+
49+
@testset "induced_subgraph" begin
50+
# Create a simple GNNGraph with two nodes and one edge
51+
graph = GNNGraph() # Initialize graph
52+
add_nodes!(graph, 2) # Add 2 nodes
53+
add_edge!(graph, 1, 2) # Add an edge from node 1 to node 2
54+
graph.x = rand(10, 2) # Assign random features to both nodes (10 features per node)
55+
56+
# Induce subgraph on both nodes
57+
nodes = [1, 2]
58+
subgraph = induced_subgraph(graph, nodes)
59+
60+
@test num_nodes(subgraph) == 2 # Subgraph should have 2 nodes
61+
@test num_nodes(subgraph) == 1 # Subgraph should have 1 edge
62+
### TODO @test subgraph.ndata.x == graph.x[:, nodes] # Features should match the original graph
63+
end
4864
end

0 commit comments

Comments
 (0)