Skip to content

Commit 246b761

Browse files
vchuravygbaraldi
andauthored
Fix deferred_codegen registration (#711)
Co-authored-by: Gabriel Baraldi <[email protected]>
1 parent aa05fa3 commit 246b761

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

src/GPUCompiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ function __init__()
6666
global compile_cache = dir
6767

6868
Tracy.@register_tracepoints()
69+
register_deferred_codegen()
6970
end
7071

7172
end # module

src/driver.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ const deferred_codegen_jobs = Dict{Int, Any}()
129129

130130
# We make this function explicitly callable so that we can drive OrcJIT's
131131
# lazy compilation from, while also enabling recursive compilation.
132-
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
132+
# see `register_deferred_codegen`
133+
function deferred_codegen(ptr::Ptr{Cvoid})::Ptr{Cvoid}
133134
ptr
134135
end
135136

@@ -149,6 +150,33 @@ end
149150
end
150151
end
151152

153+
# Register deferred_codegen as a global function so that it can be called with `ccall("extern deferred_codegen"`
154+
# Called from __init__
155+
# On 1.11+ this is needed due to a Julia bug that drops the pointer when code-coverage is enabled.
156+
function register_deferred_codegen()
157+
@dispose jljit=JuliaOJIT() begin
158+
jd = JITDylib(jljit)
159+
160+
address = LLVM.API.LLVMOrcJITTargetAddress(
161+
reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},))))
162+
flags = LLVM.API.LLVMJITSymbolFlags(
163+
LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)
164+
name = mangle(jljit, "deferred_codegen")
165+
symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags)
166+
map = if LLVM.version() >= v"15"
167+
LLVM.API.LLVMOrcCSymbolMapPair(name, symbol)
168+
else
169+
LLVM.API.LLVMJITCSymbolMapPair(name, symbol)
170+
end
171+
172+
mu = LLVM.absolute_symbols(Ref(map))
173+
LLVM.define(jd, mu)
174+
addr = lookup(jljit, "deferred_codegen")
175+
@assert addr != C_NULL "Failed to register deferred_codegen"
176+
end
177+
return nothing
178+
end
179+
152180
const __llvm_initialized = Ref(false)
153181

154182
@locked function emit_llvm(@nospecialize(job::CompilerJob); kwargs...)

test/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,6 @@ end
190190

191191
@testset "Mock Enzyme" begin
192192
Enzyme.deferred_codegen_id(typeof(identity), Tuple{Vector{Float64}})
193+
# Check that we can call this function from the CPU, to support deferred codegen for Enzyme.
194+
@test ccall("extern deferred_codegen", llvmcall, UInt, (UInt,), 3) == 3
193195
end

0 commit comments

Comments
 (0)