Skip to content

Commit b351346

Browse files
authored
Merge pull request #19 from pythonbpf/fix-expr
Refactor expr_pass
2 parents 253944a + c3db609 commit b351346

File tree

5 files changed

+149
-125
lines changed

5 files changed

+149
-125
lines changed

pythonbpf/binary_ops.py

Lines changed: 43 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,68 +8,59 @@
88

99
def recursive_dereferencer(var, builder):
1010
"""dereference until primitive type comes out"""
11-
if var.type == ir.PointerType(ir.PointerType(ir.IntType(64))):
11+
# TODO: Not worrying about stack overflow for now
12+
if isinstance(var.type, ir.PointerType):
1213
a = builder.load(var)
1314
return recursive_dereferencer(a, builder)
14-
elif var.type == ir.PointerType(ir.IntType(64)):
15-
a = builder.load(var)
16-
return recursive_dereferencer(a, builder)
17-
elif var.type == ir.IntType(64):
15+
elif isinstance(var.type, ir.IntType):
1816
return var
1917
else:
2018
raise TypeError(f"Unsupported type for dereferencing: {var.type}")
2119

2220

23-
def handle_binary_op(rval, module, builder, var_name, local_sym_tab, map_sym_tab, func):
24-
logger.info(f"module {module}")
25-
left = rval.left
26-
right = rval.right
27-
op = rval.op
28-
29-
# Handle left operand
30-
if isinstance(left, ast.Name):
31-
if left.id in local_sym_tab:
32-
left = recursive_dereferencer(local_sym_tab[left.id].var, builder)
33-
else:
34-
raise SyntaxError(f"Undefined variable: {left.id}")
35-
elif isinstance(left, ast.Constant):
36-
left = ir.Constant(ir.IntType(64), left.value)
37-
else:
38-
raise SyntaxError("Unsupported left operand type")
21+
def get_operand_value(operand, module, builder, local_sym_tab):
22+
"""Extract the value from an operand, handling variables and constants."""
23+
if isinstance(operand, ast.Name):
24+
if operand.id in local_sym_tab:
25+
return recursive_dereferencer(local_sym_tab[operand.id].var, builder)
26+
raise ValueError(f"Undefined variable: {operand.id}")
27+
elif isinstance(operand, ast.Constant):
28+
if isinstance(operand.value, int):
29+
return ir.Constant(ir.IntType(64), operand.value)
30+
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
31+
elif isinstance(operand, ast.BinOp):
32+
return handle_binary_op_impl(operand, module, builder, local_sym_tab)
33+
raise TypeError(f"Unsupported operand type: {type(operand)}")
3934

40-
if isinstance(right, ast.Name):
41-
if right.id in local_sym_tab:
42-
right = recursive_dereferencer(local_sym_tab[right.id].var, builder)
43-
else:
44-
raise SyntaxError(f"Undefined variable: {right.id}")
45-
elif isinstance(right, ast.Constant):
46-
right = ir.Constant(ir.IntType(64), right.value)
47-
else:
48-
raise SyntaxError("Unsupported right operand type")
4935

36+
def handle_binary_op_impl(rval, module, builder, local_sym_tab):
37+
op = rval.op
38+
left = get_operand_value(rval.left, module, builder, local_sym_tab)
39+
right = get_operand_value(rval.right, module, builder, local_sym_tab)
5040
logger.info(f"left is {left}, right is {right}, op is {op}")
5141

52-
if isinstance(op, ast.Add):
53-
builder.store(builder.add(left, right), local_sym_tab[var_name].var)
54-
elif isinstance(op, ast.Sub):
55-
builder.store(builder.sub(left, right), local_sym_tab[var_name].var)
56-
elif isinstance(op, ast.Mult):
57-
builder.store(builder.mul(left, right), local_sym_tab[var_name].var)
58-
elif isinstance(op, ast.Div):
59-
builder.store(builder.sdiv(left, right), local_sym_tab[var_name].var)
60-
elif isinstance(op, ast.Mod):
61-
builder.store(builder.srem(left, right), local_sym_tab[var_name].var)
62-
elif isinstance(op, ast.LShift):
63-
builder.store(builder.shl(left, right), local_sym_tab[var_name].var)
64-
elif isinstance(op, ast.RShift):
65-
builder.store(builder.lshr(left, right), local_sym_tab[var_name].var)
66-
elif isinstance(op, ast.BitOr):
67-
builder.store(builder.or_(left, right), local_sym_tab[var_name].var)
68-
elif isinstance(op, ast.BitXor):
69-
builder.store(builder.xor(left, right), local_sym_tab[var_name].var)
70-
elif isinstance(op, ast.BitAnd):
71-
builder.store(builder.and_(left, right), local_sym_tab[var_name].var)
72-
elif isinstance(op, ast.FloorDiv):
73-
builder.store(builder.udiv(left, right), local_sym_tab[var_name].var)
42+
# Map AST operation nodes to LLVM IR builder methods
43+
op_map = {
44+
ast.Add: builder.add,
45+
ast.Sub: builder.sub,
46+
ast.Mult: builder.mul,
47+
ast.Div: builder.sdiv,
48+
ast.Mod: builder.srem,
49+
ast.LShift: builder.shl,
50+
ast.RShift: builder.lshr,
51+
ast.BitOr: builder.or_,
52+
ast.BitXor: builder.xor,
53+
ast.BitAnd: builder.and_,
54+
ast.FloorDiv: builder.udiv,
55+
}
56+
57+
if type(op) in op_map:
58+
result = op_map[type(op)](left, right)
59+
return result
7460
else:
7561
raise SyntaxError("Unsupported binary operation")
62+
63+
64+
def handle_binary_op(rval, module, builder, var_name, local_sym_tab):
65+
result = handle_binary_op_impl(rval, module, builder, local_sym_tab)
66+
builder.store(result, local_sym_tab[var_name].var)

