-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
I have a custom loss function for a GLM type model where I'm trying to fit an integer variable as Poisson distributed where the mean
using Lux
using Reactant
using Distributions
using SpecialFunctions
using Random
using LinearAlgebra
function poisson_loss(η,n)
lg = loggamma(n+1)
-n*η + lg + exp(η)
end
function lossfunc(model, ps, st, (x,y,L))
η,st_new = model(x, ps, st)
ll = sum(poisson_loss.(η,y))
ll2 = 0.01f0*η*L*η'
ll = ll + ll2
return ll, st_new, (;y_pred=η)
end
function test()
β = randn(Float32, 3)
X = randn(Float32, 3, 1000)
λ = exp.(X'*β)
y = reshape(rand.(Poisson.(λ)),1,1000)
@show size(y)
# Simple model
model = Lux.Dense(3,1)
dev = reactant_device()
d,n = size(X)
rng = Random.default_rng()
_ps,_st = Lux.setup(rng, model)
ps,st = dev((_ps, _st))
L = diagm(ones(Float32, 1000)) # dummy; this would be a graph laplacian
(xe,ye,Le) = dev.((X,y,L))
model_compiled = @compile model(xe, ps, Lux.testmode(st))
ll,st_new,_ = lossfunc(model_compiled, ps, st, (xe,ye,Le))
end
julia> test()
size(y) = (1, 1000)
ERROR: conversion to pointer not defined for ConcretePJRTArray{Float32, 2, 1}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] unsafe_convert(::Type{Ptr{Float32}}, a::ConcretePJRTArray{Float32, 2, 1})
@ Base ./pointer.jl:68
[3] gemm!(transA::Char, transB::Char, alpha::Float32, A::ConcretePJRTArray{…}, B::ConcretePJRTArray{…}, beta::Float32, C::ConcretePJRTArray{…})
@ LinearAlgebra.BLAS ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/blas.jl:1644
[4] gemm_wrapper!(C::ConcretePJRTArray{…}, tA::Char, tB::Char, A::ConcretePJRTArray{…}, B::ConcretePJRTArray{…}, _add::LinearAlgebra.MulAddMul{…})
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:657
[5] generic_matmatmul!
@ ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:381 [inlined]
[6] _mul!
@ ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
[7] mul!
@ ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
[8] mul!
@ ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
[9] *
@ ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:124 [inlined]
[10] _tri_matmul(A::ConcretePJRTArray{…}, B::ConcretePJRTArray{…}, C::Adjoint{…}, δ::Float32)
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:1144
[11] *
@ ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:1199 [inlined]
[12] lossfunc(model::Reactant.Compiler.Thunk{…}, ps::@NamedTuple{…}, st::@NamedTuple{}, ::Tuple{…})
@ Main ~/.julia/dev/RecurrentNetworkModels/src/glmtest.jl:17
[13] test()
@ Main ~/.julia/dev/RecurrentNetworkModels/src/glmtest.jl:43
[14] top-level scope
@ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia>
Metadata
Metadata
Assignees
Labels
No labels