Skip to content

Commit bd7ffe0

Browse files
committed
rework interface to handle empty trees + single leaves
1 parent a081d37 commit bd7ffe0

File tree

3 files changed

+173
-38
lines changed

3 files changed

+173
-38
lines changed

trie/binary.go

Lines changed: 167 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,90 @@ type (
4040

4141
type BinaryNode interface {
4242
Get([]byte, NodeResolverFn) ([]byte, error)
43-
Insert([]byte, []byte, NodeResolverFn) error
43+
Insert([]byte, []byte, NodeResolverFn) (BinaryNode, error)
4444
Commit() common.Hash
4545
Copy() BinaryNode
4646
Hash() common.Hash
4747
GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error)
48-
InsertValuesAtStem([]byte, [][]byte, NodeResolverFn) error
48+
InsertValuesAtStem([]byte, [][]byte, NodeResolverFn) (BinaryNode, error)
4949
CollectNodes([]byte, NodeFlushFn) error
5050

5151
toDot(parent, path string) string
5252
GetHeight() int
5353
}
5454

55-
var (
56-
errInsertingIntoHash = errors.New("cannot insert into hashed node")
57-
)
55+
type Empty struct{}
56+
57+
func (e Empty) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
58+
return nil, nil
59+
}
60+
61+
func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn) (BinaryNode, error) {
62+
var values [256][]byte
63+
values[key[31]] = value
64+
return &StemNode{
65+
Stem: append([]byte(nil), key[:31]...),
66+
Values: values[:],
67+
}, nil
68+
}
69+
70+
func (e Empty) Commit() common.Hash {
71+
return common.Hash{}
72+
}
73+
74+
func (e Empty) Copy() BinaryNode {
75+
return Empty{}
76+
}
77+
78+
func (e Empty) Hash() common.Hash {
79+
return common.Hash{}
80+
}
81+
82+
func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
83+
var values [256][]byte
84+
return values[:], nil
85+
}
86+
87+
func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn) (BinaryNode, error) {
88+
return &StemNode{
89+
Stem: append([]byte(nil), key[:31]...),
90+
Values: values,
91+
}, nil
92+
}
93+
94+
func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error {
95+
panic("not implemented") // TODO: Implement
96+
}
97+
98+
func (e Empty) toDot(parent string, path string) string {
99+
panic("not implemented") // TODO: Implement
100+
}
101+
102+
func (e Empty) GetHeight() int {
103+
return 0
104+
}
58105

59106
type HashedNode common.Hash
60107

61108
func (h HashedNode) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
62109
panic("not implemented") // TODO: Implement
63110
}
64111

