|
| 1 | +const MTLLOG_SUBSYSTEM = "com.juliagpu.metal.jl" |
| 2 | +const MTLLOG_CATEGRORY = "mtlprintf" |
| 3 | + |
| 4 | +const __METAL_OS_LOG_TYPE_DEBUG__ = Int32(2) |
| 5 | +const __METAL_OS_LOG_TYPE_INFO__ = Int32(1) |
| 6 | +const __METAL_OS_LOG_TYPE_DEFAULT__ = Int32(0) |
| 7 | +const __METAL_OS_LOG_TYPE_ERROR__ = Int32(16) |
| 8 | +const __METAL_OS_LOG_TYPE_FAULT__ = Int32(17) |
| 9 | + |
| 10 | +const ALLOW_DOUBLE_META = "allowdouble" |
| 11 | + |
| 12 | +export @mtlprintf |
| 13 | + |
| 14 | +@generated function promote_c_argument(arg) |
| 15 | + # > When a function with a variable-length argument list is called, the variable |
| 16 | + # > arguments are passed using C's old ``default argument promotions.'' These say that |
| 17 | + # > types char and short int are automatically promoted to int, and type float is |
| 18 | + # > automatically promoted to double. Therefore, varargs functions will never receive |
| 19 | + # > arguments of type char, short int, or float. |
| 20 | + |
| 21 | + if arg == Cchar || arg == Cshort |
| 22 | + return :(Cint(arg)) |
| 23 | + else |
| 24 | + return :(arg) |
| 25 | + end |
| 26 | +end |
| 27 | + |
| 28 | +@generated function tag_doubles(arg) |
| 29 | + @dispose ctx=Context() begin |
| 30 | + ret = arg == Cfloat ? Cdouble : arg |
| 31 | + T_arg = convert(LLVMType, arg) |
| 32 | + T_ret = convert(LLVMType, ret) |
| 33 | + |
| 34 | + f, ft = create_function(T_ret, [T_arg]) |
| 35 | + |
| 36 | + @dispose builder=IRBuilder() begin |
| 37 | + entry = BasicBlock(f, "entry") |
| 38 | + position!(builder, entry) |
| 39 | + |
| 40 | + p1 = parameters(f)[1] |
| 41 | + |
| 42 | + if arg == Cfloat |
| 43 | + res = fpext!(builder, p1, LLVM.DoubleType()) |
| 44 | + metadata(res)["ir_check_ignore"] = MDNode([]) |
| 45 | + ret!(builder, res) |
| 46 | + else |
| 47 | + ret!(builder, p1) |
| 48 | + end |
| 49 | + end |
| 50 | + |
| 51 | + call_function(f, ret, Tuple{arg}, :arg) |
| 52 | + end |
| 53 | +end |
| 54 | + |
| 55 | + |
| 56 | +""" |
| 57 | + @mtlprintf("%Fmt", args...) |
| 58 | +
|
| 59 | +Print a formatted string in device context on the host standard output. |
| 60 | +""" |
| 61 | +macro mtlprintf(fmt::String, args...) |
| 62 | + fmt_val = Val(Symbol(fmt)) |
| 63 | + |
| 64 | + return :(_mtlprintf($fmt_val, $(map(arg -> :(tag_doubles(promote_c_argument($arg))), esc.(args))...))) |
| 65 | +end |
| 66 | + |
| 67 | +@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt} |
| 68 | + @dispose ctx=Context() begin |
| 69 | + arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)] |
| 70 | + arg_types = [argspec...] |
| 71 | + |
| 72 | + T_void = LLVM.VoidType() |
| 73 | + T_int32 = LLVM.Int32Type() |
| 74 | + T_int64 = LLVM.Int64Type() |
| 75 | + T_pint8 = LLVM.PointerType(LLVM.Int8Type()) |
| 76 | + T_pint8a2 = LLVM.PointerType(LLVM.Int8Type(), 2) |
| 77 | + |
| 78 | + # create functions |
| 79 | + param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types] |
| 80 | + llvm_f, llvm_ft = create_function(T_void, LLVMType[]; vararg=true) |
| 81 | + mod = LLVM.parent(llvm_f) |
| 82 | + |
| 83 | + # generate IR |
| 84 | + @dispose builder=IRBuilder() begin |
| 85 | + entry = BasicBlock(llvm_f, "entry") |
| 86 | + position!(builder, entry) |
| 87 | + |
| 88 | + str = globalstring_ptr!(builder, String(fmt), addrspace=2) |
| 89 | + |
| 90 | + # compute argsize |
| 91 | + argtypes = LLVM.StructType(param_types) |
| 92 | + dl = datalayout(mod) |
| 93 | + arg_size = LLVM.ConstantInt(T_int64, sizeof(dl, argtypes)) |
| 94 | + |
| 95 | + alloc = alloca!(builder, T_pint8) |
| 96 | + buffer = bitcast!(builder, alloc, T_pint8) |
| 97 | + alloc_size = LLVM.ConstantInt(T_int64, sizeof(dl, T_pint8)) |
| 98 | + |
| 99 | + lifetime_start_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8]) |
| 100 | + lifetime_start = LLVM.Function(mod, "llvm.lifetime.start.p0i8", lifetime_start_fty) |
| 101 | + call!(builder, lifetime_start_fty, lifetime_start, [alloc_size, buffer]) |
| 102 | + |
| 103 | + va_start_fty = LLVM.FunctionType(T_void, [T_pint8]) |
| 104 | + va_start = LLVM.Function(mod, "llvm.va_start", va_start_fty) |
| 105 | + call!(builder, va_start_fty, va_start, [buffer]) |
| 106 | + |
| 107 | + # invoke @air.os_log and return |
| 108 | + subsystem_str = globalstring_ptr!(builder, MTLLOG_SUBSYSTEM, addrspace=2) |
| 109 | + category_str = globalstring_ptr!(builder, MTLLOG_CATEGRORY, addrspace=2) |
| 110 | + log_type = LLVM.ConstantInt(T_int32, __METAL_OS_LOG_TYPE_DEBUG__) |
| 111 | + os_log_fty = LLVM.FunctionType(T_void, [T_pint8a2, T_pint8a2, T_int32, T_pint8a2, T_pint8, T_int64]) |
| 112 | + os_log = LLVM.Function(mod, "air.os_log", os_log_fty) |
| 113 | + |
| 114 | + arg_ptr = load!(builder, T_pint8, alloc) |
| 115 | + |
| 116 | + call!(builder, os_log_fty, os_log, [subsystem_str, category_str, log_type, str, arg_ptr, arg_size]) |
| 117 | + |
| 118 | + va_end_fty = LLVM.FunctionType(T_void, [T_pint8]) |
| 119 | + va_end = LLVM.Function(mod, "llvm.va_end", va_end_fty) |
| 120 | + call!(builder, va_end_fty, va_end, [buffer]) |
| 121 | + |
| 122 | + lifetime_end_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8]) |
| 123 | + lifetime_end = LLVM.Function(mod, "llvm.lifetime.end.p0i8", lifetime_end_fty) |
| 124 | + call!(builder, lifetime_end_fty, lifetime_end, [alloc_size, buffer]) |
| 125 | + |
| 126 | + ret!(builder) |
| 127 | + end |
| 128 | + |
| 129 | + call_function(llvm_f, Nothing, Tuple{arg_types...}, arg_exprs...) |
| 130 | + end |
| 131 | +end |
| 132 | + |
| 133 | + |
| 134 | +## print-like functionality |
| 135 | + |
| 136 | +export @mtlprint, @mtlprintln |
| 137 | + |
| 138 | +# simple conversions, defining an expression and the resulting argument type. nothing fancy, |
| 139 | +# `@mtlprint` pretty directly maps to `@mtlprintf`; we should just support `write(::IO)`. |
| 140 | +const mtlprint_conversions = [ |
| 141 | + Float32 => (x->:(Float64($x)), Float64), |
| 142 | + Ptr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}), |
| 143 | + LLVMPtr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}), |
| 144 | + Bool => (x->:(Int32($x)), Int32), |
| 145 | +] |
| 146 | + |
| 147 | +# format specifiers |
| 148 | +const mtlprint_specifiers = Dict( |
| 149 | + # integers |
| 150 | + Int16 => "%hd", |
| 151 | + Int32 => "%d", |
| 152 | + Int64 => "%ld", |
| 153 | + UInt16 => "%hu", |
| 154 | + UInt32 => "%u", |
| 155 | + UInt64 => "%lu", |
| 156 | + |
| 157 | + # floating-point |
| 158 | + Float32 => "%f", |
| 159 | + |
| 160 | + # other |
| 161 | + Cchar => "%c", |
| 162 | + Ptr{Cvoid} => "%p", |
| 163 | + Cstring => "%s", |
| 164 | +) |
| 165 | + |
| 166 | +@inline @generated function _mtlprint(parts...) |
| 167 | + fmt = "" |
| 168 | + args = Expr[] |
| 169 | + |
| 170 | + for i in 1:length(parts) |
| 171 | + part = :(parts[$i]) |
| 172 | + T = parts[i] |
| 173 | + |
| 174 | + # put literals directly in the format string |
| 175 | + if T <: Val |
| 176 | + fmt *= string(T.parameters[1]) |
| 177 | + continue |
| 178 | + end |
| 179 | + |
| 180 | + # try to convert arguments if they are not supported directly |
| 181 | + if !haskey(mtlprint_specifiers, T) |
| 182 | + for (Tmatch, rule) in mtlprint_conversions |
| 183 | + if T <: Tmatch |
| 184 | + part = rule[1](part) |
| 185 | + T = rule[2] |
| 186 | + break |
| 187 | + end |
| 188 | + end |
| 189 | + end |
| 190 | + |
| 191 | + # render the argument |
| 192 | + if haskey(mtlprint_specifiers, T) |
| 193 | + fmt *= mtlprint_specifiers[T] |
| 194 | + push!(args, part) |
| 195 | + elseif T <: Tuple |
| 196 | + fmt *= "(" |
| 197 | + for (j, U) in enumerate(T.parameters) |
| 198 | + if haskey(mtlprint_specifiers, U) |
| 199 | + fmt *= mtlprint_specifiers[U] |
| 200 | + push!(args, :($part[$j])) |
| 201 | + if j < length(T.parameters) |
| 202 | + fmt *= ", " |
| 203 | + elseif length(T.parameters) == 1 |
| 204 | + fmt *= "," |
| 205 | + end |
| 206 | + else |
| 207 | + @error("@mtlprint does not support values of type $U") |
| 208 | + end |
| 209 | + end |
| 210 | + fmt *= ")" |
| 211 | + elseif T <: String |
| 212 | + @error("@mtlprint does not support non-literal strings") |
| 213 | + elseif T <: Type |
| 214 | + fmt *= string(T.parameters[1]) |
| 215 | + else |
| 216 | + @warn("@mtlprint does not support values of type $T") |
| 217 | + fmt *= "$(T)(...)" |
| 218 | + end |
| 219 | + end |
| 220 | + |
| 221 | + quote |
| 222 | + @mtlprintf($fmt, $(args...)) |
| 223 | + end |
| 224 | +end |
| 225 | + |
| 226 | +""" |
| 227 | + @mtlprint(xs...) |
| 228 | + @mtlprintln(xs...) |
| 229 | +
|
| 230 | +Print a textual representation of values `xs` to standard output from the GPU. The |
| 231 | +functionality builds on `@mtlprintf`, and is intended as a more use friendly alternative of |
| 232 | +that API. However, that also means there's only limited support for argument types, handling |
| 233 | +16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, `Cchar`s and |
| 234 | +pointers. For more complex output, use `@mtlprintf` directly. |
| 235 | +
|
| 236 | +Limited string interpolation is also possible: |
| 237 | +
|
| 238 | +```julia |
| 239 | + @mtlprint("Hello, World ", 42, "\\n") |
| 240 | + @mtlprint "Hello, World \$(42)\\n" |
| 241 | +``` |
| 242 | +""" |
| 243 | +macro mtlprint(parts...) |
| 244 | + args = Union{Val,Expr,Symbol}[] |
| 245 | + |
| 246 | + parts = [parts...] |
| 247 | + while true |
| 248 | + isempty(parts) && break |
| 249 | + |
| 250 | + part = popfirst!(parts) |
| 251 | + |
| 252 | + # handle string interpolation |
| 253 | + if isa(part, Expr) && part.head == :string |
| 254 | + parts = vcat(part.args, parts) |
| 255 | + continue |
| 256 | + end |
| 257 | + |
| 258 | + # expose literals to the generator by using Val types |
| 259 | + if isbits(part) # literal numbers, etc |
| 260 | + push!(args, Val(part)) |
| 261 | + elseif isa(part, QuoteNode) # literal symbols |
| 262 | + push!(args, Val(part.value)) |
| 263 | + elseif isa(part, String) # literal strings need to be interned |
| 264 | + push!(args, Val(Symbol(part))) |
| 265 | + else # actual values that will be passed to printf |
| 266 | + push!(args, part) |
| 267 | + end |
| 268 | + end |
| 269 | + |
| 270 | + quote |
| 271 | + _mtlprint($(map(esc, args)...)) |
| 272 | + end |
| 273 | +end |
| 274 | + |
| 275 | +@doc (@doc @mtlprint) -> |
| 276 | +macro mtlprintln(parts...) |
| 277 | + esc(quote |
| 278 | + Metal.@mtlprint($(parts...), "\n") |
| 279 | + end) |
| 280 | +end |
| 281 | + |
| 282 | +export @mtlshow |
| 283 | + |
| 284 | +""" |
| 285 | + @mtlshow(ex) |
| 286 | +
|
| 287 | +GPU analog of `Base.@show`. It comes with the same type restrictions as [`@mtlprintf`](@ref). |
| 288 | +
|
| 289 | +```julia |
| 290 | +@mtlshow thread_position_in_grid_1d() |
| 291 | +``` |
| 292 | +""" |
| 293 | +macro mtlshow(exs...) |
| 294 | + blk = Expr(:block) |
| 295 | + for ex in exs |
| 296 | + push!(blk.args, :(Metal.@mtlprintln($(sprint(Base.show_unquoted,ex)*" = "), |
| 297 | + begin local value = $(esc(ex)) end))) |
| 298 | + end |
| 299 | + isempty(exs) || push!(blk.args, :value) |
| 300 | + blk |
| 301 | +end |
0 commit comments