@@ -22,6 +22,12 @@ pub(crate) enum EventContext {
22
22
Body ,
23
23
/// The current expression is a children of a field access expression.
24
24
FieldAccess ,
25
+ /// The current expression is one of the expressions that initializes a
26
+ /// variable in a `with` statement. For instance, the `<expr>` in:
27
+ /// ```text
28
+ /// with some_var = <expr> : ( .. )
29
+ /// ```
30
+ WithDeclaration ,
25
31
}
26
32
27
33
/// An iterator that conducts a Depth First Search (DFS) traversal of the IR
@@ -154,12 +160,12 @@ impl<'a> DFSWithScopeIter<'a> {
154
160
/// corresponding to the `for` statement.
155
161
///
156
162
/// Let's see a more complex example, consider the following YARA
157
- /// expression, where we are positioned at `<inner expr>`:
163
+ /// expression, where we are positioned at `<for body expr>`:
158
164
///
159
165
/// ```text
160
- /// with a = <expr> : (
166
+ /// with a = <init expr> : (
161
167
/// for any i in (0..10) : (
162
- /// <inner expr>
168
+ /// <for body expr>
163
169
/// )
164
170
/// )
165
171
/// ```
@@ -168,6 +174,9 @@ impl<'a> DFSWithScopeIter<'a> {
168
174
/// statement first and then the [`ExprId`] corresponding to the `for`
169
175
/// statement. The iterator processes the scopes starting from the outermost
170
176
/// scope and progresses inward.
177
+ ///
178
+ /// If we are positioned at `<init expr>`, the iterator returns the
179
+ /// [`ExprId`] that corresponds to the `with` statement.
171
180
pub fn scopes ( & self ) -> impl DoubleEndedIterator < Item = ExprId > + ' _ {
172
181
self . scopes . iter ( ) . cloned ( )
173
182
}
@@ -192,15 +201,21 @@ impl Iterator for DFSWithScopeIter<'_> {
192
201
}
193
202
let next = match self . dfs . next ( ) ? {
194
203
Event :: Enter ( ( expr_id, _, ctx) ) => {
195
- if matches ! ( ctx, EventContext :: Body ) {
204
+ if matches ! (
205
+ ctx,
206
+ EventContext :: Body | EventContext :: WithDeclaration
207
+ ) {
196
208
// If the current expression is the body of some other
197
209
// expression, the current expression must have a parent.
198
210
self . scopes . push ( self . ir . get_parent ( expr_id) . unwrap ( ) ) ;
199
211
}
200
212
Event :: Enter ( ( expr_id, ctx) )
201
213
}
202
214
Event :: Leave ( ( expr_id, _, ctx) ) => {
203
- if matches ! ( ctx, EventContext :: Body ) {
215
+ if matches ! (
216
+ ctx,
217
+ EventContext :: Body | EventContext :: WithDeclaration
218
+ ) {
204
219
// Don't remove the scope at top of the stack right away.
205
220
// If the user calls `scopes()` while processing the Leave
206
221
// event, we want the current context to be there. We just
@@ -383,7 +398,8 @@ pub(super) fn dfs_common(
383
398
Expr :: With ( with) => {
384
399
stack. push ( Event :: Enter ( ( with. body , EventContext :: Body ) ) ) ;
385
400
for ( _id, expr) in with. declarations . iter ( ) . rev ( ) {
386
- stack. push ( Event :: Enter ( ( * expr, EventContext :: None ) ) )
401
+ stack
402
+ . push ( Event :: Enter ( ( * expr, EventContext :: WithDeclaration ) ) )
387
403
}
388
404
}
389
405
}
@@ -612,17 +628,17 @@ mod tests {
612
628
613
629
assert ! ( matches!(
614
630
dfs. next( ) ,
615
- Some ( Event :: Enter ( ( expr_id, EventContext :: None ) ) ) if expr_id == const_1
631
+ Some ( Event :: Enter ( ( expr_id, EventContext :: WithDeclaration ) ) ) if expr_id == const_1
616
632
) ) ;
617
633
618
- assert_eq ! ( dfs. scopes( ) . collect:: <Vec <_>>( ) , vec![ ] ) ;
634
+ assert_eq ! ( dfs. scopes( ) . collect:: <Vec <_>>( ) , vec![ with ] ) ;
619
635
620
636
assert ! ( matches!(
621
637
dfs. next( ) ,
622
- Some ( Event :: Leave ( ( expr_id, EventContext :: None ) ) ) if expr_id == const_1
638
+ Some ( Event :: Leave ( ( expr_id, EventContext :: WithDeclaration ) ) ) if expr_id == const_1
623
639
) ) ;
624
640
625
- assert_eq ! ( dfs. scopes( ) . collect:: <Vec <_>>( ) , vec![ ] ) ;
641
+ assert_eq ! ( dfs. scopes( ) . collect:: <Vec <_>>( ) , vec![ with ] ) ;
626
642
627
643
assert ! ( matches!(
628
644
dfs. next( ) ,
@@ -642,5 +658,7 @@ mod tests {
642
658
dfs. next( ) ,
643
659
Some ( Event :: Leave ( ( expr_id, EventContext :: None ) ) ) if expr_id == with
644
660
) ) ;
661
+
662
+ assert_eq ! ( dfs. scopes( ) . collect:: <Vec <_>>( ) , vec![ ] ) ;
645
663
}
646
664
}
0 commit comments