pythonbpf/expr_pass.py

Lines changed: 100 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,92 @@
22
from llvmlite import ir
33
from logging import Logger
44
import logging
5+
from typing import Dict
56

67
logger: Logger = logging.getLogger(__name__)
78

89

10+
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
11+
"""Handle ast.Name expressions."""
12+
if expr.id in local_sym_tab:
13+
var = local_sym_tab[expr.id].var
14+
val = builder.load(var)
15+
return val, local_sym_tab[expr.id].ir_type
16+
else:
17+
logger.info(f"Undefined variable {expr.id}")
18+
return None
19+
20+
21+
def _handle_constant_expr(expr: ast.Constant):
22+
"""Handle ast.Constant expressions."""
23+
if isinstance(expr.value, int):
24+
return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64)
25+
elif isinstance(expr.value, bool):
26+
return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1)
27+
else:
28+
logger.info("Unsupported constant type")
29+
return None
30+
31+
32+
def _handle_attribute_expr(
33+
expr: ast.Attribute,
34+
local_sym_tab: Dict,
35+
structs_sym_tab: Dict,
36+
builder: ir.IRBuilder,
37+
):
38+
"""Handle ast.Attribute expressions for struct field access."""
39+
if isinstance(expr.value, ast.Name):
40+
var_name = expr.value.id
41+
attr_name = expr.attr
42+
if var_name in local_sym_tab:
43+
var_ptr, var_type, var_metadata = local_sym_tab[var_name]
44+
logger.info(f"Loading attribute {attr_name} from variable {var_name}")
45+
logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
46+
47+
metadata = structs_sym_tab[var_metadata]
48+
if attr_name in metadata.fields:
49+
gep = metadata.gep(builder, var_ptr, attr_name)
50+
val = builder.load(gep)
51+
field_type = metadata.field_type(attr_name)
52+
return val, field_type
53+
return None
54+
55+
56+
def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder):
57+
"""Handle deref function calls."""
58+
logger.info(f"Handling deref {ast.dump(expr)}")
59+
if len(expr.args) != 1:
60+
logger.info("deref takes exactly one argument")
61+
return None
62+
63+
arg = expr.args[0]
64+
if (
65+
isinstance(arg, ast.Call)
66+
and isinstance(arg.func, ast.Name)
67+
and arg.func.id == "deref"
68+
):
69+
logger.info("Multiple deref not supported")
70+
return None
71+
72+
if isinstance(arg, ast.Name):
73+
if arg.id in local_sym_tab:
74+
arg_ptr = local_sym_tab[arg.id].var
75+
else:
76+
logger.info(f"Undefined variable {arg.id}")
77+
return None
78+
else:
79+
logger.info("Unsupported argument type for deref")
80+
return None
81+
82+
if arg_ptr is None:
83+
logger.info("Failed to evaluate deref argument")
84+
return None
85+
86+
# Load the value from pointer
87+
val = builder.load(arg_ptr)
88+
return val, local_sym_tab[arg.id].ir_type
89+
90+
991
def eval_expr(
1092
func,
1193
module,
@@ -17,64 +99,28 @@ def eval_expr(
1799
):
18100
logger.info(f"Evaluating expression: {ast.dump(expr)}")
19101
if isinstance(expr, ast.Name):
20-
if expr.id in local_sym_tab:
21-
var = local_sym_tab[expr.id].var
22-
val = builder.load(var)
23-
return val, local_sym_tab[expr.id].ir_type # return value and type
24-
else:
25-
logger.info(f"Undefined variable {expr.id}")
26-
return None
102+
return _handle_name_expr(expr, local_sym_tab, builder)
27103
elif isinstance(expr, ast.Constant):
28-
if isinstance(expr.value, int):
29-
return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64)
30-
elif isinstance(expr.value, bool):
31-
return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1)
32-
else:
33-
logger.info("Unsupported constant type")
34-
return None
104+
return _handle_constant_expr(expr)
35105
elif isinstance(expr, ast.Call):
106+
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
107+
return _handle_deref_call(expr, local_sym_tab, builder)
108+
36109
# delayed import to avoid circular dependency
37110
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
38111

