@@ -207,6 +207,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
207207        ] =  {}
208208        self ._expr_args : dict [sympy .Expr , SymbolArgument ] =  {}
209209        self ._constexpr_args : dict [str , ConstExprArg ] =  {}
210+         self ._constexpr_host_defs : set [str ] =  set ()
210211        self ._tensor_properties : dict [
211212            tuple [type [TensorPropertyArg ], torch .Tensor , int ], TensorPropertyArg 
212213        ] =  {}
@@ -282,11 +283,7 @@ def block_size_var(self, block_id: int) -> str | None:
282283
283284            var_name  =  self .new_var (f"_BLOCK_SIZE_{ block_id }  )
284285            self .block_size_var_cache [key ] =  var_name 
285-             host_expr  =  HostFunction .current ().literal_expr (block_value )
286-             if  self .constexpr_arg (var_name , host_expr ):
287-                 self .codegen .host_statements .append (
288-                     statement_from_string (f"{ var_name } { host_expr }  )
289-                 )
286+             self .constexpr_arg_with_host_def (var_name , block_value )
290287
291288        return  self .block_size_var_cache [key ]
292289
@@ -484,14 +481,55 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484481            self ._expr_args [sym ] =  arg 
485482        return  self ._expr_args [sym ]
486483
487-     def  constexpr_arg (self , name : str , host_str :  str  |  None  =  None ) ->  bool :
484+     def  constexpr_arg (self , name : str , value :  object  |  None  =  None ) ->  bool :
488485        """Create a constexpr argument, returns True if created, False if already exists.""" 
489486        if  name  in  self ._constexpr_args :
490487            return  False 
491-         self ._constexpr_args [name ] =  rv  =  ConstExprArg (name , host_str  or  name )
488+         host_str  =  name  if  value  is  None  else  self ._format_constexpr_value (value )
489+         self ._constexpr_args [name ] =  rv  =  ConstExprArg (name , host_str )
492490        self .arguments .append (rv )
493491        return  True 
494492
493+     def  constexpr_arg_with_host_def (self , name : str , value : object ) ->  None :
494+         """Create a constexpr argument and add its host-side definition if needed.""" 
495+         created  =  self .constexpr_arg (name , value )
496+         host_expr  =  self ._constexpr_args [name ].host_str ()
497+         if  created  or  name  not  in self ._constexpr_host_defs :
498+             self .codegen .host_statements .append (
499+                 statement_from_string (f"{ name } { host_expr }  )
500+             )
501+         self ._constexpr_host_defs .add (name )
502+ 
503+     def  _format_constexpr_value (self , value : object ) ->  str :
504+         if  isinstance (value , str ):
505+             return  value 
506+         if  isinstance (value , (int , float , bool )):
507+             return  repr (value )
508+ 
509+         # Extract sympy expression from torch symbolic types 
510+         if  isinstance (value , (torch .SymInt , torch .SymFloat , torch .SymBool )):
511+             value  =  value ._sympy_ ()
512+ 
513+         # Handle sympy expressions (sanitize by replacing triton_helpers functions) 
514+         if  isinstance (value , sympy .Expr ):
515+             expr  =  cast (
516+                 "sympy.Expr" ,
517+                 value .replace (
518+                     lambda  node : isinstance (node , sympy .Function )
519+                     and  getattr (node .func , "__name__" , "" )
520+                     ==  "triton_helpers.div_floor_integer" ,
521+                     lambda  node : sympy .floor (node .args [0 ] /  node .args [1 ]),  # pyright: ignore[reportAttributeAccessIssue] 
522+                 ).replace (
523+                     lambda  node : isinstance (node , sympy .Function )
524+                     and  getattr (node .func , "__name__" , "" )
525+                     ==  "triton_helpers.remainder_integer" ,
526+                     lambda  node : sympy .Mod (node .args [0 ], node .args [1 ]),  # pyright: ignore[reportAttributeAccessIssue] 
527+                 ),
528+             )
529+             return  HostFunction .current ().sympy_expr (expr )
530+ 
531+         return  HostFunction .current ().literal_expr (value )
532+ 
495533    def  _tensor_property (
496534        self ,
497535        prop_cls : type [_P ],
@@ -556,7 +594,12 @@ def codegen_function_def(self) -> list[ast.stmt]:
556594        ]
557595
558596    def  codegen_function_call (self ) ->  ast .AST :
559-         args  =  [arg .host_str () for  arg  in  self .sorted_args ()]
597+         args  =  []
598+         for  arg  in  self .sorted_args ():
599+             if  isinstance (arg , ConstExprArg ) and  arg .name  in  self ._constexpr_host_defs :
600+                 args .append (arg .name )
601+             else :
602+                 args .append (arg .host_str ())
560603
561604        if  self .has_rng_ops ():
562605            # Pass the host-side seed buffer variable to the kernel 
0 commit comments