Skip to content

Commit 0103f95

Browse files
committed
feat: more progress
1 parent 0650b61 commit 0103f95

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

examples/NanoGPT/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1616
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
17-
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1817

1918
[compat]
2019
Comonicon = "1"
@@ -30,4 +29,3 @@ Printf = "1.10"
3029
Random = "1.10"
3130
Reactant = "0.2.5"
3231
Statistics = "1.10"
33-
StatsBase = "0.34.3"

examples/NanoGPT/main.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
ENV["XLA_REACTANT_GPU_MEM_FRACTION"] = "0.98"
2+
13
using ConcreteStructs,
24
MLUtils,
35
Lux,
@@ -120,25 +122,35 @@ function (model::GPT2)(x, ps, st)
120122
return outputs, (; tok_emb=st_tok_emb, pos_emb=st_pos_emb, gpt_blocks=st_gpt_blocks)
121123
end
122124

125+
#=
123126
dev = reactant_device(; force=true)
124127
rng = Random.default_rng()
125128
126129
model = GPT2(;
127130
n_vocab=50304,
128-
embed_dim=768,
131+
embed_dim=1024,
129132
hidden_dim=3072,
130133
block_size=1024,
131134
n_layers=3,
132135
dropout_rate=0.0,
133-
num_heads=12,
136+
num_heads=16,
134137
)
135138
ps, st = Lux.setup(rng, model) |> dev;
136139
137-
x = rand(1:50304, 1024, 32) |> dev;
140+
x = rand(1:50304, 48, 32) |> dev;
138141
139142
@code_hlo model(x, ps, st)
140143
144+
sumabs2first(layer, x, ps, st) = sum(abs2, first(layer(x, ps, st)))
145+
146+
@code_hlo Enzyme.gradient(Reverse, sumabs2first, Const(model), x, ps, Const(st))
147+
=#
148+
141149
# Use the model to generate some text.
150+
# function weighted_sample(items::AbstractVector, weights::AbstractVector)
151+
152+
# end
153+
142154
function generate_text(model, ps, st, seed; alphabet, output_length, sequence_length)
143155
dev = get_device((ps, st))
144156
@assert !(dev isa ReactantDevice) "Currently we don't support running inference of \

src/layers/embedding.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function (e::Embedding)(x::Union{Number,AbstractVector}, ps, st::NamedTuple)
6363
end
6464
function (e::Embedding)(x::AbstractArray, ps, st::NamedTuple)
6565
@argcheck Utils.eltype(x) <: Integer
66-
y, stₙ = e(vec(x), ps, st)
66+
y, stₙ = e(Utils.vec(x), ps, st)
6767
return reshape(y, :, size(x)...), stₙ
6868
end
6969
function (e::Embedding)(x::NTuple{N,T}, ps, st::NamedTuple) where {N,T}

0 commit comments

Comments
 (0)