Skip to content

Commit 7fbac1c

Browse files
mhauruyebai
andauthored
Support for Julia v1.12 (#196)
* Various fixes for Julia v1.12. Work in progress. * Fix a typo * Fix opaque_closure for v1.12 * Work around Julia issue 59222 * Update Julia version in CI * Fixes to opaque_closure Copied over from chalk-lab/Mooncake.jl#714 * Optimise opaque closures on 1.12 Copied over from @serenity4's work in chalk-lab/Mooncake.jl#714 * Bump patch version to 0.9.6, turn NEWS.md into a HISTORY.md * Refactor to simplify * More refactoring * Remove/update out-of-date comments * Add warning on Julia v1.12.0 --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 0aa14c1 commit 7fbac1c

File tree

7 files changed

+224
-59
lines changed

7 files changed

+224
-59
lines changed

.github/workflows/Testing.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ jobs:
1111
strategy:
1212
matrix:
1313
version:
14-
- '1.10'
14+
- 'min'
1515
- '1'
16-
- 'pre'
16+
# TODO(mhauru) Reenable the below once there is a 'pre' version different from '1'.
17+
# - 'pre'
1718
os:
1819
- ubuntu-latest
1920
- windows-latest

HISTORY.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# 0.9.6
2+
3+
Add support for Julia v1.12.
4+
5+
# 0.9.0
6+
7+
From version 0.9.0, the old `TArray` and `TRef` types are completely removed, where previously they were only deprecated. Additionally, the internals have been completely overhauled, and the public interface more precisely defined. See the docs for more info.
8+
9+
# 0.6.0
10+
11+
From v0.6.0 Libtask is implemented by recording all the computing to a tape and copying that tape. Before that version, it is based on a tricky hack on the Julia internals. You can check the commit history of this repo to see the details.

NEWS.md

Lines changed: 0 additions & 8 deletions
This file was deleted.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
33
license = "MIT"
44
desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
6-
version = "0.9.5"
6+
version = "0.9.6"
77

88
[deps]
99
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"

src/bbcode.jl

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -140,22 +140,44 @@ end
140140

141141
collect_stmts(bb::BBlock)::Vector{IDInstPair} = collect(zip(bb.inst_ids, bb.insts))
142142

143-
struct BBCode
144-
blocks::Vector{BBlock}
145-
argtypes::Vector{Any}
146-
sptypes::Vector{CC.VarState}
147-
linetable::Vector{Core.LineInfoNode}
148-
meta::Vector{Expr}
149-
end
143+
@static if VERSION >= v"1.12-"
144+
struct BBCode
145+
blocks::Vector{BBlock}
146+
argtypes::Vector{Any}
147+
sptypes::Vector{CC.VarState}
148+
debuginfo::CC.DebugInfoStream
149+
meta::Vector{Expr}
150+
valid_worlds::CC.WorldRange
151+
end
150152

151-
function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock})
152-
return BBCode(
153-
new_blocks,
154-
CC.copy(ir.argtypes),
155-
CC.copy(ir.sptypes),
156-
CC.copy(ir.linetable),
157-
CC.copy(ir.meta),
158-
)
153+
function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock})
154+
return BBCode(
155+
new_blocks,
156+
CC.copy(ir.argtypes),
157+
CC.copy(ir.sptypes),
158+
CC.copy(ir.debuginfo),
159+
CC.copy(ir.meta),
160+
ir.valid_worlds,
161+
)
162+
end
163+
else
164+
struct BBCode
165+
blocks::Vector{BBlock}
166+
argtypes::Vector{Any}
167+
sptypes::Vector{CC.VarState}
168+
linetable::Vector{Core.LineInfoNode}
169+
meta::Vector{Expr}
170+
end
171+
172+
function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock})
173+
return BBCode(
174+
new_blocks,
175+
CC.copy(ir.argtypes),
176+
CC.copy(ir.sptypes),
177+
CC.copy(ir.linetable),
178+
CC.copy(ir.meta),
179+
)
180+
end
159181
end
160182

