11use  std:: iter:: once; 
22
3- use  ide_db:: { 
4-     syntax_helpers:: node_ext:: { is_pattern_cond,  single_let} , 
5-     ty_filter:: TryEnum , 
6- } ; 
3+ use  either:: Either ; 
4+ use  hir:: { Semantics ,  TypeInfo } ; 
5+ use  ide_db:: { RootDatabase ,  ty_filter:: TryEnum } ; 
76use  syntax:: { 
87    AstNode , 
9-     SyntaxKind :: { FN ,  FOR_EXPR ,  LOOP_EXPR ,  WHILE_EXPR ,  WHITESPACE } , 
10-     T , 
8+     SyntaxKind :: { CLOSURE_EXPR ,   FN ,  FOR_EXPR ,  LOOP_EXPR ,  WHILE_EXPR ,  WHITESPACE } , 
9+     SyntaxNode ,   T , 
1110    ast:: { 
1211        self , 
1312        edit:: { AstNodeEdit ,  IndentLevel } , 
@@ -44,12 +43,9 @@ use crate::{
4443// } 
4544// ``` 
4645pub ( crate )  fn  convert_to_guarded_return ( acc :  & mut  Assists ,  ctx :  & AssistContext < ' _ > )  -> Option < ( ) >  { 
47-     if  let  Some ( let_stmt)  = ctx. find_node_at_offset ( )  { 
48-         let_stmt_to_guarded_return ( let_stmt,  acc,  ctx) 
49-     }  else  if  let  Some ( if_expr)  = ctx. find_node_at_offset ( )  { 
50-         if_expr_to_guarded_return ( if_expr,  acc,  ctx) 
51-     }  else  { 
52-         None 
46+     match  ctx. find_node_at_offset :: < Either < ast:: LetStmt ,  ast:: IfExpr > > ( ) ? { 
47+         Either :: Left ( let_stmt)  => let_stmt_to_guarded_return ( let_stmt,  acc,  ctx) , 
48+         Either :: Right ( if_expr)  => if_expr_to_guarded_return ( if_expr,  acc,  ctx) , 
5349    } 
5450} 
5551
@@ -73,13 +69,7 @@ fn if_expr_to_guarded_return(
7369        return  None ; 
7470    } 
7571
76-     // Check if there is an IfLet that we can handle. 
77-     let  ( if_let_pat,  cond_expr)  = if  is_pattern_cond ( cond. clone ( ) )  { 
78-         let  let_ = single_let ( cond) ?; 
79-         ( Some ( let_. pat ( ) ?) ,  let_. expr ( ) ?) 
80-     }  else  { 
81-         ( None ,  cond) 
82-     } ; 
72+     let  let_chains = flat_let_chain ( cond) ; 
8373
8474    let  then_block = if_expr. then_branch ( ) ?; 
8575    let  then_block = then_block. stmt_list ( ) ?; 
@@ -106,11 +96,7 @@ fn if_expr_to_guarded_return(
10696
10797    let  parent_container = parent_block. syntax ( ) . parent ( ) ?; 
10898
109-     let  early_expression:  ast:: Expr  = match  parent_container. kind ( )  { 
110-         WHILE_EXPR  | LOOP_EXPR  | FOR_EXPR  => make:: expr_continue ( None ) , 
111-         FN  => make:: expr_return ( None ) , 
112-         _ => return  None , 
113-     } ; 
99+     let  early_expression:  ast:: Expr  = early_expression ( parent_container,  & ctx. sema ) ?; 
114100
115101    then_block. syntax ( ) . first_child_or_token ( ) . map ( |t| t. kind ( )  == T ! [ '{' ] ) ?; 
116102
@@ -132,32 +118,42 @@ fn if_expr_to_guarded_return(
132118        target, 
133119        |edit| { 
134120            let  if_indent_level = IndentLevel :: from_node ( if_expr. syntax ( ) ) ; 
135-             let  replacement = match  if_let_pat { 
136-                 None  => { 
137-                     // If. 
138-                     let  new_expr = { 
139-                         let  then_branch =
140-                             make:: block_expr ( once ( make:: expr_stmt ( early_expression) . into ( ) ) ,  None ) ; 
141-                         let  cond = invert_boolean_expression_legacy ( cond_expr) ; 
142-                         make:: expr_if ( cond,  then_branch,  None ) . indent ( if_indent_level) 
143-                     } ; 
144-                     new_expr. syntax ( ) . clone ( ) 
145-                 } 
146-                 Some ( pat)  => { 
121+             let  replacement = let_chains. into_iter ( ) . map ( |expr| { 
122+                 if  let  ast:: Expr :: LetExpr ( let_expr)  = & expr
123+                     && let  ( Some ( pat) ,  Some ( expr) )  = ( let_expr. pat ( ) ,  let_expr. expr ( ) ) 
124+                 { 
147125                    // If-let. 
148126                    let  let_else_stmt = make:: let_else_stmt ( 
149127                        pat, 
150128                        None , 
151-                         cond_expr , 
152-                         ast:: make:: tail_only_block_expr ( early_expression) , 
129+                         expr , 
130+                         ast:: make:: tail_only_block_expr ( early_expression. clone ( ) ) , 
153131                    ) ; 
154132                    let  let_else_stmt = let_else_stmt. indent ( if_indent_level) ; 
155133                    let_else_stmt. syntax ( ) . clone ( ) 
134+                 }  else  { 
135+                     // If. 
136+                     let  new_expr = { 
137+                         let  then_branch = make:: block_expr ( 
138+                             once ( make:: expr_stmt ( early_expression. clone ( ) ) . into ( ) ) , 
139+                             None , 
140+                         ) ; 
141+                         let  cond = invert_boolean_expression_legacy ( expr) ; 
142+                         make:: expr_if ( cond,  then_branch,  None ) . indent ( if_indent_level) 
143+                     } ; 
144+                     new_expr. syntax ( ) . clone ( ) 
156145                } 
157-             } ; 
146+             } ) ; 
158147
148+             let  newline = & format ! ( "\n {if_indent_level}" ) ; 
159149            let  then_statements = replacement
160-                 . children_with_tokens ( ) 
150+                 . enumerate ( ) 
151+                 . flat_map ( |( i,  node) | { 
152+                     ( i != 0 ) 
153+                         . then ( || make:: tokens:: whitespace ( newline) . into ( ) ) 
154+                         . into_iter ( ) 
155+                         . chain ( node. children_with_tokens ( ) ) 
156+                 } ) 
161157                . chain ( 
162158                    then_block_items
163159                        . syntax ( ) 
@@ -201,11 +197,7 @@ fn let_stmt_to_guarded_return(
201197            let_stmt. syntax ( ) . parent ( ) ?. ancestors ( ) . find_map ( ast:: BlockExpr :: cast) ?; 
202198        let  parent_container = parent_block. syntax ( ) . parent ( ) ?; 
203199
204-         match  parent_container. kind ( )  { 
205-             WHILE_EXPR  | LOOP_EXPR  | FOR_EXPR  => make:: expr_continue ( None ) , 
206-             FN  => make:: expr_return ( None ) , 
207-             _ => return  None , 
208-         } 
200+         early_expression ( parent_container,  & ctx. sema ) ?
209201    } ; 
210202
211203    acc. add ( 
@@ -232,6 +224,54 @@ fn let_stmt_to_guarded_return(
232224    ) 
233225} 
234226
227+ fn  early_expression ( 
228+     parent_container :  SyntaxNode , 
229+     sema :  & Semantics < ' _ ,  RootDatabase > , 
230+ )  -> Option < ast:: Expr >  { 
231+     let  return_none_expr = || { 
232+         let  none_expr = make:: expr_path ( make:: ext:: ident_path ( "None" ) ) ; 
233+         make:: expr_return ( Some ( none_expr) ) 
234+     } ; 
235+     if  let  Some ( fn_)  = ast:: Fn :: cast ( parent_container. clone ( ) ) 
236+         && let  Some ( fn_def)  = sema. to_def ( & fn_) 
237+         && let  Some ( TryEnum :: Option )  = TryEnum :: from_ty ( sema,  & fn_def. ret_type ( sema. db ) ) 
238+     { 
239+         return  Some ( return_none_expr ( ) ) ; 
240+     } 
241+     if  let  Some ( body)  = ast:: ClosureExpr :: cast ( parent_container. clone ( ) ) . and_then ( |it| it. body ( ) ) 
242+         && let  Some ( ret_ty)  = sema. type_of_expr ( & body) . map ( TypeInfo :: original) 
243+         && let  Some ( TryEnum :: Option )  = TryEnum :: from_ty ( sema,  & ret_ty) 
244+     { 
245+         return  Some ( return_none_expr ( ) ) ; 
246+     } 
247+ 
248+     Some ( match  parent_container. kind ( )  { 
249+         WHILE_EXPR  | LOOP_EXPR  | FOR_EXPR  => make:: expr_continue ( None ) , 
250+         FN  | CLOSURE_EXPR  => make:: expr_return ( None ) , 
251+         _ => return  None , 
252+     } ) 
253+ } 
254+ 
255+ fn  flat_let_chain ( mut  expr :  ast:: Expr )  -> Vec < ast:: Expr >  { 
256+     let  mut  chains = vec ! [ ] ; 
257+ 
258+     while  let  ast:: Expr :: BinExpr ( bin_expr)  = & expr
259+         && bin_expr. op_kind ( )  == Some ( ast:: BinaryOp :: LogicOp ( ast:: LogicOp :: And ) ) 
260+         && let  ( Some ( lhs) ,  Some ( rhs) )  = ( bin_expr. lhs ( ) ,  bin_expr. rhs ( ) ) 
261+     { 
262+         if  let  Some ( last)  = chains. pop_if ( |last| !matches ! ( last,  ast:: Expr :: LetExpr ( _) ) )  { 
263+             chains. push ( make:: expr_bin_op ( rhs,  ast:: BinaryOp :: LogicOp ( ast:: LogicOp :: And ) ,  last) ) ; 
264+         }  else  { 
265+             chains. push ( rhs) ; 
266+         } 
267+         expr = lhs; 
268+     } 
269+ 
270+     chains. push ( expr) ; 
271+     chains. reverse ( ) ; 
272+     chains
273+ } 
274+ 
235275#[ cfg( test) ]  
236276mod  tests { 
237277    use  crate :: tests:: { check_assist,  check_assist_not_applicable} ; 
@@ -268,6 +308,71 @@ fn main() {
268308        ) ; 
269309    } 
270310
311+     #[ test]  
312+     fn  convert_inside_fn_return_option ( )  { 
313+         check_assist ( 
314+             convert_to_guarded_return, 
315+             r#" 
316+ //- minicore: option 
317+ fn ret_option() -> Option<()> { 
318+     bar(); 
319+     if$0 true { 
320+         foo(); 
321+ 
322+         // comment 
323+         bar(); 
324+     } 
325+ } 
326+ "# , 
327+             r#" 
328+ fn ret_option() -> Option<()> { 
329+     bar(); 
330+     if false { 
331+         return None; 
332+     } 
333+     foo(); 
334+ 
335+     // comment 
336+     bar(); 
337+ } 
338+ "# , 
339+         ) ; 
340+     } 
341+ 
342+     #[ test]  
343+     fn  convert_inside_closure ( )  { 
344+         check_assist ( 
345+             convert_to_guarded_return, 
346+             r#" 
347+ fn main() { 
348+     let _f = || { 
349+         bar(); 
350+         if$0 true { 
351+             foo(); 
352+ 
353+             // comment 
354+             bar(); 
355+         } 
356+     } 
357+ } 
358+ "# , 
359+             r#" 
360+ fn main() { 
361+     let _f = || { 
362+         bar(); 
363+         if false { 
364+             return; 
365+         } 
366+         foo(); 
367+ 
368+         // comment 
369+         bar(); 
370+     } 
371+ } 
372+ "# , 
373+         ) ; 
374+     } 
375+ 
271376    #[ test]  
272377    fn  convert_let_inside_fn ( )  { 
273378        check_assist ( 
@@ -316,6 +421,82 @@ fn main() {
316421        ) ; 
317422    } 
318423
424+     #[ test]  
425+     fn  convert_if_let_result_inside_let ( )  { 
426+         check_assist ( 
427+             convert_to_guarded_return, 
428+             r#" 
429+ fn main() { 
430+     let _x = loop { 
431+         if$0 let Ok(x) = Err(92) { 
432+             foo(x); 
433+         } 
434+     }; 
435+ } 
436+ "# , 
437+             r#" 
438+ fn main() { 
439+     let _x = loop { 
440+         let Ok(x) = Err(92) else { continue }; 
441+         foo(x); 
442+     }; 
443+ } 
444+ "# , 
445+         ) ; 
446+     } 
447+ 
448+     #[ test]  
449+     fn  convert_if_let_chain_result ( )  { 
450+         check_assist ( 
451+             convert_to_guarded_return, 
452+             r#" 
453+ fn main() { 
454+     if$0 let Ok(x) = Err(92) 
455+         && x < 30 
456+         && let Some(y) = Some(8) 
457+     { 
458+         foo(x, y); 
459+     } 
460+ } 
461+ "# , 
462+             r#" 
463+ fn main() { 
464+     let Ok(x) = Err(92) else { return }; 
465+     if x >= 30 { 
466+         return; 
467+     } 
468+     let Some(y) = Some(8) else { return }; 
469+     foo(x, y); 
470+ } 
471+ "# , 
472+         ) ; 
473+ 
474+         check_assist ( 
475+             convert_to_guarded_return, 
476+             r#" 
477+ fn main() { 
478+     if$0 let Ok(x) = Err(92) 
479+         && x < 30 
480+         && y < 20 
481+         && let Some(y) = Some(8) 
482+     { 
483+         foo(x, y); 
484+     } 
485+ } 
486+ "# , 
487+             r#" 
488+ fn main() { 
489+     let Ok(x) = Err(92) else { return }; 
490+     if !(x < 30 && y < 20) { 
491+         return; 
492+     } 
493+     let Some(y) = Some(8) else { return }; 
494+     foo(x, y); 
495+ } 
496+ "# , 
497+         ) ; 
498+     } 
499+ 
319500    #[ test]  
320501    fn  convert_let_ok_inside_fn ( )  { 
321502        check_assist ( 
@@ -560,6 +741,32 @@ fn main() {
560741        ) ; 
561742    } 
562743
744+     #[ test]  
745+     fn  convert_let_stmt_inside_fn_return_option ( )  { 
746+         check_assist ( 
747+             convert_to_guarded_return, 
748+             r#" 
749+ //- minicore: option 
750+ fn foo() -> Option<i32> { 
751+     None 
752+ } 
753+ 
754+ fn ret_option() -> Option<i32> { 
755+     let x$0 = foo(); 
756+ } 
757+ "# , 
758+             r#" 
759+ fn foo() -> Option<i32> { 
760+     None 
761+ } 
762+ 
763+ fn ret_option() -> Option<i32> { 
764+     let Some(x) = foo() else { return None }; 
765+ } 
766+ "# , 
767+         ) ; 
768+     } 
769+ 
563770    #[ test]  
564771    fn  convert_let_stmt_inside_loop ( )  { 
565772        check_assist ( 
0 commit comments