Skip to content

Commit f4504bc

Browse files
committed
feat: finish the implementation
1 parent 4b33068 commit f4504bc

File tree

3 files changed

+172
-59
lines changed

3 files changed

+172
-59
lines changed

examples/NanoGPT/Project.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,19 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1414
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1515
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
16+
17+
[compat]
18+
Comonicon = "1"
19+
DataDeps = "0.7"
20+
Enzyme = "0.13.14"
21+
JLD2 = "0.5"
22+
Lux = "1.2.3"
23+
MLUtils = "0.4"
24+
NNlib = "0.9.24"
25+
OneHotArrays = "0.2.5"
26+
Optimisers = "0.4.1"
27+
Printf = "1.10"
28+
Random = "1.10"
29+
Reactant = "0.2.5"
30+
Statistics = "1.10"
31+
StatsBase = "0.34.3"

examples/NanoGPT/README.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# NanoGPT using Lux & Reactant
2+
3+
## Requirements
4+
5+
* Install [julia](https://julialang.org/)
6+
* In the Julia REPL instantiate the `Project.toml` in the parent directory
7+
8+
## Training
9+
10+
To train a model, run `main.jl` with the necessary parameters.
11+
12+
```bash
13+
julia --startup=no --project=examples/NanoGPT --threads=auto examples/NanoGPT/main.jl
14+
```
15+
16+
## Inference
17+
18+
To run inference on a trained model, run `main.jl` with the necessary parameters.
19+
20+
```bash
21+
julia --startup=no --project=examples/NanoGPT --threads=auto examples/NanoGPT/main.jl \
22+
--inference \
23+
--model-path=<path to model checkpoint>
24+
```
25+
26+
## Usage
27+
28+
```bash
29+
main
30+
31+
Usage
32+
33+
main [options] [flags]
34+
35+
Options
36+
37+
--n-embed <64::Int>
38+
--n-hidden <256::Int>
39+
--n-heads <4::Int>
40+
--qk-dim <16::Int>
41+
--v-dim <16::Int>
42+
--n-layers <6::Int>
43+
--sequence-length <64::Int>
44+
--batchsize <128::Int>
45+
--dropout-rate <0.0::Float32>
46+
--test-split <0.1::Float64>
47+
--lr <0.01::Float64>
48+
--epochs <100::Int>
49+
--model-path <::String>
50+
--seed <::Union{String, Vector{String}}>
51+
--output-length <1024::Int>
52+
53+
Flags
54+
55+
--inference
56+
-h, --help Print this help message.
57+
--version Print version.
58+
```

examples/NanoGPT/main.jl

Lines changed: 98 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Taken from https://github.com/FluxML/model-zoo/pull/410
22
using MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib, DataDeps, StatsBase,
3-
OneHotArrays
3+
OneHotArrays, JLD2
44
using Reactant, Enzyme
55
using Comonicon: @main
66

@@ -51,42 +51,57 @@ function GPT(;
5151
token_embedding=Embedding(n_vocab => n_embed),
5252
position_embedding=Embedding(sequence_length => n_embed),
5353
drop=Dropout(dropout_rate),
54-
blocks=ntuple(n_layers) do i
54+
blocks=Chain(ntuple(n_layers) do i
5555
return gpt_block(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate)
56-
end,
56+
end...),
5757
ln=LayerNorm((n_embed, 1)),
5858
output_layer=Dense(n_embed => n_vocab)) do tokens
59-
te = token_embedding(tokens)
60-
pe = position_embedding(1:size(tokens, 1))
61-
x = drop(te .+ pe)
62-
for blk in blocks
63-
x = blk(x)
64-
end
65-
x = ln(x)
66-
x = output_layer(x)
67-
@return x
59+
x = drop(token_embedding(tokens) .+ position_embedding(1:size(tokens, 1)))
60+
x = blocks(x)
61+
@return output_layer(ln(x))
6862
end
6963
end
7064

7165
# Use the model to generate some text.
72-
# function generate(model, seed, outlen)
73-
# seqlen = context_length(model)
74-
# if isempty(seed)
75-
# seed = "_"
76-
# end
77-
# x = map(c -> findfirst(==(c), model.alphabet)::Int64, collect(seed))
78-
# while length(x) < outlen
79-
# tail = x[max(1, end-seqlen+1):end]
80-
# tail = reshape(tail, length(tail), 1)
81-
# y = model(tail |> device) |> cpu
82-
# p = softmax(y[:,end,1])
83-
# j = sample(1:length(model.alphabet), Weights(p))
84-
# #j = argmax(p)
85-
# #x = vcat(x, [j])
86-
# push!(x, j)
87-
# end
88-
# String(map(j -> model.alphabet[j], x))
89-
# end
66+
function generate_text(
67+
model, ps, st, seed; alphabet, output_length, sequence_length
68+
)
69+
dev = get_device((ps, st))
70+
@assert !(dev isa ReactantDevice) "Currently we don't support running inference of \
71+
dynamically sized tensors."
72+
73+
seed = copy(seed)
74+
seed_len = maximum(length, seed)
75+
extra_letters = zeros(Int, length(seed))
76+
for (i, s) in enumerate(seed)
77+
if seed_len != length(s)
78+
extra_letters[i] = seed_len - length(s)
79+
seed[i] = "_"^extra_letters[i] * s
80+
end
81+
end
82+
original_output_length = output_length
83+
output_length += maximum(extra_letters)
84+
85+
st = Lux.testmode(st)
86+
87+
x = zeros(Int, output_length, length(seed))
88+
for (i, s) in enumerate(seed), j in 1:seed_len
89+
x[j, i] = findfirst(==(s[j]), alphabet)
90+
end
91+
for i in (seed_len + 1):output_length
92+
tail = x[max(1, i - sequence_length + 1):(i - 1), :] |> dev
93+
y = model(tail, ps, st)[1] |> cpu_device()
94+
p = softmax(y[:, end, 1])
95+
x[i, :] .= sample(1:length(alphabet), Weights(p))
96+
end
97+
98+
res = [String(map(Base.Fix1(getindex, alphabet), x[:, i])) for i in axes(x, 2)]
99+
for i in eachindex(res)
100+
res[i] = res[i][(extra_letters[i] + 1):end][1:original_output_length]
101+
end
102+
103+
return res
104+
end
90105

91106
# Load data from input file, and partition into training and testing subsets.
92107
function get_nanogpt_data(; sequence_length, test_split)
@@ -121,32 +136,62 @@ function get_nanogpt_data(; sequence_length, test_split)
121136
return alphabet, Array(trainX), Array(trainY), Array(testX), Array(testY)
122137
end
123138

124-
@main function train_nanogpt(;
139+
@main function main(;
125140
n_embed::Int=64, n_hidden::Int=256, n_heads::Int=4, qk_dim::Int=16,
126141
v_dim::Int=16, n_layers::Int=6, sequence_length::Int=64, batchsize::Int=128,
127142
dropout_rate::Float32=0.0f0, test_split::Float64=0.1, lr::Float64=1e-2,
128-
epochs::Int=20
143+
epochs::Int=100,
144+
# Only inference options
145+
inference::Bool=false, model_path::String="",
146+
seed::Union{String, Vector{String}}=["_", "The", "Julia", "Lux.jl"],
147+
output_length::Int=1024
129148
)
130-
alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split)
131-
132-
@printf "[Info] Alphabet size: %d\n" length(alphabet)
133-
@printf "[Info] Training size: %d sequences.\n" size(trainX, 2)
134-
@printf "[Info] Testing size: %d sequences.\n\n" size(testX, 2)
135-
136149
rng = Random.default_rng()
137150
Random.seed!(rng, 1234)
138151

139152
dev = reactant_device()
140153
cdev = cpu_device()
141154

155+
if inference
156+
@printf "[Info] Inference mode enabled.\n"
157+
158+
@assert !isempty(model_path) "Please provide a path to a model checkpoint."
159+
160+
@printf "[Info] Loading model from %s.\n" model_path
161+
model_config = JLD2.load(model_path, "model_config")
162+
model = GPT(; model_config...)
163+
ps = JLD2.load(model_path, "parameters")
164+
st = JLD2.load(model_path, "states")
165+
alphabet = JLD2.load(model_path, "alphabet")
166+
sequence_length = model_config.sequence_length
167+
168+
texts = generate_text(
169+
model, ps, st, seed; alphabet, output_length, sequence_length
170+
)
171+
172+
for (i, (text, s)) in enumerate(zip(texts, seed))
173+
@printf "[Info] Seed [%d]: %s\n" i s
174+
@printf "[Generated Text] %s\n\n" text
175+
end
176+
177+
return
178+
end
179+
180+
alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split)
181+
182+
@printf "[Info] Alphabet size: %d\n" length(alphabet)
183+
@printf "[Info] Training size: %d sequences.\n" size(trainX, 2)
184+
@printf "[Info] Testing size: %d sequences.\n\n" size(testX, 2)
185+
142186
train_loader = DataLoader(
143187
(trainX, trainY); batchsize, shuffle=true, parallel=true
144188
) |> dev
145189