161183
# Makes use of the above outer constructor for `BBCode`.
@@ -352,20 +374,42 @@ function CC.IRCode(bb_code::BBCode)
352374
insts = _ids_to_line_numbers(bb_code)
353375
cfg = control_flow_graph(bb_code)
354376
insts = _lines_to_blocks(insts, cfg)
355-
return IRCode(
356-
CC.InstructionStream(
357-
map(x -> x.stmt, insts),
358-
map(x -> x.type, insts),
359-
map(x -> x.info, insts),
360-
map(x -> x.line, insts),
361-
map(x -> x.flag, insts),
362-
),
363-
cfg,
364-
CC.copy(bb_code.linetable),
365-
CC.copy(bb_code.argtypes),
366-
CC.copy(bb_code.meta),
367-
CC.copy(bb_code.sptypes),
368-
)
377+
@static if VERSION >= v"1.12-"
378+
# See e.g. here for how the NTuple{3,Int}s get flattened for InstructionStream:
379+
# https://github.com/JuliaLang/julia/blob/16a2bf0a3b106b03dda23b8c9478aab90ffda5e1/Compiler/src/ssair/ir.jl#L299
380+
lines = map(x -> x.line, insts)
381+
lines = collect(Iterators.flatten(lines))
382+
return IRCode(
383+
CC.InstructionStream(
384+
map(x -> x.stmt, insts),
385+
collect(Any, map(x -> x.type, insts)),
386+
collect(CC.CallInfo, map(x -> x.info, insts)),
387+
lines,
388+
map(x -> x.flag, insts),
389+
),
390+
cfg,
391+
CC.copy(bb_code.debuginfo),
392+
CC.copy(bb_code.argtypes),
393+
CC.copy(bb_code.meta),
394+
CC.copy(bb_code.sptypes),
395+
bb_code.valid_worlds,
396+
)
397+
else
398+
return IRCode(
399+
CC.InstructionStream(
400+
map(x -> x.stmt, insts),
401+
map(x -> x.type, insts),
402+
map(x -> x.info, insts),
403+
map(x -> x.line, insts),
404+
map(x -> x.flag, insts),
405+
),
406+
cfg,
407+
CC.copy(bb_code.linetable),
408+
CC.copy(bb_code.argtypes),
409+
CC.copy(bb_code.meta),
410+
CC.copy(bb_code.sptypes),
411+
)
412+
end
369413
end
370414

371415
function _lower_switch_statements(bb_code::BBCode)

src/copyable_task.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ function build_callable(sig::Type{<:Tuple})
9090
unoptimised_ir = IRCode(bb)
9191
optimised_ir = optimise_ir!(unoptimised_ir)
9292
mc_ret_type = callable_ret_type(sig, types)
93-
mc = misty_closure(mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true)
93+
mc = optimized_misty_closure(
94+
mc_ret_type, optimised_ir, refs...; isva=isva, do_compile=true
95+
)
9496
mc_cache[key] = mc
9597
return mc, refs[end]
9698
end
@@ -277,6 +279,13 @@ The above gives the broad outline of how `TapedTask`s are implemented. We refer
277279
readers to the code, which is extensively commented to explain implementation details.
278280
"""
279281
function TapedTask(taped_globals::Any, fargs...; kwargs...)
282+
@static if v"1.12.1" > VERSION >= v"1.12.0-"
283+
@warn """
284+
Libtask.jl does not work correctly on Julia v1.12.0 and may crash your Julia
285+
session. Please upgrade to at least v1.12.1. See
286+
https://github.com/JuliaLang/julia/issues/59222 for the bug in question.
287+
"""
288+
end
280289
all_args = isempty(kwargs) ? fargs : (Core.kwcall, getfield(kwargs, :data), fargs...)
281290
seed_id!() # a BBCode thing.
282291
mc, count_ref = build_callable(typeof(all_args))
@@ -441,8 +450,10 @@ get_value(x) = x
441450
expression, otherwise `false`.
442451
"""
443452
function is_produce_stmt(x)::Bool
444-
if Meta.isexpr(x, :invoke) && length(x.args) == 3 && x.args[1] isa Core.MethodInstance
445-
return x.args[1].specTypes <: Tuple{typeof(produce),Any}
453+
if Meta.isexpr(x, :invoke) &&
454+
length(x.args) == 3 &&
455+
x.args[1] isa Union{Core.MethodInstance,Core.CodeInstance}
456+
return get_mi(x.args[1]).specTypes <: Tuple{typeof(produce),Any}
446457
elseif Meta.isexpr(x, :call) && length(x.args) == 2
447458
return get_value(x.args[1]) === produce
448459
else
@@ -465,7 +476,7 @@ function stmt_might_produce(x, ret_type::Type)::Bool
465476

