Skip to content

Commit 994a567

Browse files
committed
[python][utils] MemRef Manager
Adds a utility for manual memory management of memref buffers across Python and jitted MLIR modules. Explicit memory management becomes required when an MLIR function returns a newly allocated buffer e.g., results of a computation. This can become a complex task due to difference in memory models between Python and the MLIR runtime allocators. By default, returned MLIR buffers' lifetime cannot be automatically managed by the Python environment. The Python memref manager aims to address the following challenges: - use of the same runtime allocators as a jitted MLIR module for consistent memory management - lean abstraction using memref descriptors directly - buffers usable both by Python and jitted MLIR modules Current implementation assumes that memref allocation ops are lowered to standard C functions, like 'malloc' and 'free', which are preloaded together with the Python process.
1 parent fdc9b53 commit 994a567

File tree

4 files changed

+222
-1
lines changed

4 files changed

+222
-1
lines changed

.github/workflows/examples.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,8 @@ jobs:
2525
2626
- name: Run MLP From Module
2727
run: |-
28-
uv run python/examples/ingress/torch/mlp_from_model.py
28+
uv run python/examples/ingress/torch/mlp_from_model.py
29+
30+
- name: Run MemRef Management
31+
run: |-
32+
uv run python/examples/mlir/memref_management.py
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import torch
2+
import ctypes
3+
4+
from mlir import ir
5+
from mlir.dialects import func, memref
6+
from mlir.runtime import np_to_memref
7+
from mlir.execution_engine import ExecutionEngine
8+
from mlir.passmanager import PassManager
9+
10+
import lighthouse.utils as lh_utils
11+
12+
13+
def create_mlir_module(ctx: ir.Context, shape: list[int]) -> ir.Module:
14+
with ctx, ir.Location.unknown():
15+
module = ir.Module.create()
16+
with ir.InsertionPoint(module.body):
17+
mem_type = ir.MemRefType.get(shape, ir.F32Type.get())
18+
19+
# Return a new buffer initialized with input's data.
20+
@func.func(mem_type)
21+
def copy(input):
22+
new_buf = memref.alloc(mem_type, [], [])
23+
memref.copy(input, new_buf)
24+
return new_buf
25+
26+
# Free given buffer.
27+
@func.func(mem_type)
28+
def module_dealloc(input):
29+
memref.dealloc(input)
30+
31+
return module
32+
33+
34+
def lower_to_llvm(operation: ir.Operation) -> None:
35+
with operation.context:
36+
pm = PassManager("builtin.module")
37+
pm.add("func.func(llvm-request-c-wrappers)")
38+
pm.add("convert-to-llvm")
39+
pm.add("reconcile-unrealized-casts")
40+
pm.add("cse")
41+
pm.add("canonicalize")
42+
pm.run(operation)
43+
44+
45+
def main():
46+
# Validate basic functionality.
47+
print("Testing memref allocator...")
48+
mem = lh_utils.MemRefManager()
49+
# Check allocation.
50+
buf = mem.alloc(32, 8, 16, ctype=ctypes.c_float)
51+
assert buf.allocated != 0, "Invalid allocation"
52+
assert list(buf.shape) == [32, 8, 16], "Invalid shape"
53+
assert list(buf.strides) == [128, 16, 1], "Invalid strides"
54+
# Check deallocation.
55+
mem.dealloc(buf)
56+
assert buf.allocated == 0, "Failed deallocation"
57+
# Double free must not crash.
58+
mem.dealloc(buf)
59+
60+
# Zero rank buffer.
61+
buf = mem.alloc(ctype=ctypes.c_float)
62+
mem.dealloc(buf)
63+
# Small buffer.
64+
buf = mem.alloc(8, ctype=ctypes.c_int8)
65+
mem.dealloc(buf)
66+
# Large buffer.
67+
buf = mem.alloc(1024, 1024, ctype=ctypes.c_int32)
68+
mem.dealloc(buf)
69+
70+
# Validate functionality across Python-MLIR boundary.
71+
print("Testing JIT module memory management...")
72+
# Buffer shape for testing.
73+
shape = [16, 32]
74+
75+
# Create and compile test module.
76+
ctx = ir.Context()
77+
kernel = create_mlir_module(ctx, shape)
78+
lower_to_llvm(kernel.operation)
79+
eng = ExecutionEngine(kernel, opt_level=3)
80+
eng.initialize()
81+
82+
# Validate passing memrefs between Python and jitted module.
83+
print("...copy test...")
84+
fn_copy = eng.lookup("copy")
85+
86+
# Alloc buffer in Python and initialize it.
87+
in_mem = mem.alloc(*shape, ctype=ctypes.c_float)
88+
in_np = np_to_memref.ranked_memref_to_numpy([in_mem])
89+
assert not in_np.flags.owndata, "Expected non-owning memref conversion"
90+
in_tensor = torch.from_numpy(in_np)
91+
torch.randn(in_tensor.shape, out=in_tensor)
92+
93+
out_mem = np_to_memref.make_nd_memref_descriptor(in_tensor.dim(), ctypes.c_float)()
94+
out_mem.allocated = 0
95+
96+
args = lh_utils.memrefs_to_packed_args([out_mem, in_mem])
97+
fn_copy(args)
98+
assert out_mem.allocated != 0, "Invalid buffer returned"
99+
100+
out_tensor = torch.from_numpy(np_to_memref.ranked_memref_to_numpy([out_mem]))
101+
torch.testing.assert_close(out_tensor, in_tensor)
102+
103+
mem.dealloc(out_mem)
104+
assert out_mem.allocated == 0, "Failed to dealloc returned buffer"
105+
mem.dealloc(in_mem)
106+
107+
# Validate external allocation with deallocation from within jitted module.
108+
print("...dealloc test...")
109+
fn_mlir_dealloc = eng.lookup("module_dealloc")
110+
buf_mem = mem.alloc(*shape, ctype=ctypes.c_float)
111+
fn_mlir_dealloc(lh_utils.memrefs_to_packed_args([buf_mem]))
112+
113+
print("SUCCESS")
114+
115+
116+
if __name__ == "__main__":
117+
main()