39-
if isinstance(expr.func, ast.Name):
40-
# check deref
41-
if expr.func.id == "deref":
42-
logger.info(f"Handling deref {ast.dump(expr)}")
43-
if len(expr.args) != 1:
44-
logger.info("deref takes exactly one argument")
45-
return None
46-
arg = expr.args[0]
47-
if (
48-
isinstance(arg, ast.Call)
49-
and isinstance(arg.func, ast.Name)
50-
and arg.func.id == "deref"
51-
):
52-
logger.info("Multiple deref not supported")
53-
return None
54-
if isinstance(arg, ast.Name):
55-
if arg.id in local_sym_tab:
56-
arg = local_sym_tab[arg.id].var
57-
else:
58-
logger.info(f"Undefined variable {arg.id}")
59-
return None
60-
if arg is None:
61-
logger.info("Failed to evaluate deref argument")
62-
return None
63-
# Since we are handling only name case, directly take type from sym tab
64-
val = builder.load(arg)
65-
return val, local_sym_tab[expr.args[0].id].ir_type
66-
67-
# check for helpers
68-
if HelperHandlerRegistry.has_handler(expr.func.id):
69-
return handle_helper_call(
70-
expr,
71-
module,
72-
builder,
73-
func,
74-
local_sym_tab,
75-
map_sym_tab,
76-
structs_sym_tab,
77-
)
112+
if isinstance(expr.func, ast.Name) and HelperHandlerRegistry.has_handler(
113+
expr.func.id
114+
):
115+
return handle_helper_call(
116+
expr,
117+
module,
118+
builder,
119+
func,
120+
local_sym_tab,
121+
map_sym_tab,
122+
structs_sym_tab,
123+
)
78124
elif isinstance(expr.func, ast.Attribute):
79125
logger.info(f"Handling method call: {ast.dump(expr.func)}")
80126
if isinstance(expr.func.value, ast.Call) and isinstance(
@@ -106,19 +152,7 @@ def eval_expr(
106152
structs_sym_tab,
107153
)
108154
elif isinstance(expr, ast.Attribute):
109-
if isinstance(expr.value, ast.Name):
110-
var_name = expr.value.id
111-
attr_name = expr.attr
112-
if var_name in local_sym_tab:
113-
var_ptr, var_type, var_metadata = local_sym_tab[var_name]
114-
logger.info(f"Loading attribute {attr_name} from variable {var_name}")
115-
logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
116-
metadata = structs_sym_tab[var_metadata]
117-
if attr_name in metadata.fields:
118-
gep = metadata.gep(builder, var_ptr, attr_name)
119-
val = builder.load(gep)
120-
field_type = metadata.field_type(attr_name)
121-
return val, field_type
155+
return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
122156
logger.info("Unsupported expression evaluation")
123157
return None
124158

pythonbpf/functions_pass.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,7 @@ def handle_assign(
233233
else:
234234
logger.info("Unsupported assignment call function type")
235235
elif isinstance(rval, ast.BinOp):
236-
handle_binary_op(
237-
rval, module, builder, var_name, local_sym_tab, map_sym_tab, func
238-
)
236+
handle_binary_op(rval, module, builder, var_name, local_sym_tab)
239237
else:
240238
logger.info("Unsupported assignment value type")
241239

tests/failing_tests/binops.py renamed to tests/passing_tests/binops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44

55
@bpf
6-
@section("sometag1")
6+
@section("tracepoint/syscalls/sys_enter_sync")
77
def sometag(ctx: c_void_p) -> c_int64:
8-
a = 1 + 2 + 1
8+
a = 1 + 2 + 1 + 12 + 13
99
print(f"{a}")
1010
return c_int64(0)
1111

tests/failing_tests/binops1.py renamed to tests/passing_tests/binops1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44

55
@bpf
6-
@section("sometag1")
6+
@section("tracepoint/syscalls/sys_enter_sync")
77
def sometag(ctx: c_void_p) -> c_int64:
88
b = 1 + 2
99
a = 1 + b
10-
return c_int64(a)
10+
print(f"{a}")
11+
return c_int64(0)
1112

1213

1314
@bpf

0 commit comments

Comments
 (0)