466477
# Statement will terminate in the usual fashion, so _do_ bother recusing.
467478
is_produce_stmt(x) && return true
468-
Meta.isexpr(x, :invoke) && return might_produce(x.args[1].specTypes)
479+
Meta.isexpr(x, :invoke) && return might_produce(get_mi(x.args[1]).specTypes)
469480
if Meta.isexpr(x, :call)
470481
# This is a hack -- it's perfectly possible for `DataType` calls to produce in general.
471482
f = get_function(x.args[1])
@@ -1029,7 +1040,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
10291040

10301041
# Derive TapedTask for this statement.
10311042
(callable, callable_args) = if Meta.isexpr(stmt, :invoke)
1032-
sig = stmt.args[1].specTypes
1043+
sig = get_mi(stmt.args[1]).specTypes
10331044
v = Any[Any]
10341045
(LazyCallable{sig,callable_ret_type(sig, v)}(), stmt.args[2:end])
10351046
elseif Meta.isexpr(stmt, :call)
@@ -1144,7 +1155,13 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}}
11441155
new_argtypes = vcat(typeof(refs), copy(ir.argtypes))
11451156

11461157
# Return BBCode and the `Ref`s.
1147-
new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta)
1158+
@static if VERSION >= v"1.12-"
1159+
new_ir = BBCode(
1160+
new_bblocks, new_argtypes, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds
1161+
)
1162+
else
1163+
new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta)
1164+
end
11481165
return new_ir, refs, possible_produce_types
11491166
end
11501167

src/utils.jl

Lines changed: 113 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
function get_mi(ci::Core.CodeInstance)
2+
@static isdefined(CC, :get_ci_mi) ? CC.get_ci_mi(ci) : ci.def
3+
end
4+
get_mi(mi::Core.MethodInstance) = mi
15