146-
model = GPT(;
190+
model_config = (;
147191
n_vocab=length(alphabet), n_embed, sequence_length, n_hidden,
148192
n_layers, dropout_rate, n_heads, qk_dim, v_dim
149193
)
194+
model = GPT(; model_config...)
150195
ps, st = Lux.setup(rng, model) |> dev
151196
@printf "[Info] Number of parameters: %d\n" Lux.parameterlength(ps)
152197
@printf "[Info] Number of states: %d\n\n" Lux.statelength(st)
@@ -156,9 +201,12 @@ end
156201

157202
@printf "[Info] Compiling Inference Model...\n"
158203
testX, testY = (testX, testY) |> dev
204+
start_time = time()
159205
model_compiled = @compile model(testX, ps, Lux.testmode(st))
206+
time_to_compile = time() - start_time
160207
best_test_loss = Inf
161208

209+
@printf "[Info] Time taken to compile inference model: %0.5fs\n" time_to_compile
162210
@printf "[Info] Starting Model Training...\n\n"
163211

164212
loss_fn = CrossEntropyLoss(; logits=Val(true))
@@ -185,7 +233,15 @@ end
185233
)
186234
@printf "[Test] Epoch %3d\tTest Loss %.8e\n" epoch test_loss
187235

188-
# XXX: Also generate some text here...
236+
# Generate some text here...
237+
texts = generate_text(
238+
model, ps |> cdev, st |> cdev, seed;
239+
alphabet, output_length, sequence_length
240+
)
241+
for (i, (text, s)) in enumerate(zip(texts, seed))
242+
@printf "[Info] Seed [%d]: %s\n" i s
243+
@printf "[Generated Text] %s\n\n" text
244+
end
189245

190246
if test_loss < best_test_loss
191247
best_test_loss = test_loss
@@ -195,26 +251,9 @@ end
195251
joinpath(@__DIR__, "nanogpt.jld2");
196252
parameters=train_state.parameters |> cdev,
197253
states=train_state.states |> cdev,
198-
alphabet=alphabet
254+
alphabet=alphabet,
255+
model_config=model_config
199256
)
200257
end
201258
end
202259
end
203-
204-
# # Load a model from a checkpoint (see `jldsave` above).
205-
# function load_model(filename)
206-
# args = JLD2.load(filename, "args")
207-
# alphabet = JLD2.load(filename, "alphabet")
208-
# model = GPT(args, alphabet)
209-
# model_state = JLD2.load(filename, "model_state")
210-
# model = Flux.loadmodel!(model, model_state);
211-
# return args, model
212-
# end
213-
214-
# if true
215-
# args, model = train()
216-
# else
217-
# args, model = load_model("model-checkpoint.jld2") |> device
218-
# end
219-
220-
# generate(model, "The", 50)

0 commit comments

Comments
 (0)