Skip to content

Commit 8d5fa15

Browse files
z80devcharles-coopercyberthirst
authored
fix[lang]: allow flag member access in pure functions (#4693)
flag members are compile-time constants and should be allowed in pure functions. when validating attribute access like `Action.BUY` in a pure function, the analysis was failing because `Action` (the value) is a `TYPE_T` expression which was rejected with an `InvalidReference` exception. fix by passing `is_callable=True` to `get_expr_info` when analyzing the `.value` member of `Attribute` nodes in pure functions. this allows the `.value` of the attribute (e.g., `Action`) to be parsed as a type expression without raising `InvalidReference`. note that this still respects the semantics of `is_callable=True` because we are only allowing the type expression for the value portion of the attribute. the overall attribute node (e.g., `Action.BUY`) still preserves the `is_callable` semantics for the whole expression. --------- Co-authored-by: Charles Cooper <[email protected]> Co-authored-by: cyberthirst <[email protected]>
1 parent ad8c765 commit 8d5fa15

File tree

5 files changed

+184
-2
lines changed

5 files changed

+184
-2
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
def test_flag_members_in_pure_functions(get_contract):
2+
"""Test that flag members can be used in pure functions since they are
3+
compile-time constants"""
4+
code = """
5+
flag Action:
6+
BUY
7+
SELL
8+
CANCEL
9+
10+
@pure
11+
@external
12+
def get_buy_action() -> Action:
13+
return Action.BUY
14+
15+
@pure
16+
@external
17+
def get_sell_action() -> Action:
18+
return Action.SELL
19+
20+
@pure
21+
@external
22+
def get_cancel_action() -> Action:
23+
return Action.CANCEL
24+
"""
25+
c = get_contract(code)
26+
assert c.get_buy_action() == 1 # 2^0
27+
assert c.get_sell_action() == 2 # 2^1
28+
assert c.get_cancel_action() == 4 # 2^2
29+
30+
31+
def test_flag_operations_in_pure_functions(get_contract):
32+
"""Test that flag operations work in pure functions"""
33+
code = """
34+
flag Permissions:
35+
READ
36+
WRITE
37+
EXECUTE
38+
39+
@pure
40+
@external
41+
def get_read_write() -> Permissions:
42+
return Permissions.READ | Permissions.WRITE
43+
44+
@pure
45+
@external
46+
def check_read_permission(perms: Permissions) -> bool:
47+
return Permissions.READ in perms
48+
49+
@pure
50+
@external
51+
def combine_all() -> Permissions:
52+
return Permissions.READ | Permissions.WRITE | Permissions.EXECUTE
53+
"""
54+
c = get_contract(code)
55+
assert c.get_read_write() == 3 # 1 | 2 = 3
56+
assert c.check_read_permission(1) is True # READ permission
57+
assert c.check_read_permission(2) is False # WRITE permission only
58+
assert c.combine_all() == 7 # 1 | 2 | 4 = 7
59+
60+
61+
def test_flag_conditionals_in_pure_functions(get_contract):
62+
"""Test flags in conditional expressions within pure functions"""
63+
code = """
64+
flag Status:
65+
ACTIVE
66+
INACTIVE
67+
PENDING
68+
69+
@pure
70+
@external
71+
def classify_status(status: Status) -> uint256:
72+
if status == Status.ACTIVE:
73+
return 100
74+
elif status == Status.PENDING:
75+
return 50
76+
else:
77+
return 0
78+
"""
79+
c = get_contract(code)
80+
assert c.classify_status(1) == 100 # ACTIVE
81+
assert c.classify_status(4) == 50 # PENDING
82+
assert c.classify_status(2) == 0 # INACTIVE
83+
84+
85+
def test_access_flag_from_another_module(get_contract, make_input_bundle):
86+
"""Test flag access even if the attribute comes from another module (eg lib1.flag.foo)"""
87+
code = """
88+
import lib1
89+
90+
@pure
91+
@external
92+
def foo() -> lib1.Action:
93+
return lib1.Action.BUY
94+
95+
"""
96+
lib1 = """
97+
flag Action:
98+
BUY
99+
SELL
100+
"""
101+
102+
input_bundle = make_input_bundle({"lib1.vy": lib1})
103+
c = get_contract(code, input_bundle=input_bundle)
104+
assert c.foo() == 1 # BUY
105+
106+
107+
def test_internal_pure_accessing_flag(get_contract, make_input_bundle):
108+
"""Test flag accesses in internal pure functions"""
109+
code = """
110+
import lib1
111+
112+
@pure
113+
def bar() -> lib1.Action:
114+
return lib1.Action.BUY
115+
116+
@pure
117+
@external
118+
def foo() -> lib1.Action:
119+
return self.bar()
120+
121+
"""
122+
lib1 = """
123+
flag Action:
124+
BUY
125+
SELL
126+
"""
127+
128+
input_bundle = make_input_bundle({"lib1.vy": lib1})
129+
c = get_contract(code, input_bundle=input_bundle)
130+
assert c.foo() == 1 # BUY
131+
132+
133+
def test_flag_access_in_loop(get_contract, make_input_bundle):
134+
"""Test flag accesses in a for loop"""
135+
code = """
136+
137+
flag Action:
138+
BUY
139+
SELL
140+
141+
@pure
142+
@external
143+
def foo() -> uint256:
144+
cnt: uint256 = 0
145+
for i: uint256 in range(10):
146+
cnt += convert(Action.SELL, uint256)
147+
return cnt
148+
"""
149+
c = get_contract(code)
150+
assert c.foo() == 10 * 2

tests/functional/codegen/types/test_flag.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import pytest
2+
3+
from vyper.exceptions import StateAccessViolation
4+
5+
16
def test_values_should_be_increasing_ints(get_contract):
27
code = """
38
flag Action:
@@ -303,3 +308,16 @@ def get_key(f: Foobar, i: uint256) -> uint256:
303308
"""
304309
c = get_contract(code)
305310
assert c.get_key(1, 777) == 777
311+
312+
313+
@pytest.mark.xfail(raises=StateAccessViolation)
314+
def test_flag_constant(get_contract):
315+
code = """
316+
flag F:
317+
FOO
318+
BAR
319+
320+
c: public(constant(F)) = F.FOO
321+
"""
322+
c = get_contract(code)
323+
assert c.c() == 1 # FOO

tests/functional/syntax/test_flag.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from vyper.exceptions import (
55
FlagDeclarationException,
66
InvalidOperation,
7+
InvalidReference,
78
NamespaceCollision,
89
StructureException,
910
TypeMismatch,
@@ -124,6 +125,17 @@ def foo():
124125
""",
125126
TypeMismatch,
126127
),
128+
(
129+
"""
130+
flag Status:
131+
ACTIVE
132+
133+
@external
134+
def test_assign_to_flag():
135+
Status.ACTIVE = 2
136+
""",
137+
InvalidReference,
138+
),
127139
]
128140

129141

vyper/semantics/analysis/local.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ def _validate_pure_access(node: vy_ast.Attribute | vy_ast.Name, typ: VyperType)
179179
raise StateAccessViolation(
180180
"not allowed to query environment variables in pure functions"
181181
)
182-
parent_info = get_expr_info(node.value)
182+
# allow type exprs in the value node, e.g. MyFlag.A
183+
parent_info = get_expr_info(node.value, is_callable=True)
183184
if isinstance(parent_info.typ, AddressT) and node.attr in AddressT._type_members:
184185
raise StateAccessViolation("not allowed to query address members in pure functions")
185186

vyper/semantics/analysis/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def get_expr_info(self, node: vy_ast.VyperNode, is_callable: bool = False) -> Ex
9090
# propagate the parent exprinfo members down into the new expr
9191
# note: Attribute(expr value, identifier attr)
9292

93-
info = self.get_expr_info(node.value, is_callable=is_callable)
93+
# allow the value node to be a type expr (e.g., MyFlag.A)
94+
info = self.get_expr_info(node.value, is_callable=True)
9495
attr = node.attr
9596

9697
t = info.typ.get_member(attr, node)

0 commit comments

Comments
 (0)