26
"""
37
replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}
@@ -68,7 +72,11 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true)
6872

6973
ir = CC.compact!(ir)
7074
# CC.verify_ir(ir, true, false, CC.optimizer_lattice(local_interp))
71-
CC.verify_linetable(ir.linetable, true)
75+
@static if VERSION >= v"1.12-"
76+
CC.verify_linetable(ir.debuginfo, div(length(ir.debuginfo.codelocs), 3), true)
77+
else
78+
CC.verify_linetable(ir.linetable, true)
79+
end
7280
if show_ir
7381
println("Post-optimization")
7482
display(ir)
@@ -96,13 +104,27 @@ end
96104
# Run type inference and constant propagation on the ir. Credit to @oxinabox:
97105
# https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54
98106
function __infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance)
99-
method_info = CC.MethodInfo(true, nothing) #=propagate_inbounds=#
100-
min_world = world = get_inference_world(interp)
101-
max_world = Base.get_world_counter()
102-
irsv = CC.IRInterpretationState(
103-
interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world
104-
)
105-
rt = CC._ir_abstract_constant_propagation(interp, irsv)
107+
@static if VERSION >= v"1.12-"
108+
nargs = length(ir.argtypes) - 1
109+
# TODO(mhauru) How should we figure out isva? I don't think it's in ir or mi.
110+
isva = false
111+
propagate_inbounds = true
112+
spec_info = CC.SpecInfo(nargs, isva, propagate_inbounds, nothing)
113+
min_world = world = get_inference_world(interp)
114+
max_world = Base.get_world_counter()
115+
irsv = CC.IRInterpretationState(
116+
interp, spec_info, ir, mi, ir.argtypes, world, min_world, max_world
117+
)
118+
rt = CC.ir_abstract_constant_propagation(interp, irsv)
119+
else
120+
method_info = CC.MethodInfo(true, nothing) #=propagate_inbounds=#
121+
min_world = world = get_inference_world(interp)
122+
max_world = Base.get_world_counter()
123+
irsv = CC.IRInterpretationState(
124+
interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world
125+
)
126+
rt = CC._ir_abstract_constant_propagation(interp, irsv)
127+
end
106128
return ir
107129
end
108130

@@ -168,19 +190,85 @@ function opaque_closure(
168190
)
169191
# This implementation is copied over directly from `Core.OpaqueClosure`.
170192
ir = CC.copy(ir)
171-
nargs = length(ir.argtypes) - 1
172-
sig = Base.Experimental.compute_oc_signature(ir, nargs, isva)
193+
@static if VERSION >= v"1.12-"
194+
# On v1.12 OpaqueClosure expects the first arg to be the environment.
195+
ir.argtypes[1] = typeof(env)
196+
end
197+
nargtypes = length(ir.argtypes)
198+
nargs = nargtypes - 1
199+
@static if VERSION >= v"1.12-"
200+
sig = CC.compute_oc_signature(ir, nargs, isva)
201+
else
202+
sig = Base.Experimental.compute_oc_signature(ir, nargs, isva)
203+
end
173204
src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
174-
src.slotnames = fill(:none, nargs + 1)
175-
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
205+
src.slotnames = [Symbol(:_, i) for i in 1:nargtypes]
206+
src.slotflags = fill(zero(UInt8), nargtypes)
176207
src.slottypes = copy(ir.argtypes)
177-
src.rettype = ret_type
208+
@static if VERSION > v"1.12-"
209+
ir.debuginfo.def === nothing &&
210+
(ir.debuginfo.def = :var"generated IR for OpaqueClosure")
211+
src.min_world = ir.valid_worlds.min_world
212+
src.max_world = ir.valid_worlds.max_world
213+
src.isva = isva
214+
src.nargs = nargtypes
215+
end
178216
src = CC.ir_to_codeinf!(src, ir)
217+
src.rettype = ret_type
179218
return Base.Experimental.generate_opaque_closure(
180219
sig, Union{}, ret_type, src, nargs, isva, env...; do_compile
181220
)::Core.OpaqueClosure{sig,ret_type}
182221
end
183222

223+
function optimized_opaque_closure(rtype, ir::IRCode, env...; kwargs...)
224+
oc = opaque_closure(rtype, ir, env...; kwargs...)
225+
world = UInt(oc.world)
226+
set_world_bounds_for_optimization!(oc)
227+
optimized_oc = optimize_opaque_closure(oc, rtype, env...; kwargs...)
228+
return optimized_oc
229+
end
230+
231+
function optimize_opaque_closure(oc::Core.OpaqueClosure, rtype, env...; kwargs...)
232+
method = oc.source
233+
ci = method.specializations.cache
234+
world = UInt(oc.world)
235+
ir = reinfer_and_inline(ci, world)
236+
ir === nothing && return oc # nothing to optimize
237+
return opaque_closure(rtype, ir, env...; kwargs...)
238+
end
239+
240+
# Allows optimization to make assumptions about binding access,
241+
# enabling inlining and other optimizations.
242+
function set_world_bounds_for_optimization!(oc::Core.OpaqueClosure)
243+
ci = oc.source.specializations.cache
244+
ci.inferred === nothing && return nothing
245+
ci.inferred.min_world = oc.world
246+
return ci.inferred.max_world = oc.world
247+
end
248+
249+
function reinfer_and_inline(ci::Core.CodeInstance, world::UInt)
250+
interp = CC.NativeInterpreter(world)
251+
mi = get_mi(ci)
252+
argtypes = collect(Any, mi.specTypes.parameters)
253+
irsv = CC.IRInterpretationState(interp, ci, mi, argtypes, world)
254+
irsv === nothing && return nothing
255+
for stmt in irsv.ir.stmts
256+
inst = stmt[:inst]
257+
if Meta.isexpr(inst, :loopinfo) ||
258+
Meta.isexpr(inst, :pop_exception) ||
259+
isa(inst, CC.GotoIfNot) ||
260+
isa(inst, CC.GotoNode) ||
261+
Meta.isexpr(inst, :copyast)
262+
continue
263+
end
264+
stmt[:flag] |= CC.IR_FLAG_REFINED
265+
end
266+
CC.ir_abstract_constant_propagation(interp, irsv)
267+
state = CC.InliningState(interp)
268+
ir = CC.ssa_inlining_pass!(irsv.ir, state, CC.propagate_inbounds(irsv))
269+
return ir
270+
end
271+
184272
"""
185273
misty_closure(
186274
ret_type::Type,
@@ -202,3 +290,15 @@ function misty_closure(
202290
)
203291
return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir))
204292
end
293+
294+
function optimized_misty_closure(
295+
ret_type::Type,
296+
ir::IRCode,
297+
@nospecialize env...;
298+
isva::Bool=false,
299+
do_compile::Bool=true,
300+
)
301+
return MistyClosure(
302+
optimized_opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)
303+
)
304+
end

0 commit comments

Comments
 (0)