Skip to content

Commit e8bf6ed

Browse files
committed
refactor(go): Change the serialization API in Golang.
The existing `Serialize` and `Deserialize` functions are replaced by `WriteTo` and `ReadFrom`, which write to an `io.Writer` and read from an `io.Reader` respectively. This new API is more efficient because it doesn't need to make a copy of the compiled rules in memory. This also removes an issue that existed in `Serialize` when serialized rules are larger than 4GB. It turns out that `C.GoBytes` receives a length of type `C.int` which is a 32-bits integer, effectively limiting the serialized rules to less than 4GB.
1 parent d7db62b commit e8bf6ed

File tree

3 files changed

+88
-21
lines changed

3 files changed

+88
-21
lines changed

go/compiler_test.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package yara_x
22

33
import (
4+
"bytes"
45
"github.com/stretchr/testify/assert"
56
"testing"
67
)
@@ -50,12 +51,20 @@ func TestSerialization(t *testing.T) {
5051
r, err := Compile("rule test { condition: true }")
5152
assert.NoError(t, err)
5253

53-
b, _ := r.Serialize()
54-
r, _ = Deserialize(b)
54+
var buf bytes.Buffer
5555

56+
// Write rules into buffer
57+
n, err := r.WriteTo(&buf)
58+
59+
assert.NoError(t, err)
60+
assert.Len(t, buf.Bytes(), int(n))
61+
62+
// Read rules from buffer
63+
r, _ = ReadFrom(&buf)
64+
65+
// Make sure the rules work properly.
5666
s := NewScanner(r)
5767
scanResults, _ := s.Scan([]byte{})
58-
5968
assert.Len(t, scanResults.MatchingRules(), 1)
6069
}
6170

@@ -163,8 +172,8 @@ func TestRulesIter(t *testing.T) {
163172
}`)
164173
assert.NoError(t, err)
165174

166-
rules := c.Build()
167-
assert.Equal(t, 2, rules.Count())
175+
rules := c.Build()
176+
assert.Equal(t, 2, rules.Count())
168177

169178
slice := rules.Slice()
170179
assert.Len(t, slice, 2)
@@ -177,7 +186,7 @@ func TestRulesIter(t *testing.T) {
177186
assert.Len(t, slice[0].Metadata(), 0)
178187
assert.Len(t, slice[1].Metadata(), 1)
179188

180-
assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier())
189+
assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier())
181190
}
182191

183192
func TestImportsIter(t *testing.T) {
@@ -193,12 +202,12 @@ func TestImportsIter(t *testing.T) {
193202
}`)
194203
assert.NoError(t, err)
195204

196-
rules := c.Build()
197-
imports := rules.Imports()
205+
rules := c.Build()
206+
imports := rules.Imports()
198207

199-
assert.Len(t, imports, 2)
200-
assert.Equal(t, "pe", imports[0])
201-
assert.Equal(t, "elf", imports[1])
208+
assert.Len(t, imports, 2)
209+
assert.Equal(t, "pe", imports[0])
210+
assert.Equal(t, "elf", imports[1])
202211
}
203212