65-
func (h HashedNode) Insert(_ []byte, _ []byte, _ NodeResolverFn) error {
66-
return errInsertingIntoHash
112+
func (h HashedNode) Insert(key []byte, value []byte, resolver NodeResolverFn) (BinaryNode, error) {
113+
if resolver == nil {
114+
return h, errors.New("resolver is nil")
115+
}
116+
117+
resolved, err := resolver(h[:])
118+
if err != nil {
119+
return nil, fmt.Errorf("insert error: %w", err)
120+
}
121+
node, err := DeserializeNode(resolved, 0)
122+
if err != nil {
123+
return nil, fmt.Errorf("insert node deserialization error: %w", err)
124+
}
125+
126+
return node.Insert(key, value, resolver)
67127
}
68128

69129
func (h HashedNode) Commit() common.Hash {
@@ -82,8 +142,21 @@ func (h HashedNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error
82142
panic("not implemented") // TODO: Implement
83143
}
84144

85-
func (h HashedNode) InsertValuesAtStem(_ []byte, _ [][]byte, _ NodeResolverFn) error {
86-
return errInsertingIntoHash
145+
func (h HashedNode) InsertValuesAtStem(key []byte, values [][]byte, resolver NodeResolverFn) (BinaryNode, error) {
146+
if resolver == nil {
147+
return h, errors.New("resolver is nil")
148+
}
149+
150+
resolved, err := resolver(h[:])
151+
if err != nil {
152+
return nil, fmt.Errorf("insert error: %w", err)
153+
}
154+
node, err := DeserializeNode(resolved, 0)
155+
if err != nil {
156+
return nil, fmt.Errorf("insert node deserialization error: %w", err)
157+
}
158+
159+
return node.InsertValuesAtStem(key, values, resolver)
87160
}
88161

89162
func (h HashedNode) toDot(parent string, path string) string {
@@ -107,16 +180,50 @@ func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) {
107180
panic("this should not be called directly")
108181
}
109182

110-
func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn) error {
183+
func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn) (BinaryNode, error) {
111184
if !bytes.Equal(bt.Stem, key[:31]) {
112-
return errors.New("invalid insertion: stem mismatch")
185+
// look for the first bit that differs
186+
// TODO maintaining a depth field would save some work
187+
for depth := 0; depth < 31*8; depth++ {
188+
bitStem := bt.Stem[depth/8] >> (7 - (depth % 8)) & 1
189+
190+
new := &InternalNode{}
191+
var child, other *BinaryNode
192+
if bitStem == 0 {
193+
new.left = bt
194+
child = &new.left
195+
other = &new.right
196+
} else {
197+
new.right = bt
198+
child = &new.right
199+
other = &new.left
200+
}
201+
202+
bitKey := key[depth/8] >> (7 - (depth % 8)) & 1
203+
if bitKey == bitStem {
204+
var err error
205+
*child, err = (*child).Insert(key, value, nil)
206+
if err != nil {
207+
return new, fmt.Errorf("insert error: %w", err)
208+
}
209+
} else {
210+
var values [256][]byte
211+
values[key[31]] = value
212+
*other = &StemNode{
213+
Stem: append([]byte(nil), key[:31]...),
214+
Values: values[:],
215+
}
216+
}
217+
218+
return new, nil
219+
}
113220
}
114221
if len(value) != 32 {
115-
return errors.New("invalid insertion: value length")
222+
return bt, errors.New("invalid insertion: value length")
116223
}
117224

118225
bt.Values[key[31]] = value
119-
return nil
226+
return bt, nil
120227
}
121228

122229
func (bt *StemNode) Commit() common.Hash {
@@ -178,13 +285,18 @@ func (bt *StemNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error
178285
return bt.Values[:], nil
179286
}
180287

181-
func (bt *StemNode) InsertValuesAtStem(_ []byte, values [][]byte, _ NodeResolverFn) error {
288+
func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn) (BinaryNode, error) {
289+
if !bytes.Equal(bt.Stem, key[:31]) {
290+
return &InternalNode{}, nil
291+
}
292+
293+
// same stem, just merge the two value lists
182294
for i, v := range values {
183295
if v != nil {
184296
bt.Values[i] = v
185297
}
186298
}
187-
return nil
299+
return bt, nil
188300
}
189301

190302
func (bt *StemNode) toDot(parent, path string) string {
@@ -213,7 +325,7 @@ type InternalNode struct {
213325
}
214326

215327
func NewBinaryNode() BinaryNode {
216-
return &InternalNode{}
328+
return Empty{}
217329
}
218330

219331
func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
@@ -251,7 +363,7 @@ func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error)
251363
return values[key[31]], nil
252364
}
253365

254-
func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn) error {
366+
func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn) (BinaryNode, error) {
255367
var values [256][]byte
256368
values[key[31]] = value
257369
return bt.InsertValuesAtStem(key[:31], values[:], resolver)
@@ -288,24 +400,29 @@ func (bt *InternalNode) Hash() common.Hash {
288400
return common.BytesToHash(h.Sum(nil))
289401
}
290402

291-
func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn) error {
403+
func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn) (BinaryNode, error) {
292404
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
293-
var child *BinaryNode
405+
var (
406+
child *BinaryNode
407+
err error
408+
)
294409
if bit == 0 {
295410
child = &bt.left
296411
} else {
297412
child = &bt.right
298413
}
299414

300-
if *child == nil {
301-
*child = &StemNode{
302-
Stem: append([]byte(nil), stem[:31]...),
303-
Values: values,
304-
}
305-
return nil
306-
}
415+
// if *child == nil {
416+
// *child = &StemNode{
417+
// Stem: append([]byte(nil), stem[:31]...),
418+
// Values: values,
419+
// }
420+
// return bt, nil
421+
// }
307422
// XXX il faut vérifier si c'est un stemnode et aussi faire le resolve
308-
return (*child).InsertValuesAtStem(stem, values, resolver)
423+
424+
*child, err = (*child).InsertValuesAtStem(stem, values, resolver)
425+
return bt, err
309426
}
310427

311428
func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
@@ -540,7 +657,6 @@ var zero [32]byte
540657

541658
func (t *VerkleTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error {
542659
var (
543-
err error
544660
basicData [32]byte
545661
values = make([][]byte, verkle.NodeWidth)
546662
stem = t.pointCache.GetTreeKeyBasicDataCached(addr[:])
@@ -563,21 +679,27 @@ func (t *VerkleTrie) UpdateAccount(addr common.Address, acc *types.StateAccount,
563679

564680
switch root := t.root.(type) {
565681
case *InternalNode:
566-
err = root.InsertValuesAtStem(stem, values, t.FlatdbNodeResolver)
682+
r, err := root.InsertValuesAtStem(stem, values, t.FlatdbNodeResolver)
683+
if err != nil {
684+
return fmt.Errorf("UpdateAccount (%x) error: %v", addr, err)
685+
}
686+
t.root = r
567687
default:
568688
return errInvalidRootType
569689
}
570-
if err != nil {
571-
return fmt.Errorf("UpdateAccount (%x) error: %v", addr, err)
572-
}
573690

574691
return nil
575692
}
576693

577694
func (trie *VerkleTrie) UpdateStem(key []byte, values [][]byte) error {
578695
switch root := trie.root.(type) {
579696
case *InternalNode:
580-
return root.InsertValuesAtStem(key, values, trie.FlatdbNodeResolver)
697+
r, err := root.InsertValuesAtStem(key, values, trie.FlatdbNodeResolver)
698+
if err != nil {
699+
return fmt.Errorf("UpdateStem (%x) error: %v", key, err)
700+
}
701+
trie.root = r
702+
return nil
581703
default:
582704
panic("invalid root type")
583705
}
@@ -595,7 +717,12 @@ func (trie *VerkleTrie) UpdateStorage(address common.Address, key, value []byte)
595717
} else {
596718
copy(v[32-len(value):], value[:])
597719
}
598-
return trie.root.Insert(k, v[:], trie.FlatdbNodeResolver)
720+
root, err := trie.root.Insert(k, v[:], trie.FlatdbNodeResolver)
721+
if err != nil {
722+
return fmt.Errorf("UpdateStorage (%x) error: %v", address, err)
723+
}
724+
trie.root = root
725+
return nil
599726
}
600727

601728
func (t *VerkleTrie) DeleteAccount(addr common.Address) error {
@@ -608,7 +735,12 @@ func (trie *VerkleTrie) DeleteStorage(addr common.Address, key []byte) error {
608735
pointEval := trie.pointCache.GetTreeKeyHeader(addr[:])
609736
k := utils.GetTreeKeyStorageSlotWithEvaluatedAddress(pointEval, key)
610737
var zero [32]byte
611-
return trie.root.Insert(k, zero[:], trie.FlatdbNodeResolver)
738+
root, err := trie.root.Insert(k, zero[:], trie.FlatdbNodeResolver)
739+
if err != nil {
740+
return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err)
741+
}
742+
trie.root = root
743+
return nil
612744
}
613745

614746
// Hash returns the root hash of the trie. It does not write to the database and

trie/binary_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ var (
3434

3535
func TestSingleEntry(t *testing.T) {
3636
tree := NewBinaryNode()
37-
tree.Insert(zeroKey[:], oneKey[:], nil)
38-
if tree.GetHeight() == 1 {
37+
tree, err := tree.Insert(zeroKey[:], oneKey[:], nil)
38+
if err != nil {
39+
t.Fatal(err)
40+
}
41+
if tree.GetHeight() != 1 {
3942
t.Fatal("invalid depth")
4043
}
4144
expected := common.HexToHash("694545468677064fd833cddc8455762fe6b21c6cabe2fc172529e0f573181cd5")

trie/transition.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ func (t *TransitionTrie) IsVerkle() bool {
175175
return true
176176
}
177177

178-
func (t *TransitionTrie) UpdateStem(key []byte, values [][]byte) error {
178+
func (t *TransitionTrie) UpdateStem(key []byte, values [][]byte) (BinaryNode, error) {
179179
trie := t.overlay
180180
switch root := trie.root.(type) {
181181
case *InternalNode:

0 commit comments

Comments
 (0)