@@ -1205,51 +1205,59 @@ def aten_bilinear(
1205
1205
# bias shape: (out_features) - optional
1206
1206
# output shape: (..., out_features)
1207
1207
1208
- # Decompose bilinear into MatMul operations:
1209
- # 1. Create outer product of input1 and input2
1210
- # 2. Reshape to flatten feature dimensions
1211
- # 3. Use MatMul with reshaped weight
1212
-
1213
- # Get shapes for reshaping
1214
- input1_shape = op .Shape (input1 )
1215
- weight_shape = op .Shape (weight )
1208
+ # Leveraging N-dimensional MatMul, we can compute this as:
1209
+ # 1. weight @ input2.T -> [out_features, in1_features, ...batch_dims]
1210
+ # 2. input1 @ result -> [...batch_dims, out_features]
1216
1211
1217
1212
# Get dimensions
1213
+ weight_shape = op .Shape (weight )
1218
1214
out_features = op .Gather (weight_shape , 0 , axis = 0 )
1219
1215
in1_features = op .Gather (weight_shape , 1 , axis = 0 )
1220
1216
in2_features = op .Gather (weight_shape , 2 , axis = 0 )
1221
1217
1222
- # Get batch dimensions (everything except the last dimension)
1223
- input1_rank = Rank (input1 )
1224
- batch_dims = op .Slice (input1_shape , [0 ], [input1_rank - 1 ])
1225
- batch_size = op .ReduceProd (batch_dims , keepdims = False )
1218
+ # Step 1: Reshape weight for matrix multiplication
1219
+ # weight: [out_features, in1_features, in2_features] -> [out_features * in1_features, in2_features]
1220
+ weight_2d = op .Reshape (weight , op .Concat ([op .Mul (out_features , in1_features )], [in2_features ], axis = 0 ))
1226
1221
1227
- # Create outer product: input1[..., i] * input2[..., j] -> [..., i, j]
1228
- # Reshape inputs to [batch_size, features] for easier handling
1229
- input1_2d = op .Reshape (input1 , op .Concat ([batch_size ], [in1_features ], axis = 0 ))
1222
+ # Get input2 shape for transpose
1223
+ input2_shape = op .Shape (input2 )
1224
+ input2_rank = Rank (input2 )
1225
+ batch_dims = op .Slice (input2_shape , [0 ], [input2_rank - 1 ])
1226
+
1227
+ # Reshape input2 to 2D: [...batch_dims, in2_features] -> [batch_size, in2_features]
1228
+ batch_size = op .ReduceProd (batch_dims , keepdims = False )
1230
1229
input2_2d = op .Reshape (input2 , op .Concat ([batch_size ], [in2_features ], axis = 0 ))
1231
1230
1232
- # Create outer product using unsqueeze and broadcasting
1233
- input1_expanded = op .Unsqueeze (input1_2d , axes = [2 ]) # [batch_size, in1_features, 1]
1234
- input2_expanded = op .Unsqueeze (input2_2d , axes = [1 ]) # [batch_size, 1, in2_features]
1231
+ # Transpose input2_2d: [batch_size, in2_features] -> [in2_features, batch_size]
1232
+ input2_t = op .Transpose (input2_2d , perm = [1 , 0 ])
1233
+
1234
+ # First MatMul: weight_2d @ input2_t
1235
+ # [out_features * in1_features, in2_features] @ [in2_features, batch_size]
1236
+ # -> [out_features * in1_features, batch_size]
1237
+ temp = op .MatMul (weight_2d , input2_t )
1235
1238
1236
- # Outer product via broadcasting multiplication
1237
- outer_product = op .Mul ( input1_expanded , input2_expanded ) # [batch_size, in1_features, in2_features]
1239
+ # Reshape temp: [out_features * in1_features, batch_size] -> [out_features, in1_features, batch_size]
1240
+ temp = op .Reshape ( temp , op . Concat ([ out_features ], [ in1_features ], [ batch_size ], axis = 0 ))
1238
1241
1239
- # Flatten the feature dimensions
1240
- features_total = op .Mul (in1_features , in2_features )
1241
- outer_flat = op .Reshape (outer_product , op .Concat ([batch_size ], [features_total ], axis = 0 ))
1242
+ # Transpose temp for second matmul: [out_features, in1_features, batch_size] -> [batch_size, in1_features, out_features]
1243
+ temp_t = op .Transpose (temp , perm = [2 , 1 , 0 ])
1244
+
1245
+ # Step 2: Prepare input1 for second MatMul
1246
+ # Reshape input1 to 2D: [...batch_dims, in1_features] -> [batch_size, in1_features]
1247
+ input1_2d = op .Reshape (input1 , op .Concat ([batch_size ], [in1_features ], axis = 0 ))
1242
1248
1243
- # Reshape weight to 2D : [out_features , in1_features * in2_features ]
1244
- weight_2d = op .Reshape ( weight , op . Concat ([ out_features ], [ features_total ], axis = 0 ) )
1249
+ # Expand input1 for batch matrix multiplication : [batch_size , in1_features] -> [batch_size, 1, in1_features ]
1250
+ input1_expanded = op .Unsqueeze ( input1_2d , axes = [ 1 ] )
1245
1251
1246
- # Transpose weight for MatMul: [in1_features * in2_features, out_features]
1247
- weight_t = op .Transpose (weight_2d , perm = [1 , 0 ])
1252
+ # Second MatMul: input1_expanded @ temp_t
1253
+ # [batch_size, 1, in1_features] @ [batch_size, in1_features, out_features]
1254
+ # -> [batch_size, 1, out_features]
1255
+ result = op .MatMul (input1_expanded , temp_t )
1248
1256
1249
- # Matrix multiplication: [batch_size, out_features]
1250
- result = op .MatMul ( outer_flat , weight_t )
1257
+ # Remove singleton dimension: [batch_size, 1, out_features] -> [batch_size, out_features]
1258
+ result = op .Squeeze ( result , axes = [ 1 ] )
1251
1259
1252
- # Reshape back to original batch dimensions + out_features
1260
+ # Reshape back to original batch dimensions
1253
1261
output_shape = op .Concat (batch_dims , [out_features ], axis = 0 )
1254
1262
result = op .Reshape (result , output_shape )
1255
1263
0 commit comments