Skip to content

Commit 095967e

Browse files
Copilotjustinchuby
andcommitted
Simplify aten_bilinear implementation by leveraging N-dimensional MatMul support
Co-authored-by: justinchuby <[email protected]>
1 parent 22c4c0f commit 095967e

File tree

1 file changed

+38
-30
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+38
-30
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,51 +1205,59 @@ def aten_bilinear(
12051205
# bias shape: (out_features) - optional
12061206
# output shape: (..., out_features)
12071207

1208-
# Decompose bilinear into MatMul operations:
1209-
# 1. Create outer product of input1 and input2
1210-
# 2. Reshape to flatten feature dimensions
1211-
# 3. Use MatMul with reshaped weight
1212-
1213-
# Get shapes for reshaping
1214-
input1_shape = op.Shape(input1)
1215-
weight_shape = op.Shape(weight)
1208+
# Leveraging N-dimensional MatMul, we can compute this as:
1209+
# 1. weight @ input2.T -> [out_features, in1_features, ...batch_dims]
1210+
# 2. input1 @ result -> [...batch_dims, out_features]
12161211

12171212
# Get dimensions
1213+
weight_shape = op.Shape(weight)
12181214
out_features = op.Gather(weight_shape, 0, axis=0)
12191215
in1_features = op.Gather(weight_shape, 1, axis=0)
12201216
in2_features = op.Gather(weight_shape, 2, axis=0)
12211217

1222-
# Get batch dimensions (everything except the last dimension)
1223-
input1_rank = Rank(input1)
1224-
batch_dims = op.Slice(input1_shape, [0], [input1_rank - 1])
1225-
batch_size = op.ReduceProd(batch_dims, keepdims=False)
1218+
# Step 1: Reshape weight for matrix multiplication
1219+
# weight: [out_features, in1_features, in2_features] -> [out_features * in1_features, in2_features]
1220+
weight_2d = op.Reshape(weight, op.Concat([op.Mul(out_features, in1_features)], [in2_features], axis=0))
12261221

1227-
# Create outer product: input1[..., i] * input2[..., j] -> [..., i, j]
1228-
# Reshape inputs to [batch_size, features] for easier handling
1229-
input1_2d = op.Reshape(input1, op.Concat([batch_size], [in1_features], axis=0))
1222+
# Get input2 shape for transpose
1223+
input2_shape = op.Shape(input2)
1224+
input2_rank = Rank(input2)
1225+
batch_dims = op.Slice(input2_shape, [0], [input2_rank - 1])
1226+
1227+
# Reshape input2 to 2D: [...batch_dims, in2_features] -> [batch_size, in2_features]
1228+
batch_size = op.ReduceProd(batch_dims, keepdims=False)
12301229
input2_2d = op.Reshape(input2, op.Concat([batch_size], [in2_features], axis=0))
12311230

1232-
# Create outer product using unsqueeze and broadcasting
1233-
input1_expanded = op.Unsqueeze(input1_2d, axes=[2]) # [batch_size, in1_features, 1]
1234-
input2_expanded = op.Unsqueeze(input2_2d, axes=[1]) # [batch_size, 1, in2_features]
1231+
# Transpose input2_2d: [batch_size, in2_features] -> [in2_features, batch_size]
1232+
input2_t = op.Transpose(input2_2d, perm=[1, 0])
1233+
1234+
# First MatMul: weight_2d @ input2_t
1235+
# [out_features * in1_features, in2_features] @ [in2_features, batch_size]
1236+
# -> [out_features * in1_features, batch_size]
1237+
temp = op.MatMul(weight_2d, input2_t)
12351238

1236-
# Outer product via broadcasting multiplication
1237-
outer_product = op.Mul(input1_expanded, input2_expanded) # [batch_size, in1_features, in2_features]
1239+
# Reshape temp: [out_features * in1_features, batch_size] -> [out_features, in1_features, batch_size]
1240+
temp = op.Reshape(temp, op.Concat([out_features], [in1_features], [batch_size], axis=0))
12381241

1239-
# Flatten the feature dimensions
1240-
features_total = op.Mul(in1_features, in2_features)
1241-
outer_flat = op.Reshape(outer_product, op.Concat([batch_size], [features_total], axis=0))
1242+
# Transpose temp for second matmul: [out_features, in1_features, batch_size] -> [batch_size, in1_features, out_features]
1243+
temp_t = op.Transpose(temp, perm=[2, 1, 0])
1244+
1245+
# Step 2: Prepare input1 for second MatMul
1246+
# Reshape input1 to 2D: [...batch_dims, in1_features] -> [batch_size, in1_features]
1247+
input1_2d = op.Reshape(input1, op.Concat([batch_size], [in1_features], axis=0))
12421248

1243-
# Reshape weight to 2D: [out_features, in1_features * in2_features]
1244-
weight_2d = op.Reshape(weight, op.Concat([out_features], [features_total], axis=0))
1249+
# Expand input1 for batch matrix multiplication: [batch_size, in1_features] -> [batch_size, 1, in1_features]
1250+
input1_expanded = op.Unsqueeze(input1_2d, axes=[1])
12451251

1246-
# Transpose weight for MatMul: [in1_features * in2_features, out_features]
1247-
weight_t = op.Transpose(weight_2d, perm=[1, 0])
1252+
# Second MatMul: input1_expanded @ temp_t
1253+
# [batch_size, 1, in1_features] @ [batch_size, in1_features, out_features]
1254+
# -> [batch_size, 1, out_features]
1255+
result = op.MatMul(input1_expanded, temp_t)
12481256

1249-
# Matrix multiplication: [batch_size, out_features]
1250-
result = op.MatMul(outer_flat, weight_t)
1257+
# Remove singleton dimension: [batch_size, 1, out_features] -> [batch_size, out_features]
1258+
result = op.Squeeze(result, axes=[1])
12511259

1252-
# Reshape back to original batch dimensions + out_features
1260+
# Reshape back to original batch dimensions
12531261
output_shape = op.Concat(batch_dims, [out_features], axis=0)
12541262
result = op.Reshape(result, output_shape)
12551263

0 commit comments

Comments
 (0)