Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor:
return op.CastLike(sampled, self)


@torch_op("aten::bilinear", trace_only=True)
def aten_bilinear(
input1: TensorType,
input2: TensorType,
Expand All @@ -1197,7 +1198,23 @@ def aten_bilinear(
) -> TensorType:
"""bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor"""

raise NotImplementedError()
# Bilinear transformation: y = x1^T A x2 + b
# input1 shape: (..., in1_features)
# input2 shape: (..., in2_features)
# weight shape: (out_features, in1_features, in2_features)
# bias shape: (out_features) - optional
# output shape: (..., out_features)

# Use Einsum to compute the bilinear transformation
# "...i,oij,...j->...o" means:
# - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o")

# Add bias if provided
if bias is not None:
result = op.Add(result, bias)

return result


def aten_binary_cross_entropy_with_logits(
Expand Down
38 changes: 38 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@ def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs)
yield opinfo_core.SampleInput(item, dtype=dtype)


def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for bilinear operation."""
del op_info
del kwargs

make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)

# Test cases: (batch_size, in1_features, in2_features, out_features)
cases = [
(2, 3, 4, 5), # Basic case
(1, 2, 2, 1), # Minimal case
(3, 5, 7, 4), # Different dimensions
(2, 1, 1, 3), # Single input features
]

for batch_size, in1_features, in2_features, out_features in cases:
input1 = make_arg((batch_size, in1_features))
input2 = make_arg((batch_size, in2_features))
weight = make_arg((out_features, in1_features, in2_features))
bias = make_arg((out_features,))

# Test with bias
yield opinfo_core.SampleInput(input1, args=(input2, weight, bias))

# Test without bias (only for first case to avoid too many tests)
if batch_size == 2:
yield opinfo_core.SampleInput(input1, args=(input2, weight, None))


def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs):
del op_info

Expand Down Expand Up @@ -2180,6 +2211,13 @@ def __init__(self):
# To avoid name duplication, it is possible to rename the OpInfo and specify
# the `op` field explicitly.
OP_DB: List[opinfo_core.OpInfo] = [
opinfo_core.OpInfo(
"bilinear",
op=torch.nn.functional.bilinear,
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs_bilinear,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.bernoulli.p",
aten_name="bernoulli.p",
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def _where_input_wrangler(
),
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}),
TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True),
TorchLibOpInfo("bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (1e-4, 1e-4)}),
TorchLibOpInfo(
# This string is a unique ID. In extra_opinfo.py, we
# also define test data for this ID with
Expand Down
Loading