|
1 | | -# Type Trees in Enzyme |
| 1 | +# TypeTrees for Autodiff |
2 | 2 |
|
3 | | -This document describes type trees as used by Enzyme for automatic differentiation. |
| 3 | +## What are TypeTrees? |
| 4 | +Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently. |
4 | 5 |
|
5 | | -## What are Type Trees? |
6 | | - |
7 | | -Type trees in Enzyme are a way to represent the types of variables, including their activity (e.g., whether they are active, duplicated, or contain duplicated data) for automatic differentiation. They provide a structured way for Enzyme to understand how to handle different data types during the differentiation process. |
| 6 | +## Structure |
| 7 | +```rust |
| 8 | +TypeTree(Vec<Type>) |
8 | 9 |
|
9 | | -## Representing Rust Types as Type Trees |
| 10 | +Type { |
| 11 | + offset: isize, // byte offset (-1 = everywhere) |
| 12 | + size: usize, // size in bytes |
| 13 | + kind: Kind, // Float, Integer, Pointer, etc. |
| 14 | + child: TypeTree // nested structure |
| 15 | +} |
| 16 | +``` |
10 | 17 |
|
11 | | -Enzyme needs to understand the structure and properties of Rust types to perform automatic differentiation correctly. This is where type trees come in. They provide a detailed map of a type, including pointer indirections and the underlying concrete data types. |
| 18 | +## Example: `fn compute(x: &f32, data: &[f32]) -> f32` |
12 | 19 |
|
13 | | -The `-enzyme-rust-type` flag in Enzyme helps in interpreting types more accurately in the context of Rust's memory layout and type system. |
| 20 | +**Input 0: `x: &f32`** |
| 21 | +```rust |
| 22 | +TypeTree(vec![Type { |
| 23 | + offset: -1, size: 8, kind: Pointer, |
| 24 | + child: TypeTree(vec![Type { |
| 25 | + offset: -1, size: 4, kind: Float, |
| 26 | + child: TypeTree::new() |
| 27 | + }]) |
| 28 | +}]) |
| 29 | +``` |
14 | 30 |
|
15 | | -### Primitive Types |
| 31 | +**Input 1: `data: &[f32]`** |
| 32 | +```rust |
| 33 | +TypeTree(vec![Type { |
| 34 | + offset: -1, size: 8, kind: Pointer, |
| 35 | + child: TypeTree(vec![Type { |
| 36 | + offset: -1, size: 4, kind: Float, // -1 = all elements |
| 37 | + child: TypeTree::new() |
| 38 | + }]) |
| 39 | +}]) |
| 40 | +``` |
16 | 41 |
|
17 | | -#### Floating-Point Types (`f32`, `f64`) |
| 42 | +**Output: `f32`** |
| 43 | +```rust |
| 44 | +TypeTree(vec![Type { |
| 45 | + offset: -1, size: 4, kind: Float, |
| 46 | + child: TypeTree::new() |
| 47 | +}]) |
| 48 | +``` |
18 | 49 |
|
19 | | -Consider a Rust reference to a 32-bit floating-point number, `&f32`. |
| 50 | +## Why Needed? |
| 51 | +- Enzyme can't deduce complex type layouts from LLVM IR |
| 52 | +- Prevents slow memory pattern analysis |
| 53 | +- Enables correct derivative computation for nested structures |
| 54 | +- Tells Enzyme which bytes are differentiable vs metadata |
20 | 55 |
|
21 | | -In LLVM IR, this might be represented, for instance, as an `i8*` (a generic byte pointer) that is then `bitcast` to a `float*`. Consider the following LLVM IR function: |
| 56 | +## What Enzyme Does With This Information: |
22 | 57 |
|
| 58 | +Without TypeTrees: |
23 | 59 | ```llvm |
24 | | -define internal void @callee(i8* %x) { |
25 | | -start: |
26 | | - %x.dbg.spill = bitcast i8* %x to float* |
27 | | - ; ... |
28 | | - ret void |
| 60 | +; Enzyme sees generic LLVM IR: |
| 61 | +define float @distance(i8* %p1, i8* %p2) { |
| 62 | +; Has to guess what these pointers point to |
| 63 | +; Slow analysis of all memory operations |
| 64 | +; May miss optimization opportunities |
29 | 65 | } |
30 | 66 | ``` |
31 | 67 |
|
32 | | -When Enzyme analyzes this function (with appropriate flags like `-enzyme-rust-type`), it might produce the following type information for the argument `%x` and the result of the bitcast: |
33 | | - |
| 68 | +With TypeTrees: |
34 | 69 | ```llvm |
35 | | -i8* %x: {[-1]:Pointer, [-1,0]:Float@float} |
36 | | -%x.dbg.spill = bitcast i8* %x to float*: {[-1]:Pointer, [-1,0]:Float@float} |
| 70 | +define "enzyme_type"="{[]:Float@float}" float @distance( |
| 71 | + ptr "enzyme_type"="{[]:Pointer}" %p1, |
| 72 | + ptr "enzyme_type"="{[]:Pointer}" %p2 |
| 73 | +) { |
| 74 | +; Enzyme knows exact type layout |
| 75 | +; Can generate efficient derivative code directly |
| 76 | +} |
37 | 77 | ``` |
38 | 78 |
|
39 | | -**Understanding the Type Tree: `{[-1]:Pointer, [-1,0]:Float@float}`** |
40 | | - |
41 | | -This string is the type tree representation. Let's break it down: |
| 79 | +# TypeTrees - Offset and -1 Explained |
42 | 80 |
|
43 | | -* **`{ ... }`**: This encloses the set of type information for the variable. |
44 | | -* **`[-1]:Pointer`**: |
45 | | - * `[-1]` is an index or path. In this context, `-1` often refers to the base memory location or the immediate value pointed to. |
46 | | - * `Pointer` indicates that the variable `%x` itself is treated as a pointer. |
47 | | -* **`[-1,0]:Float@float`**: |
48 | | - * `[-1,0]` is a path. It means: start with the base item `[-1]` (the pointer), and then look at offset `0` from the memory location it points to. |
49 | | - * `Float` is the `CConcreteType` (from `enzyme_ffi.rs`, corresponding to `DT_Float`). It signifies that the data at this location is a floating-point number. |
50 | | - * `@float` is a subtype or specific variant of `Float`. In this case, it specifies a single-precision float (like Rust's `f32`). |
| 81 | +## Type Structure |
51 | 82 |
|
52 | | -A reference to an `f64` (e.g., `&f64`) is handled very similarly. The LLVM IR might cast to `double*`: |
53 | | -```llvm |
54 | | -define internal void @callee(i8* %x) { |
55 | | -start: |
56 | | - %x.dbg.spill = bitcast i8* %x to double* |
57 | | - ; ... |
58 | | - ret void |
| 83 | +```rust |
| 84 | +Type { |
| 85 | + offset: isize, // WHERE this type starts |
| 86 | + size: usize, // HOW BIG this type is |
| 87 | + kind: Kind, // WHAT KIND of data (Float, Int, Pointer) |
| 88 | + child: TypeTree // WHAT'S INSIDE (for pointers/containers) |
59 | 89 | } |
60 | 90 | ``` |
61 | 91 |
|
62 | | -And the type tree would be: |
| 92 | +## Offset Values |
63 | 93 |
|
64 | | -```llvm |
65 | | -i8* %x: {[-1]:Pointer, [-1,0]:Float@double} |
66 | | -``` |
67 | | -The key difference is `@double`, indicating a double-precision float. |
68 | | - |
69 | | -This level of detail allows Enzyme to know, for example, that if `x` is an active variable in differentiation, the floating-point value it points to needs to be handled according to AD rules for its specific precision. |
70 | | - |
71 | | -### Compound Types |
72 | | - |
73 | | -#### Structs |
74 | | - |
75 | | -Consider a Rust struct `T` with two `f32` fields (e.g., a reference `&T`): |
| 94 | +### Regular Offset (0, 4, 8, etc.) |
| 95 | +**Specific byte position within a structure** |
76 | 96 |
|
77 | 97 | ```rust |
78 | | -struct T { |
79 | | - x: f32, |
80 | | - y: f32, |
| 98 | +struct Point { |
| 99 | + x: f32, // offset 0, size 4 |
| 100 | + y: f32, // offset 4, size 4 |
| 101 | + id: i32, // offset 8, size 4 |
81 | 102 | } |
82 | | - |
83 | | -// And a function taking a reference to it: |
84 | | -// fn callee(t: &T) { /* ... */ } |
85 | 103 | ``` |
86 | 104 |
|
87 | | -In LLVM IR, a pointer to this struct might be initially represented as `i8*` and then cast to the specific struct type, like `{ float, float }*`: |
| 105 | +TypeTree for `&Point` (internal representation): |
| 106 | +```rust |
| 107 | +TypeTree(vec![ |
| 108 | + Type { offset: 0, size: 4, kind: Float }, // x at byte 0 |
| 109 | + Type { offset: 4, size: 4, kind: Float }, // y at byte 4 |
| 110 | + Type { offset: 8, size: 4, kind: Integer } // id at byte 8 |
| 111 | +]) |
| 112 | +``` |
88 | 113 |
|
| 114 | +Generates LLVM: |
89 | 115 | ```llvm |
90 | | -define internal void @callee(i8* %t) { |
91 | | -start: |
92 | | - %t.dbg.spill = bitcast i8* %t to { float, float }* |
93 | | - ; ... |
94 | | - ret void |
95 | | -} |
| 116 | +"enzyme_type"="{[]:Float@float}" |
96 | 117 | ``` |
97 | 118 |
|
98 | | -The Enzyme type analysis output for `%t` would be: |
| 119 | +### Offset -1 (Special: "Everywhere") |
| 120 | +**Means "this pattern repeats for ALL elements"** |
99 | 121 |
|
100 | | -```llvm |
101 | | -i8* %t: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float} |
| 122 | +#### Example 1: Array `[f32; 100]` |
| 123 | +```rust |
| 124 | +TypeTree(vec![Type { |
| 125 | + offset: -1, // ALL positions |
| 126 | + size: 4, // each f32 is 4 bytes |
| 127 | + kind: Float, // every element is float |
| 128 | +}]) |
102 | 129 | ``` |
103 | 130 |
|
104 | | -**Understanding the Struct Type Tree: `{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}`** |
| 131 | +Instead of listing 100 separate Types with offsets `0,4,8,12...396` |
105 | 132 |
|
106 | | -* **`[-1]:Pointer`**: As before, this indicates that `%t` is a pointer. |
107 | | -* **`[-1,0]:Float@float`**: |
108 | | - * This describes the first field of the struct (`x`). |
109 | | - * `[-1,0]` means: from the memory location pointed to by `%t` (`-1`), at offset `0` bytes. |
110 | | - * `Float@float` indicates this field is an `f32`. |
111 | | -* **`[-1,4]:Float@float`**: |
112 | | - * This describes the second field of the struct (`y`). |
113 | | - * `[-1,4]` means: from the memory location pointed to by `%t` (`-1`), at offset `4` bytes. |
114 | | - * `Float@float` indicates this field is also an `f32`. |
| 133 | +#### Example 2: Slice `&[i32]` |
| 134 | +```rust |
| 135 | +// Pointer to slice data |
| 136 | +TypeTree(vec![Type { |
| 137 | + offset: -1, size: 8, kind: Pointer, |
| 138 | + child: TypeTree(vec![Type { |
| 139 | + offset: -1, // ALL slice elements |
| 140 | + size: 4, // each i32 is 4 bytes |
| 141 | + kind: Integer |
| 142 | + }]) |
| 143 | +}]) |
| 144 | +``` |
115 | 145 |
|
116 | | -The offset `4` comes from the size of the first field (`f32` is 4 bytes). If the first field were, for example, an `f64` (8 bytes), the second field might be at offset `[-1,8]`. Enzyme uses these offsets to pinpoint the exact memory location of each field within the struct. |
| 146 | +#### Example 3: Mixed Structure |
| 147 | +```rust |
| 148 | +struct Container { |
| 149 | + header: i64, // offset 0 |
| 150 | + data: [f32; 1000], // offset 8, but elements use -1 |
| 151 | +} |
| 152 | +``` |
117 | 153 |
|
118 | | -This detailed mapping is crucial for Enzyme to correctly track the activity of individual struct fields during automatic differentiation. |
| 154 | +```rust |
| 155 | +TypeTree(vec![ |
| 156 | + Type { offset: 0, size: 8, kind: Integer }, // header |
| 157 | + Type { offset: 8, size: 4000, kind: Pointer, |
| 158 | + child: TypeTree(vec![Type { |
| 159 | + offset: -1, size: 4, kind: Float // ALL array elements |
| 160 | + }]) |
| 161 | + } |
| 162 | +]) |
| 163 | +``` |
0 commit comments