@@ -363,6 +363,8 @@ def RsubModule_noalpha_basic(module, tu: TestUtils):
363363 module .forward (tu .rand (3 , 4 ))
364364
365365# ==============================================================================
366+
367+
366368class ElementwiseMulScalarModule (torch .nn .Module ):
367369 def __init__ (self ):
368370 super ().__init__ ()
@@ -378,7 +380,52 @@ def forward(self, x):
378380@register_test_case (module_factory = lambda : ElementwiseMulScalarModule ())
379381def ElementwiseMulScalarModule_basic (module , tu : TestUtils ):
380382 module .forward (tu .rand (3 , 4 ))
381-
383+
384+
385+
386+ class ElementwiseMulTensorFloatModule (torch .nn .Module ):
387+ def __init__ (self ):
388+ super ().__init__ ()
389+
390+ @export
391+ @annotate_args ([
392+ None ,
393+ ([- 1 ], torch .float32 , True ),
394+ ([- 1 ], torch .float64 , True ),
395+ ])
396+ def forward (self , a , b ):
397+ return torch .mul (a , b )
398+
399+
400+ @register_test_case (
401+ module_factory = lambda : ElementwiseMulTensorFloatModule ())
402+ def ElementwiseMulTensorFloatModule_basic (module , tu : TestUtils ):
403+ module .forward (
404+ tu .rand (4 ),
405+ tu .rand (4 ).type (torch .float64 ))
406+
407+ class ElementwiseMulTensorIntModule (torch .nn .Module ):
408+ def __init__ (self ):
409+ super ().__init__ ()
410+
411+ @export
412+ @annotate_args ([
413+ None ,
414+ ([- 1 ], torch .int32 , True ),
415+ ([- 1 ], torch .int64 , True ),
416+ ])
417+ def forward (self , a , b ):
418+ return torch .mul (a , b )
419+
420+
421+ @register_test_case (
422+ module_factory = lambda : ElementwiseMulTensorIntModule ())
423+ def ElementwiseMulTensorIntModule_basic (module , tu : TestUtils ):
424+ module .forward (
425+ torch .randint (10 , [4 ]).type (torch .int32 ),
426+ torch .randint (10 , [4 ]))
427+
428+
382429# ==============================================================================
383430class ElementwiseLogModule (torch .nn .Module ):
384431 def __init__ (self ):
@@ -553,7 +600,32 @@ def forward(self, x):
553600def ElementwiseDivScalarModule_basic (module , tu : TestUtils ):
554601 module .forward (tu .rand (3 , 4 ))
555602
603+
604+ class ElementwiseDivTensorFloatModule (torch .nn .Module ):
605+ def __init__ (self ):
606+ super ().__init__ ()
607+
608+ @export
609+ @annotate_args ([
610+ None ,
611+ ([- 1 ], torch .float32 , True ),
612+ ([- 1 ], torch .float64 , True ),
613+ ])
614+ def forward (self , a , b ):
615+ return torch .div (a , b )
616+
617+
618+ @register_test_case (
619+ module_factory = lambda : ElementwiseDivTensorFloatModule ())
620+ def ElementwiseDivTensorFloatModule_basic (module , tu : TestUtils ):
621+ module .forward (
622+ tu .rand (4 ),
623+ tu .rand (4 ).type (torch .float64 ))
624+
625+
556626# ==============================================================================
627+
628+
557629class ElementwiseAndIntegerModule (torch .nn .Module ):
558630 def __init__ (self ):
559631 super ().__init__ ()
@@ -573,3 +645,5 @@ def forward(self, x, y):
573645def ElementwiseAndIntegerModule_basic (module , tu : TestUtils ):
574646 module .forward (torch .randint (- 10 , 10 , (3 , 4 )).to (torch .int32 ),
575647 torch .randint (- 10 , 10 , (3 , 4 )))
648+
649+
0 commit comments