@@ -1271,6 +1271,19 @@ def _where_input_wrangler(
12711271 ),
12721272 TorchLibOpInfo ("polar" , core_ops .aten_polar ),
12731273 TorchLibOpInfo ("pow" , core_ops .aten_pow ),
1274+ TorchLibOpInfo ("prod" , core_ops .aten_prod ).skip (
1275+ matcher = lambda sample : sample .kwargs .get ("dim" ) is not None
1276+ or sample .kwargs .get ("keepdim" ) is not None
1277+ or sample .kwargs .get ("dtype" ) != - 1 ,
1278+ reason = "this Aten overload only accept 1 inputs: self" ,
1279+ ),
1280+ TorchLibOpInfo ("prod_dim_int" , core_ops .aten_prod_dim_int ).skip (
1281+ matcher = lambda sample : (
1282+ sample .kwargs .get ("dim" ) is None and sample .kwargs .get ("keepdim" ) is None
1283+ )
1284+ or sample .kwargs .get ("dtype" ) != - 1 ,
1285+ reason = "this Aten overload can accept 3 inputs:(self, dim, keepdim)" ,
1286+ ),
12741287 TorchLibOpInfo ("nn.functional.prelu" , core_ops .aten_prelu ),
12751288 TorchLibOpInfo ("ops.aten.rand" , core_ops .aten_rand , nondeterministic = True ),
12761289 TorchLibOpInfo ("ops.aten.rand_like" , core_ops .aten_rand_like , nondeterministic = True ),
@@ -2203,6 +2216,7 @@ def _where_input_wrangler(
22032216 OPS_DB , "ops.aten._log_softmax" , ("ops.aten._log_softmax_half" ,)
22042217)
22052218ops_test_common .duplicate_opinfo (OPS_DB , "ops.aten._softmax" , ("ops.aten._softmax_half" ,))
2219+ ops_test_common .duplicate_opinfo (OPS_DB , "prod" , ("prod_dim_int" ,))
22062220ops_test_common .duplicate_opinfo (OPS_DB , "round" , ("round_decimals" ,))
22072221ops_test_common .duplicate_opinfo (OPS_DB , "squeeze" , ("squeeze_dim" ,))
22082222ops_test_common .duplicate_opinfo (OPS_DB , "view_as_complex" , ("view_as_complex_copy" ,))
0 commit comments