|
1 | 1 | # Taken from https://github.com/FluxML/model-zoo/pull/410 |
2 | 2 | using MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib, DataDeps, StatsBase, |
3 | | - OneHotArrays |
| 3 | + OneHotArrays, JLD2 |
4 | 4 | using Reactant, Enzyme |
5 | 5 | using Comonicon: @main |
6 | 6 |
|
@@ -51,42 +51,57 @@ function GPT(; |
51 | 51 | token_embedding=Embedding(n_vocab => n_embed), |
52 | 52 | position_embedding=Embedding(sequence_length => n_embed), |
53 | 53 | drop=Dropout(dropout_rate), |
54 | | - blocks=ntuple(n_layers) do i |
| 54 | + blocks=Chain(ntuple(n_layers) do i |
55 | 55 | return gpt_block(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate) |
56 | | - end, |
| 56 | + end...), |
57 | 57 | ln=LayerNorm((n_embed, 1)), |
58 | 58 | output_layer=Dense(n_embed => n_vocab)) do tokens |
59 | | - te = token_embedding(tokens) |
60 | | - pe = position_embedding(1:size(tokens, 1)) |
61 | | - x = drop(te .+ pe) |
62 | | - for blk in blocks |
63 | | - x = blk(x) |
64 | | - end |
65 | | - x = ln(x) |
66 | | - x = output_layer(x) |
67 | | - @return x |
| 59 | + x = drop(token_embedding(tokens) .+ position_embedding(1:size(tokens, 1))) |
| 60 | + x = blocks(x) |
| 61 | + @return output_layer(ln(x)) |
68 | 62 | end |
69 | 63 | end |
70 | 64 |
|
71 | 65 | # Use the model to generate some text. |
72 | | -# function generate(model, seed, outlen) |
73 | | -# seqlen = context_length(model) |
74 | | -# if isempty(seed) |
75 | | -# seed = "_" |
76 | | -# end |
77 | | -# x = map(c -> findfirst(==(c), model.alphabet)::Int64, collect(seed)) |
78 | | -# while length(x) < outlen |
79 | | -# tail = x[max(1, end-seqlen+1):end] |
80 | | -# tail = reshape(tail, length(tail), 1) |
81 | | -# y = model(tail |> device) |> cpu |
82 | | -# p = softmax(y[:,end,1]) |
83 | | -# j = sample(1:length(model.alphabet), Weights(p)) |
84 | | -# #j = argmax(p) |
85 | | -# #x = vcat(x, [j]) |
86 | | -# push!(x, j) |
87 | | -# end |
88 | | -# String(map(j -> model.alphabet[j], x)) |
89 | | -# end |
| 66 | +function generate_text( |
| 67 | + model, ps, st, seed; alphabet, output_length, sequence_length |
| 68 | +) |
| 69 | + dev = get_device((ps, st)) |
| 70 | + @assert !(dev isa ReactantDevice) "Currently we don't support running inference of \ |
| 71 | + dynamically sized tensors." |
| 72 | + |
| 73 | + seed = copy(seed) |
| 74 | + seed_len = maximum(length, seed) |
| 75 | + extra_letters = zeros(Int, length(seed)) |
| 76 | + for (i, s) in enumerate(seed) |
| 77 | + if seed_len != length(s) |
| 78 | + extra_letters[i] = seed_len - length(s) |
| 79 | + seed[i] = "_"^extra_letters[i] * s |
| 80 | + end |
| 81 | + end |
| 82 | + original_output_length = output_length |
| 83 | + output_length += maximum(extra_letters) |
| 84 | + |
| 85 | + st = Lux.testmode(st) |
| 86 | + |
| 87 | + x = zeros(Int, output_length, length(seed)) |
| 88 | + for (i, s) in enumerate(seed), j in 1:seed_len |
| 89 | + x[j, i] = findfirst(==(s[j]), alphabet) |
| 90 | + end |
| 91 | + for i in (seed_len + 1):output_length |
| 92 | + tail = x[max(1, i - sequence_length + 1):(i - 1), :] |> dev |
| 93 | + y = model(tail, ps, st)[1] |> cpu_device() |
| 94 | + p = softmax(y[:, end, 1]) |
| 95 | + x[i, :] .= sample(1:length(alphabet), Weights(p)) |
| 96 | + end |
| 97 | + |
| 98 | + res = [String(map(Base.Fix1(getindex, alphabet), x[:, i])) for i in axes(x, 2)] |
| 99 | + for i in eachindex(res) |
| 100 | + res[i] = res[i][(extra_letters[i] + 1):end][1:original_output_length] |
| 101 | + end |
| 102 | + |
| 103 | + return res |
| 104 | +end |
90 | 105 |
|
91 | 106 | # Load data from input file, and partition into training and testing subsets. |
92 | 107 | function get_nanogpt_data(; sequence_length, test_split) |
@@ -121,32 +136,62 @@ function get_nanogpt_data(; sequence_length, test_split) |
121 | 136 | return alphabet, Array(trainX), Array(trainY), Array(testX), Array(testY) |
122 | 137 | end |
123 | 138 |
|
124 | | -@main function train_nanogpt(; |
| 139 | +@main function main(; |
125 | 140 | n_embed::Int=64, n_hidden::Int=256, n_heads::Int=4, qk_dim::Int=16, |
126 | 141 | v_dim::Int=16, n_layers::Int=6, sequence_length::Int=64, batchsize::Int=128, |
127 | 142 | dropout_rate::Float32=0.0f0, test_split::Float64=0.1, lr::Float64=1e-2, |
128 | | - epochs::Int=20 |
| 143 | + epochs::Int=100, |
| 144 | + # Only inference options |
| 145 | + inference::Bool=false, model_path::String="", |
| 146 | + seed::Union{String, Vector{String}}=["_", "The", "Julia", "Lux.jl"], |
| 147 | + output_length::Int=1024 |
129 | 148 | ) |
130 | | - alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split) |
131 | | - |
132 | | - @printf "[Info] Alphabet size: %d\n" length(alphabet) |
133 | | - @printf "[Info] Training size: %d sequences.\n" size(trainX, 2) |
134 | | - @printf "[Info] Testing size: %d sequences.\n\n" size(testX, 2) |
135 | | - |
136 | 149 | rng = Random.default_rng() |
137 | 150 | Random.seed!(rng, 1234) |
138 | 151 |
|
139 | 152 | dev = reactant_device() |
140 | 153 | cdev = cpu_device() |
141 | 154 |
|
| 155 | + if inference |
| 156 | + @printf "[Info] Inference mode enabled.\n" |
| 157 | + |
| 158 | + @assert !isempty(model_path) "Please provide a path to a model checkpoint." |
| 159 | + |
| 160 | + @printf "[Info] Loading model from %s.\n" model_path |
| 161 | + model_config = JLD2.load(model_path, "model_config") |
| 162 | + model = GPT(; model_config...) |
| 163 | + ps = JLD2.load(model_path, "parameters") |
| 164 | + st = JLD2.load(model_path, "states") |
| 165 | + alphabet = JLD2.load(model_path, "alphabet") |
| 166 | + sequence_length = model_config.sequence_length |
| 167 | + |
| 168 | + texts = generate_text( |
| 169 | + model, ps, st, seed; alphabet, output_length, sequence_length |
| 170 | + ) |
| 171 | + |
| 172 | + for (i, (text, s)) in enumerate(zip(texts, seed)) |
| 173 | + @printf "[Info] Seed [%d]: %s\n" i s |
| 174 | + @printf "[Generated Text] %s\n\n" text |
| 175 | + end |
| 176 | + |
| 177 | + return |
| 178 | + end |
| 179 | + |
| 180 | + alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split) |
| 181 | + |
| 182 | + @printf "[Info] Alphabet size: %d\n" length(alphabet) |
| 183 | + @printf "[Info] Training size: %d sequences.\n" size(trainX, 2) |
| 184 | + @printf "[Info] Testing size: %d sequences.\n\n" size(testX, 2) |
| 185 | + |
142 | 186 | train_loader = DataLoader( |
143 | 187 | (trainX, trainY); batchsize, shuffle=true, parallel=true |
144 | 188 | ) |> dev |
145 | 189 |
|
146 | | - model = GPT(; |
| 190 | + model_config = (; |
147 | 191 | n_vocab=length(alphabet), n_embed, sequence_length, n_hidden, |
148 | 192 | n_layers, dropout_rate, n_heads, qk_dim, v_dim |
149 | 193 | ) |
| 194 | + model = GPT(; model_config...) |
150 | 195 | ps, st = Lux.setup(rng, model) |> dev |
151 | 196 | @printf "[Info] Number of parameters: %d\n" Lux.parameterlength(ps) |
152 | 197 | @printf "[Info] Number of states: %d\n\n" Lux.statelength(st) |
|
156 | 201 |
|
157 | 202 | @printf "[Info] Compiling Inference Model...\n" |
158 | 203 | testX, testY = (testX, testY) |> dev |
| 204 | + start_time = time() |
159 | 205 | model_compiled = @compile model(testX, ps, Lux.testmode(st)) |
| 206 | + time_to_compile = time() - start_time |
160 | 207 | best_test_loss = Inf |
161 | 208 |
|
| 209 | + @printf "[Info] Time taken to compile inference model: %0.5fs\n" time_to_compile |
162 | 210 | @printf "[Info] Starting Model Training...\n\n" |
163 | 211 |
|
164 | 212 | loss_fn = CrossEntropyLoss(; logits=Val(true)) |
|
185 | 233 | ) |
186 | 234 | @printf "[Test] Epoch %3d\tTest Loss %.8e\n" epoch test_loss |
187 | 235 |
|
188 | | - # XXX: Also generate some text here... |
| 236 | + # Generate some text here... |
| 237 | + texts = generate_text( |
| 238 | + model, ps |> cdev, st |> cdev, seed; |
| 239 | + alphabet, output_length, sequence_length |
| 240 | + ) |
| 241 | + for (i, (text, s)) in enumerate(zip(texts, seed)) |
| 242 | + @printf "[Info] Seed [%d]: %s\n" i s |
| 243 | + @printf "[Generated Text] %s\n\n" text |
| 244 | + end |
189 | 245 |
|
190 | 246 | if test_loss < best_test_loss |
191 | 247 | best_test_loss = test_loss |
|
195 | 251 | joinpath(@__DIR__, "nanogpt.jld2"); |
196 | 252 | parameters=train_state.parameters |> cdev, |
197 | 253 | states=train_state.states |> cdev, |
198 | | - alphabet=alphabet |
| 254 | + alphabet=alphabet, |
| 255 | + model_config=model_config |
199 | 256 | ) |
200 | 257 | end |
201 | 258 | end |
202 | 259 | end |
203 | | - |
204 | | -# # Load a model from a checkpoint (see `jldsave` above). |
205 | | -# function load_model(filename) |
206 | | -# args = JLD2.load(filename, "args") |
207 | | -# alphabet = JLD2.load(filename, "alphabet") |
208 | | -# model = GPT(args, alphabet) |
209 | | -# model_state = JLD2.load(filename, "model_state") |
210 | | -# model = Flux.loadmodel!(model, model_state); |
211 | | -# return args, model |
212 | | -# end |
213 | | - |
214 | | -# if true |
215 | | -# args, model = train() |
216 | | -# else |
217 | | -# args, model = load_model("model-checkpoint.jld2") |> device |
218 | | -# end |
219 | | - |
220 | | -# generate(model, "The", 50) |
|
0 commit comments