python/lighthouse/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""A collection of utility tools"""
22

3+
from .memref_manager import MemRefManager
4+
35
from .runtime_args import (
46
get_packed_arg,
57
memref_to_ctype,
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import ctypes
2+
3+
from itertools import accumulate
4+
from functools import reduce
5+
import operator
6+
7+
import mlir.runtime.np_to_memref as np_mem
8+
9+
10+
class MemRefManager:
11+
"""
12+
A utility class for manual management of MLIR memrefs.
13+
14+
When used together with memref operation from within a jitted MLIR module,
15+
it is assumed that Memref dialect allocations and deallocation are performed
16+
through standard runtime `malloc` and `free` functions.
17+
18+
Custom allocators are currently not supported. For more details, see:
19+
https://mlir.llvm.org/docs/TargetLLVMIR/#generic-alloction-and-deallocation-functions
20+
"""
21+
22+
def __init__(self) -> None:
23+
# Library name is left unspecified to allow for symbol search
24+
# in the global symbol table of the current process.
25+
# For more details, see:
26+
# https://github.com/python/cpython/issues/78773
27+
self.dll = ctypes.CDLL(name=None)
28+
self.fn_malloc = self.dll.malloc
29+
self.fn_malloc.argtypes = [ctypes.c_size_t]
30+
self.fn_malloc.restype = ctypes.c_void_p
31+
self.fn_free = self.dll.free
32+
self.fn_free.argtypes = [ctypes.c_void_p]
33+
self.fn_free.restype = None
34+
35+
def alloc(self, *shape: int, ctype: ctypes._SimpleCData) -> ctypes.Structure:
36+
"""
37+
Allocate an empty memory buffer.
38+
Returns an MLIR ranked memref descriptor.
39+
40+
Args:
41+
shape: A sequence of integers defining the buffer's shape.
42+
ctype: A C type of buffer's elements.
43+
"""
44+
assert issubclass(ctype, ctypes._SimpleCData), "Expected a simple data ctype"
45+
size_bytes = reduce(operator.mul, shape, ctypes.sizeof(ctype))
46+
buf = self.fn_malloc(size_bytes)
47+
assert buf, "Failed to allocate memory"
48+
49+
rank = len(shape)
50+
if rank == 0:
51+
desc = np_mem.make_zero_d_memref_descriptor(ctype)()
52+
desc.allocated = buf
53+
desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype))
54+
desc.offset = ctypes.c_longlong(0)
55+
return desc
56+
57+
desc = np_mem.make_nd_memref_descriptor(rank, ctype)()
58+
desc.allocated = buf
59+
desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype))
60+
desc.offset = ctypes.c_longlong(0)
61+
shape_ctype_t = ctypes.c_longlong * rank
62+
desc.shape = shape_ctype_t(*shape)
63+
64+
strides = list(accumulate(reversed(shape[1:]), func=operator.mul))
65+
strides.reverse()
66+
strides.append(1)
67+
desc.strides = shape_ctype_t(*strides)
68+
return desc
69+
70+
def dealloc(self, memref_desc: ctypes.Structure) -> None:
71+
"""
72+
Free underlying memory buffer.
73+
74+
Args:
75+
memref_desc: An MLIR memref descriptor.
76+
"""
77+
# TODO: Expose upstream MemrefDescriptor classes for easier handling
78+
assert memref_desc.__class__.__name__ == "MemRefDescriptor" or isinstance(
79+
memref_desc, np_mem.UnrankedMemRefDescriptor
80+
), "Invalid memref descriptor"
81+
82+
if isinstance(memref_desc, np_mem.UnrankedMemRefDescriptor):
83+
# Unranked memref holds the underlying descriptor as an opaque pointer.
84+
# Cast the descriptor to a zero ranked memref with an arbitrary type to
85+
# access the base allocated memory pointer.
86+
ranked_desc_type = np_mem.make_zero_d_memref_descriptor(ctypes.c_char)
87+
ranked_desc = ctypes.cast(
88+
memref_desc.descriptor, ctypes.POINTER(ranked_desc_type)
89+
)
90+
memref_desc = ranked_desc[0]
91+
92+
alloc_ptr = memref_desc.allocated
93+
if alloc_ptr == 0:
94+
return
95+
96+
c_ptr = ctypes.cast(alloc_ptr, ctypes.c_void_p)
97+
self.fn_free(c_ptr)
98+
memref_desc.allocated = 0

0 commit comments

Comments
 (0)