diff --git a/docs/JAX FP8 matmul tutorial.ipynb b/docs/JAX FP8 matmul tutorial.ipynb index e335079..f816f82 100644 --- a/docs/JAX FP8 matmul tutorial.ipynb +++ b/docs/JAX FP8 matmul tutorial.ipynb @@ -34,11 +34,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "51775bad-18ad-49b7-9371-930b3704a294", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Notebook JAX version: 0.4.31\n", + "Notebook JAX device: cuda:0\n" + ] + } + ], + "source": [ + "import jax\n", + "\n", + "print(f\"Notebook JAX version: {jax.__version__}\")\n", + "print(f\"Notebook JAX device: {jax.devices()[0]}\")" + ] }, { "cell_type": "markdown", @@ -52,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "fb62c752-f7ba-4714-8605-88e2afcff88f", "metadata": {}, "outputs": [ @@ -110,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "9be90f27-5520-45f6-a42d-b309572e6e91", "metadata": {}, "outputs": [ @@ -118,7 +132,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-10-01 14:39:16.245591: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" + "2024-10-02 08:31:01.744162: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.5.82). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" ] }, { @@ -159,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "7edfa758-bf4e-49fa-8c5d-5dc9c0c2c346", "metadata": {}, "outputs": [ @@ -206,7 +220,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "72d805ea-89b6-457d-9558-ff31fdd23d35", "metadata": {}, "outputs": [ @@ -280,7 +294,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "1ed9d08e-b18a-4fe7-bcba-72b95ddf6e68", "metadata": {}, "outputs": [ @@ -325,7 +339,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "id": "b9a608d7-6cf8-457b-8275-bdcacc9b06fe", "metadata": {}, "outputs": [ @@ -333,9 +347,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"230c40ffa1e1e3ba7f06e4a65ac9e2bd\"}\n", + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"880fbc3fe38d16fac872dc7542132e26\"}\n", "\n", - "ENTRY %main.22 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n", + "ENTRY %main.25 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n", " %constant_1 = f32[] constant(1)\n", " %Arg_4.5.0 = f32[] parameter(4)\n", " %Arg_3.4.0 = f32[] parameter(3)\n", @@ -352,22 +366,24 @@ ], "source": [ "e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n", + "# \"Dequantization\" datatype (note: required to be BF16!)\n", + "dqt_dtype = jnp.bfloat16\n", "\n", "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", " # Dequantize x and y\n", - " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", - " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", + " a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)\n", + " b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)\n", " \n", " # Do the matmul (NOTE: adding transpose to reduce on last axis).\n", - " d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n", + " d_dqt = jax.lax.dot(a_dqt, b_dqt.T)\n", " \n", " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", - " d_fp32 = d_fp32 * d_scale\n", + " d_dqt = d_dqt * d_scale.astype(dqt_dtype)\n", " # NOTE: clamping is NOT optional for proper pattern matching!\n", - " d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n", + " d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))\n", " # (Re)Quantize the scaled matmul output.\n", - " return d_fp32.astype(jnp.float8_e4m3fn)\n", + " return d_dqt.astype(jnp.float8_e4m3fn)\n", "\n", "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", @@ -387,7 +403,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "id": "44f28bbb-d4c6-4170-a736-76d667d73f97", "metadata": {}, "outputs": [ @@ -395,9 +411,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"f1fb5db9dad54941d7d17e04fdbe9515\"}\n", + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"ba54f58f7ec56c7beda9299cd16bb7b2\"}\n", "\n", - "ENTRY %main.28 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n", + "ENTRY %main.31 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n", " %constant_1_0 = f32[] constant(1)\n", " %Arg_4.5.0 = f32[] parameter(4)\n", " %Arg_3.4.0 = f32[] parameter(3)\n", @@ -414,24 +430,26 @@ ], "source": [ "e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n", + "# \"Dequantization\" datatype (note: required to be BF16!)\n", + "dqt_dtype = jnp.bfloat16\n", "\n", "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", " # Dequantize x and y\n", - " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", - " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", + " a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)\n", + " b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)\n", " \n", " # Do the matmul (NOTE: adding transpose to simplify HLO).\n", - " d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n", + " d_dqt = jax.lax.dot(a_dqt, b_dqt.T)\n", " # ReLU non-linearity. Note: applied before scaling.\n", - " d_fp32 = jax.nn.relu(d_fp32)\n", + " d_dqt = jax.nn.relu(d_dqt)\n", " \n", " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", - " d_fp32 = d_fp32 * d_scale\n", + " d_dqt = d_dqt * d_scale.astype(dqt_dtype)\n", " # NOTE: clamping is NOT optional for proper pattern matching!\n", - " d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n", + " d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))\n", " # (Re)Quantize the scaled matmul output.\n", - " return d_fp32.astype(jnp.float8_e4m3fn)\n", + " return d_dqt.astype(jnp.float8_e4m3fn)\n", "\n", "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", @@ -449,7 +467,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "id": "2ca21eae-8b0c-454b-b670-1ef0d5935a5c", "metadata": {}, "outputs": [ @@ -504,7 +522,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 16, "id": "a65cf3be-c465-49ae-9e90-2ada54dba84a", "metadata": {}, "outputs": [ @@ -512,9 +530,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->(f8e4m3fn[32,128]{1,0}, f32[])}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs=\"5d38b8087de7ebb664888f640beb2017\"}\n", + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->(f8e4m3fn[32,128]{1,0}, f32[])}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs=\"206494040898ad9e7c872e73f922a9e5\"}\n", "\n", - "ENTRY %main.36 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> (f8e4m3fn[32,128], f32[]) {\n", + "ENTRY %main.40 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> (f8e4m3fn[32,128], f32[]) {\n", " %constant_1_0 = f32[] constant(1)\n", " %Arg_4.5.0 = f32[] parameter(4)\n", " %Arg_3.4.0 = f32[] parameter(3)\n", @@ -524,7 +542,7 @@ " %cublas-gemm.2.clone.1.0 = (f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1_0, /*index=5*/f32[] %Arg_4.5.0), custom_call_target=\"__cublas$lt$matmul$f8\"\n", " %get-tuple-element.1.0 = f32[] get-tuple-element((f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) %cublas-gemm.2.clone.1.0), index=1\n", " %get-tuple-element.4 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, f32[], s8[33554432]{0}) %cublas-gemm.2.clone.1.0), index=0\n", - " ROOT %tuple.35.0 = (f8e4m3fn[32,128]{1,0}, f32[]) tuple(f8e4m3fn[32,128]{1,0} %get-tuple-element.4, f32[] %get-tuple-element.1.0)\n", + " ROOT %tuple.39.0 = (f8e4m3fn[32,128]{1,0}, f32[]) tuple(f8e4m3fn[32,128]{1,0} %get-tuple-element.4, f32[] %get-tuple-element.1.0)\n", "}\n", "\n", "\n" @@ -533,26 +551,28 @@ ], "source": [ "e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n", + "# \"Dequantization\" datatype (note: required to be BF16!)\n", + "dqt_dtype = jnp.bfloat16\n", "\n", "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", " # Dequantize x and y\n", - " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", - " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", + " a_dqt = a_fp8.astype(dqt_dtype) * a_scale.astype(dqt_dtype)\n", + " b_dqt = b_fp8.astype(dqt_dtype) * b_scale.astype(dqt_dtype)\n", " \n", " # Do the matmul (NOTE: adding transpose to simplify HLO).\n", - " d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n", + " d_dqt = jax.lax.dot(a_dqt, b_dqt.T)\n", " # ReLU non-linearity. Note: needs to be before the scaling.\n", - " d_fp32 = jax.nn.relu(d_fp32)\n", + " d_dqt = jax.nn.relu(d_dqt)\n", " # Delayed rescaling: capture the raw output scaling for latter.\n", - " out_scale = jnp.max(jnp.abs(d_fp32))\n", + " out_scale = jnp.max(jnp.abs(d_dqt)).astype(jnp.float32)\n", "\n", " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", - " d_fp32 = d_fp32 * d_scale\n", + " d_dqt = d_dqt * d_scale.astype(dqt_dtype)\n", " # NOTE: clamping is NOT optional for proper pattern matching!\n", - " d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n", + " d_dqt = jax.lax.clamp(dqt_dtype(-e4m3_max), d_dqt, dqt_dtype(e4m3_max))\n", " # (Re)Quantize the scaled matmul output.\n", - " return d_fp32.astype(jnp.float8_e4m3fn), out_scale\n", + " return d_dqt.astype(jnp.float8_e4m3fn), out_scale\n", "\n", "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", @@ -570,7 +590,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "id": "20d4d088-6563-44c2-86a1-ab2c34fe4e8e", "metadata": {}, "outputs": [