|
| 1 | +ENV["XLA_REACTANT_GPU_MEM_FRACTION"] = "0.98" |
| 2 | + |
1 | 3 | using ConcreteStructs, |
2 | 4 | MLUtils, |
3 | 5 | Lux, |
@@ -120,25 +122,35 @@ function (model::GPT2)(x, ps, st) |
120 | 122 | return outputs, (; tok_emb=st_tok_emb, pos_emb=st_pos_emb, gpt_blocks=st_gpt_blocks) |
121 | 123 | end |
122 | 124 |
|
| 125 | +#= |
123 | 126 | dev = reactant_device(; force=true) |
124 | 127 | rng = Random.default_rng() |
125 | 128 |
|
126 | 129 | model = GPT2(; |
127 | 130 | n_vocab=50304, |
128 | | - embed_dim=768, |
| 131 | + embed_dim=1024, |
129 | 132 | hidden_dim=3072, |
130 | 133 | block_size=1024, |
131 | 134 | n_layers=3, |
132 | 135 | dropout_rate=0.0, |
133 | | - num_heads=12, |
| 136 | + num_heads=16, |
134 | 137 | ) |
135 | 138 | ps, st = Lux.setup(rng, model) |> dev; |
136 | 139 |
|
137 | | -x = rand(1:50304, 1024, 32) |> dev; |
| 140 | +x = rand(1:50304, 48, 32) |> dev; |
138 | 141 |
|
139 | 142 | @code_hlo model(x, ps, st) |
140 | 143 |
|
| 144 | +sumabs2first(layer, x, ps, st) = sum(abs2, first(layer(x, ps, st))) |
| 145 | +
|
| 146 | +@code_hlo Enzyme.gradient(Reverse, sumabs2first, Const(model), x, ps, Const(st)) |
| 147 | +=# |
| 148 | + |
141 | 149 | # Use the model to generate some text. |
| 150 | +# function weighted_sample(items::AbstractVector, weights::AbstractVector) |
| 151 | + |
| 152 | +# end |
| 153 | + |
142 | 154 | function generate_text(model, ps, st, seed; alphabet, output_length, sequence_length) |
143 | 155 | dev = get_device((ps, st)) |
144 | 156 | @assert !(dev isa ReactantDevice) "Currently we don't support running inference of \ |
|
0 commit comments