204213
func TestWarnings(t *testing.T) {

go/main.go

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import "C"
3131

3232
import (
3333
"errors"
34+
"io"
35+
"reflect"
3436
"runtime"
3537
"runtime/cgo"
3638
"unsafe"
@@ -49,25 +51,30 @@ func Compile(src string, opts ...CompileOption) (*Rules, error) {
4951
return c.Build(), nil
5052
}
5153

52-
// Deserialize deserializes rules from a byte slice.
54+
// ReadFrom reads compiled rules from a reader.
5355
//
54-
// The counterpart is [Rules.Serialize]
55-
func Deserialize(data []byte) (*Rules, error) {
56+
// The counterpart is [Rules.WriteTo].
57+
func ReadFrom(r io.Reader) (*Rules, error) {
58+
data, err := io.ReadAll(r)
59+
if err != nil {
60+
return nil, err
61+
}
62+
5663
var ptr *C.uint8_t
5764
if len(data) > 0 {
5865
ptr = (*C.uint8_t)(unsafe.Pointer(&(data[0])))
5966
}
6067

61-
r := &Rules{cRules: nil}
68+
rules := &Rules{cRules: nil}
6269

6370
runtime.LockOSThread()
6471
defer runtime.UnlockOSThread()
6572

66-
if C.yrx_rules_deserialize(ptr, C.size_t(len(data)), &r.cRules) != C.SUCCESS {
73+
if C.yrx_rules_deserialize(ptr, C.size_t(len(data)), &rules.cRules) != C.SUCCESS {
6774
return nil, errors.New(C.GoString(C.yrx_last_error()))
6875
}
6976

70-
return r, nil
77+
return rules, nil
7178
}
7279

7380
// Rules represents a set of compiled YARA rules.
@@ -79,17 +86,60 @@ func (r *Rules) Scan(data []byte) (*ScanResults, error) {
7986
return scanner.Scan(data)
8087
}
8188

82-
// Serialize converts the compiled rules into a byte slice.
83-
func (r *Rules) Serialize() ([]byte, error) {
89+
// WriteTo writes the compiled rules into a writer.
90+
//
91+
// The counterpart is [ReadFrom].
92+
func (r *Rules) WriteTo(w io.Writer) (int64, error) {
8493
var buf *C.YRX_BUFFER
8594
runtime.LockOSThread()
8695
defer runtime.UnlockOSThread()
8796
if C.yrx_rules_serialize(r.cRules, &buf) != C.SUCCESS {
88-
return nil, errors.New(C.GoString(C.yrx_last_error()))
97+
return 0, errors.New(C.GoString(C.yrx_last_error()))
8998
}
9099
defer C.yrx_buffer_destroy(buf)
91100
runtime.KeepAlive(r)
92-
return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length)), nil
101+
102+
// We are going to write into `w` in chunks of 64MB.
103+
const chunkSize = 1 << 26
104+
105+
// This is the slice that contains the next chunk that will be written.
106+
var chunk []byte
107+
108+
// Modify the `chunk` slice, making it point to the buffer returned
109+
// by yrx_rules_serialize. This allows us to access the buffer from
110+
// Go without copying the data. This is safe because the slice won't
111+
// be used after the buffer is destroyed.
112+
chunkHdr := (*reflect.SliceHeader)(unsafe.Pointer(&chunk))
113+
chunkHdr.Data = uintptr(unsafe.Pointer(buf.data))
114+
chunkHdr.Len = chunkSize
115+
chunkHdr.Cap = chunkSize
116+
117+
bufLen := C.ulong(buf.length)
118+
bytesWritten := int64(0)
119+
120+
for {
121+
// If the data to be written is shorted than `chunkSize`, set the length
122+
// of the `chunk` slice to this length.
123+
if bufLen < chunkSize {
124+
chunkHdr.Len = int(bufLen)
125+
chunkHdr.Cap = int(bufLen)
126+
}
127+
if n, err := w.Write(chunk); err == nil {
128+
bytesWritten += int64(n)
129+
} else {
130+
return 0, err
131+
}
132+
// If `bufLen` is still greater than `chunkSize`, there's more data to
133+
// write, if not, we are done.
134+
if bufLen > chunkSize {
135+
chunkHdr.Data += chunkSize
136+
bufLen -= chunkSize
137+
} else {
138+
break
139+
}
140+
}
141+
142+
return bytesWritten, nil
93143
}
94144

95145
// Destroy destroys the compiled YARA rules represented by [Rules].

go/scanner_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ func TestScanner1(t *testing.T) {
1919
assert.Equal(t, "t", matchingRules[0].Identifier())
2020
assert.Equal(t, "default", matchingRules[0].Namespace())
2121
assert.Len(t, matchingRules[0].Patterns(), 0)
22+
23+
scanResults, _ = s.Scan(nil)
24+
matchingRules = scanResults.MatchingRules()
25+
26+
assert.Len(t, matchingRules, 1)
27+
assert.Equal(t, "t", matchingRules[0].Identifier())
28+
assert.Equal(t, "default", matchingRules[0].Namespace())
29+
assert.Len(t, matchingRules[0].Patterns(), 0)
2230
}
2331

2432
func TestScanner2(t *testing.T) {

0 commit comments

Comments
 (0)