|
14 | 14 |
|
15 | 15 | import copy |
16 | 16 |
|
17 | | -from tensorflow.experimental import dtensor |
18 | | -from tensorflow.experimental.dtensor import Layout |
19 | | -from tensorflow.keras.dtensor.experimental import LayoutMap |
20 | | - |
21 | 17 | from keras_nlp.api_export import keras_nlp_export |
22 | 18 | from keras_nlp.backend import keras |
23 | 19 | from keras_nlp.layers.modeling.position_embedding import PositionEmbedding |
@@ -191,71 +187,3 @@ def get_config(self): |
191 | 187 | @classproperty |
192 | 188 | def presets(cls): |
193 | 189 | return copy.deepcopy(backbone_presets) |
194 | | - |
195 | | - @classmethod |
196 | | - def create_layout_map(cls, mesh): |
197 | | - """Create a DTensor layout map for a GPT2Backbone. |
198 | | -
|
199 | | - Given a DTensor mesh describing a list of devices, this method returns a |
200 | | - DTensor layout map for creating a `keras_nlp.models.GPT2Backbone` |
201 | | - instance. This mapping describes how to distribute all model weights |
202 | | - across multiple devices. For an overview of DTensor concepts, see |
203 | | - [this guide](https://www.tensorflow.org/guide/dtensor_overview). |
204 | | -
|
205 | | - Args: |
206 | | - mesh: A 2D `tf.experimental.dtensor.Mesh` describing the arrangement |
207 | | - of devices for running distributed computation. The |
208 | | - first dimension in the mesh is expected to be for data parallel |
209 | | - distribution, and the second for model parallel distribution. |
210 | | -
|
211 | | - Returns: |
212 | | - A `tf.keras.dtensor.experimental.LayoutMap` which contains the |
213 | | - proper layout to weights mapping for the model parallel setting. |
214 | | -
|
215 | | - Examples: |
216 | | - ```python |
217 | | - keras.backend.experimental.enable_tf_random_generator() |
218 | | - keras.utils.set_random_seed(1337) |
219 | | -
|
220 | | - # Update both dimensions below for a multi-device setting. |
221 | | - mesh = dtensor.create_mesh([("batch", 1), ("model", 1)]) |
222 | | - layout_map = keras_nlp.models.GPT2Backbone.create_layout_map(mesh) |
223 | | -
|
224 | | - with layout_map.scope(): |
225 | | - model = keras_nlp.models.GPT2Backbone.from_preset("gpt2_base_en") |
226 | | - ``` |
227 | | - """ |
228 | | - # We assert the mesh is 2D, and assume the first mesh dim is for data |
229 | | - # parallel and the second dim is for model parallel. |
230 | | - mesh_shape = mesh.shape() |
231 | | - if len(mesh_shape) != 2: |
232 | | - raise ValueError( |
233 | | - f"Expect to create layout based on 2D mesh, received {mesh}" |
234 | | - ) |
235 | | - _, model_dim = mesh.dim_names |
236 | | - unshard_dim = dtensor.UNSHARDED |
237 | | - |
238 | | - layout_map = LayoutMap(mesh=mesh) |
239 | | - # Embedding sharding |
240 | | - layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh) |
241 | | - |
242 | | - # Transformer block sharding |
243 | | - layout_map[r".*_(query|key|value)_dense.kernel"] = Layout( |
244 | | - [unshard_dim, unshard_dim, model_dim], mesh |
245 | | - ) |
246 | | - layout_map[r".*_(query|key|value)_dense.bias"] = Layout( |
247 | | - [model_dim, unshard_dim], mesh |
248 | | - ) |
249 | | - layout_map[r".*_feedforward_intermediate_dense.kernel"] = Layout( |
250 | | - [unshard_dim, model_dim], mesh |
251 | | - ) |
252 | | - layout_map[r".*_feedforward_intermediate_dense.bias"] = Layout( |
253 | | - [model_dim], mesh |
254 | | - ) |
255 | | - layout_map[r".*_feedforward_output_dense.kernel"] = Layout( |
256 | | - [model_dim, unshard_dim], mesh |
257 | | - ) |
258 | | - layout_map[r".*_feedforward_output_dense.bias"] = Layout( |
259 | | - [unshard_dim], mesh |
260 | | - ) |
261 | | - return layout_map |
0 commit comments