Skip to content

Commit 4424b20

Browse files
authored
refactor internals (#2)
* renaming & surface interaction separation * refactor to use one traverse function for bvh * more gpu experiments * fix tests
1 parent f124a22 commit 4424b20

File tree

9 files changed

+324
-182
lines changed

9 files changed

+324
-182
lines changed

docs/examples.jl

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ begin
2121
viewdir = normalize(ax.scene.camera.view_direction[])
2222
end
2323

24+
hitpoints, centroid = RayCaster.get_centroid(bvh, viewdir)
25+
26+
2427
begin
2528
@time "hitpoints" hitpoints, centroid = RayCaster.get_centroid(bvh, viewdir)
2629
@time "illum" illum = RayCaster.get_illumination(bvh, viewdir)
@@ -33,10 +36,11 @@ begin
3336
f, ax, pl = mesh(world_mesh, color=:blue)
3437
per_face_vf = FaceView((viewfacts), [GLTriangleFace(i) for i in 1:N])
3538
viewfact_mesh = GeometryBasics.mesh(world_mesh, color=per_face_vf)
36-
pl = Makie.mesh(f[1, 2],
39+
pl = Makie.mesh(
40+
f[1, 2],
3741
viewfact_mesh, colormap=[:black, :red], axis=(; show_axis=false),
38-
shading=false, highclip=:red, lowclip=:black, colorscale=sqrt,)
39-
42+
shading=false, highclip=:red, lowclip=:black, colorscale=sqrt,
43+
)
4044
# Centroid
4145
cax, pl = Makie.mesh(f[2, 1], world_mesh, color=(:blue, 0.5), axis=(; show_axis=false), transparency=true)
4246

@@ -58,3 +62,113 @@ begin
5862

5963
f
6064
end
65+
66+
67+
using KernelAbstractions, Atomix
68+
69+
function random_scatter_kernel!(bvh, triangle, u, v, normal)
70+
point = RayCaster.random_triangle_point(triangle)
71+
o = point .+ (normal .* 0.01f0) # Offset so it doesn't self intersect
72+
dir = RayCaster.random_hemisphere_uniform(normal, u, v)
73+
ray = RayCaster.Ray(; o=o, d=dir)
74+
hit, prim, _ = RayCaster.intersect!(bvh, ray)
75+
return hit, prim
76+
end
77+
78+
import GeometryBasics as GB
79+
80+
@kernel function viewfact_ka_kernel!(result, bvh, primitives, rays_per_triangle)
81+
idx = @index(Global)
82+
prim_idx = ((UInt32(idx) - UInt32(1)) ÷ rays_per_triangle) + UInt32(1)
83+
if prim_idx <= length(primitives)
84+
triangle, u, v, normal = primitives[prim_idx]
85+
hit, prim = random_scatter_kernel!(bvh, triangle, u, v, normal)
86+
if hit && prim.material_idx !== triangle.material_idx
87+
# weigh by angle?
88+
Atomix.@atomic result[triangle.material_idx, prim.material_idx] += 1
89+
end
90+
end
91+
end
92+
93+
function view_factors!(result, bvh, prim_info, rays_per_triangle=10000)
94+
95+
backend = get_backend(result)
96+
workgroup = 256
97+
total_rays = length(bvh.primitives) * rays_per_triangle
98+
per_workgroup = total_rays ÷ workgroup
99+
final_rays = per_workgroup * workgroup
100+
per_triangle = final_rays ÷ length(bvh.primitives)
101+
102+
kernel = viewfact_ka_kernel!(backend, 256)
103+
kernel(result, bvh, prim_info, UInt32(per_triangle); ndrange = final_rays)
104+
return result
105+
end
106+
107+
result = zeros(UInt32, length(bvh.primitives), length(bvh.primitives))
108+
using AMDGPU
109+
prim_info = map(bvh.primitives) do triangle
110+
n = GB.orthogonal_vector(Vec3f, GB.Triangle(triangle.vertices...))
111+
normal = normalize(Vec3f(n))
112+
u, v = RayCaster.get_orthogonal_basis(normal)
113+
return triangle, u, v, normal
114+
end
115+
bvh_gpu = RayCaster.to_gpu(ROCArray, bvh)
116+
result_gpu = ROCArray(result)
117+
prim_info_gpu = ROCArray(prim_info)
118+
@time begin
119+
view_factors!(result_gpu, bvh_gpu, prim_info_gpu, 10000)
120+
KernelAbstractions.synchronize(get_backend(result_gpu))
121+
end;
122+
123+
124+
125+
@kernel function viewfact_ka_kernel2!(result, bvh, primitives, rays_per_triangle)
126+
idx = @index(Global)
127+
prim_idx = ((UInt32(idx) - UInt32(1)) ÷ rays_per_triangle) + UInt32(1)
128+
if prim_idx <= length(primitives)
129+
triangle, u, v, normal = primitives[prim_idx]
130+
hit, prim = random_scatter_kernel!(bvh, triangle, u, v, normal)
131+
if hit && prim.material_idx !== triangle.material_idx
132+
# weigh by angle?
133+
@inbounds result[idx] = UInt32(1)
134+
end
135+
end
136+
end
137+
138+
139+
function view_factors2!(result, bvh, prim_info, per_triangle)
140+
backend = get_backend(result)
141+
kernel = viewfact_ka_kernel2!(backend, 256)
142+
kernel(result, bvh, prim_info, UInt32(per_triangle); ndrange = length(result))
143+
return result
144+
end
145+
146+
147+
using AMDGPU
148+
workgroup = 256
149+
rays_per_triangle = 10000
150+
total_rays = length(bvh.primitives) * rays_per_triangle
151+
per_workgroup = total_rays ÷ workgroup
152+
final_rays = per_workgroup * workgroup
153+
per_triangle = final_rays ÷ length(bvh.primitives)
154+
result = zeros(UInt32, final_rays)
155+
156+
final_rays / 10^6
157+
158+
prim_info = map(bvh.primitives) do triangle
159+
n = GB.orthogonal_vector(Vec3f, GB.Triangle(triangle.vertices...))
160+
normal = normalize(Vec3f(n))
161+
u, v = RayCaster.get_orthogonal_basis(normal)
162+
return triangle, u, v, normal
163+
end
164+
165+
bvh_gpu = RayCaster.to_gpu(ROCArray, bvh)
166+
result_gpu = ROCArray(result)
167+
prim_info_gpu = ROCArray(prim_info)
168+
@time begin
169+
view_factors2!(result_gpu, bvh_gpu, prim_info_gpu, per_triangle)
170+
KernelAbstractions.synchronize(get_backend(result_gpu))
171+
end;
172+
173+
@time view_factors2!(result, bvh, prim_info, per_triangle)
174+
@code_warntype random_scatter_kernel!(bvh, prim_info[1]...)

src/bvh.jl

Lines changed: 119 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -239,100 +239,151 @@ end
239239
length(bvh.nodes) > Int32(0) ? bvh.nodes[1].bounds : Bounds3()
240240
end
241241

242-
@inline function intersect!(bvh::BVHAccel{P}, ray::AbstractRay) where {P}
243-
hit = false
244-
interaction = SurfaceInteraction()
242+
"""
243+
_traverse_bvh(bvh::BVHAccel{P}, ray::AbstractRay, hit_callback::F) where {P, F<:Function}
244+
245+
Internal function that traverses the BVH to find ray-primitive intersections.
246+
Uses a callback pattern to handle different intersection behaviors.
247+
248+
Arguments:
249+
- `bvh`: The BVH acceleration structure
250+
- `ray`: The ray to test for intersections
251+
- `hit_callback`: Function called when primitive is tested. Signature:
252+
hit_callback(primitive, ray) -> (continue_traversal::Bool, ray::AbstractRay, results::Any)
253+
254+
Returns:
255+
- The final result from the hit_callback
256+
"""
257+
@inline function traverse_bvh(hit_callback::F, bvh::BVHAccel{P}, ray::AbstractRay) where {P, F<:Function}
258+
# Early return if BVH is empty
259+
if length(bvh.nodes) == 0
260+
return false, ray, nothing
261+
end
262+
263+
# Prepare ray for traversal
245264
ray = check_direction(ray)
246265
inv_dir = 1f0 ./ ray.d
247266
dir_is_neg = is_dir_negative(ray.d)
248267

249-
to_visit_offset, current_node_i = Int32(1), Int32(1)
268+
# Initialize traversal stack
269+
to_visit_offset = Int32(1)
270+
current_node_idx = Int32(1)
250271
nodes_to_visit = zeros(MVector{64,Int32})
251272
primitives = bvh.primitives
252-
@_inbounds primitive = primitives[1]
253273
nodes = bvh.nodes
274+
275+
# State variables to hold callback results
276+
continue_search = true
277+
prim1 = primitives[1]
278+
result = hit_callback(prim1, ray, nothing)
279+
280+
# Traverse BVH
254281
@_inbounds while true
255-
ln = nodes[current_node_i]
256-
if intersect_p(ln.bounds, ray, inv_dir, dir_is_neg)
257-
if !ln.is_interior && ln.n_primitives > Int32(0)
258-
# Intersect ray with primitives in node.
259-
for i in Int32(0):ln.n_primitives - Int32(1)
260-
offset = ln.offset % Int32
261-
tmp_primitive = primitives[offset+i]
262-
tmp_hit, ray, tmp_interaction = intersect_p!(
263-
tmp_primitive, ray,
264-
)
265-
if tmp_hit
266-
hit = tmp_hit
267-
interaction = tmp_interaction
268-
primitive = tmp_primitive
282+
current_node = nodes[current_node_idx]
283+
284+
# Test ray against current node's bounding box
285+
if intersect_p(current_node.bounds, ray, inv_dir, dir_is_neg)
286+
if !current_node.is_interior && current_node.n_primitives > Int32(0)
287+
# Leaf node - test all primitives
288+
offset = current_node.offset % Int32
289+
290+
for i in Int32(0):(current_node.n_primitives - Int32(1))
291+
primitive = primitives[offset + i]
292+
293+
# Call the callback for this primitive
294+
continue_search, ray, result = hit_callback(primitive, ray, result)
295+
296+
# Early exit if callback requests it
297+
if !continue_search
298+
return false, ray, result
269299
end
270300
end
271-
to_visit_offset == Int32(1) && break
301+
302+
# Done with leaf, pop next node from stack
303+
if to_visit_offset == Int32(1)
304+
break
305+
end
272306
to_visit_offset -= Int32(1)
273-
current_node_i = nodes_to_visit[to_visit_offset]
307+
current_node_idx = nodes_to_visit[to_visit_offset]
274308
else
275-
if dir_is_neg[ln.split_axis] == Int32(2)
276-
nodes_to_visit[to_visit_offset] = current_node_i + Int32(1)
277-
current_node_i = ln.offset % Int32
309+
# Interior node - push children to stack
310+
if dir_is_neg[current_node.split_axis] == Int32(2)
311+
nodes_to_visit[to_visit_offset] = current_node_idx + Int32(1)
312+
current_node_idx = current_node.offset % Int32
278313
else
279-
nodes_to_visit[to_visit_offset] = ln.offset % Int32
280-
current_node_i += Int32(1)
314+
nodes_to_visit[to_visit_offset] = current_node.offset % Int32
315+
current_node_idx += Int32(1)
281316
end
282317
to_visit_offset += Int32(1)
283318
end
284319
else
285-
to_visit_offset == 1 && break
320+
# Miss - pop next node from stack
321+
if to_visit_offset == Int32(1)
322+
break
323+
end
286324
to_visit_offset -= Int32(1)
287-
current_node_i = nodes_to_visit[to_visit_offset]
325+
current_node_idx = nodes_to_visit[to_visit_offset]
288326
end
289327
end
290-
return hit, primitive, interaction
328+
329+
# Return final state
330+
return continue_search, ray, result
291331
end
292332

293-
@inline function intersect_p(bvh::BVHAccel, ray::AbstractRay)
333+
# Initialization
334+
closest_hit_callback(primitive, ray, ::Nothing) = (false, primitive, Point3f(0.0))
294335

295-
length(bvh.nodes) == Int32(0) && return false
336+
function closest_hit_callback(primitive, ray, prev_result::Tuple{Bool, P, Point3f}) where {P}
337+
# Test intersection and update if closer
338+
tmp_hit, ray, tmp_bary = intersect_p!(primitive, ray)
339+
# Always continue search to find closest
340+
return true, ray, ifelse(tmp_hit, (true, primitive, tmp_bary), prev_result)
341+
end
296342

297-
ray = check_direction(ray)
298-
inv_dir = 1f0 ./ ray.d
299-
dir_is_neg = is_dir_negative(ray.d)
343+
"""
344+
intersect!(bvh::BVHAccel{P}, ray::AbstractRay) where {P}
300345
301-
to_visit_offset, current_node_i = Int32(1), Int32(1)
302-
nodes_to_visit = zeros(MVector{64,Int32})
303-
primitives = bvh.primitives
304-
@_inbounds while true
305-
ln = bvh.nodes[current_node_i]
306-
if intersect_p(ln.bounds, ray, inv_dir, dir_is_neg)
307-
if !ln.is_interior && ln.n_primitives > Int32(0)
308-
for i in Int32(0):ln.n_primitives-Int32(1)
309-
offset = ln.offset % Int32
310-
intersect_p(
311-
primitives[offset + i], ray,
312-
) && return true
313-
end
314-
to_visit_offset == 1 && break
315-
to_visit_offset -= Int32(1)
316-
current_node_i = nodes_to_visit[to_visit_offset]
317-
else
318-
if dir_is_neg[ln.split_axis] == Int32(2)
319-
# @setindex 64 nodes_to_visit[to_visit_offset] = Int32(current_node_i + 1)
320-
nodes_to_visit[to_visit_offset] = current_node_i + Int32(1)
321-
current_node_i = ln.offset % Int32
322-
else
323-
# @setindex 64 nodes_to_visit[to_visit_offset] = Int32(ln.offset)
324-
nodes_to_visit[to_visit_offset] = ln.offset % Int32
325-
current_node_i += Int32(1)
326-
end
327-
to_visit_offset += Int32(1)
328-
end
329-
else
330-
to_visit_offset == Int32(1) && break
331-
to_visit_offset -= Int32(1)
332-
current_node_i = Int32(nodes_to_visit[to_visit_offset])
333-
end
346+
Find the closest intersection between a ray and the primitives stored in a BVH.
347+
348+
Returns:
349+
- `hit_found`: Boolean indicating if an intersection was found
350+
- `hit_primitive`: The primitive that was hit (if any)
351+
- `barycentric_coords`: Barycentric coordinates of the hit point
352+
"""
353+
@inline function intersect!(bvh::BVHAccel{P}, ray::AbstractRay) where {P}
354+
# Traverse BVH with closest-hit callback
355+
_, _, result = traverse_bvh(closest_hit_callback, bvh, ray)
356+
return result::Tuple{Bool, Triangle, Point3f}
357+
end
358+
359+
360+
any_hit_callback(primitive, current_ray, result::Nothing) = ()
361+
362+
# Define any-hit callback
363+
function any_hit_callback(primitive, current_ray, ::Tuple{})
364+
# Test for intersection
365+
if intersect_p(primitive, current_ray)
366+
# Stop traversal on first hit
367+
return false, current_ray, true
334368
end
335-
false
369+
# Continue search if no hit
370+
return true, current_ray, false
371+
end
372+
373+
"""
374+
intersect_p(bvh::BVHAccel, ray::AbstractRay)
375+
376+
Test if a ray intersects any primitive in the BVH (without finding the closest hit).
377+
378+
Returns:
379+
- `hit_found`: Boolean indicating if any intersection was found
380+
"""
381+
@inline function intersect_p(bvh::BVHAccel, ray::AbstractRay)
382+
# Traverse BVH with any-hit callback
383+
continue_search, _, result = traverse_bvh(any_hit_callback, bvh, ray)
384+
# If traversal completed without finding a hit, return false
385+
# Otherwise return the hit result (true)
386+
return !continue_search ? result : false
336387
end
337388

338389
function calculate_ray_grid_bounds(bounds::GeometryBasics.Rect, ray_direction::Vec3f)

src/kernel-abstractions.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,5 @@ end
1919
function to_gpu(ArrayType, bvh::RayCaster.BVHAccel; preserve=[])
2020
primitives = to_gpu(ArrayType, bvh.primitives; preserve=preserve)
2121
nodes = to_gpu(ArrayType, bvh.nodes; preserve=preserve)
22-
materials = to_gpu(ArrayType, to_gpu.((ArrayType,), bvh.materials; preserve=preserve); preserve=preserve)
23-
return RayCaster.BVHAccel(primitives, materials, bvh.max_node_primitives, nodes)
22+
return RayCaster.BVHAccel(primitives, bvh.max_node_primitives, nodes)
2423
end

src/kernels.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ function hits_from_grid(bvh, viewdir; grid_size=32)
1212
Threads.@threads for idx in CartesianIndices(ray_origins)
1313
o = ray_origins[idx]
1414
ray = RayCaster.Ray(; o=o, d=ray_direction)
15-
hit, prim, si = RayCaster.intersect!(bvh, ray)
16-
@inbounds result[idx] = RayHit(hit, si.core.p, prim.material_idx)
15+
hit, prim, bary = RayCaster.intersect!(bvh, ray)
16+
hitpoint = sum_mul(bary, prim.vertices)
17+
@inbounds result[idx] = RayHit(hit, hitpoint, prim.material_idx)
1718
end
1819
return result
1920
end
@@ -34,7 +35,7 @@ function view_factors!(result, bvh, rays_per_triangle=10000)
3435
point_on_triangle = random_triangle_point(triangle)
3536
o = point_on_triangle .+ (normal .* 0.01f0) # Offset so it doesn't self intersect
3637
ray = Ray(; o=o, d=random_hemisphere_uniform(normal, u, v))
37-
hit, prim, si = intersect!(bvh, ray)
38+
hit, prim, _ = intersect!(bvh, ray)
3839
if hit && prim.material_idx != triangle.material_idx
3940
# weigh by angle?
4041
result[triangle.material_idx, prim.material_idx] += 1

0 commit comments

Comments
 (0)