@@ -23,6 +23,7 @@ import (
2323
2424 "github.com/ethereum/go-ethereum/common"
2525 "github.com/ethereum/go-ethereum/log"
26+ "github.com/ethereum/go-ethereum/trie/trienode"
2627 "golang.org/x/sync/errgroup"
2728)
2829
@@ -275,50 +276,74 @@ func (l *lookup) addLayer(diff *diffLayer) {
275276 wg .Add (1 )
276277 go func () {
277278 defer wg .Done ()
279+ l .addStorageNodes (state , diff .nodes .storageNodes )
280+ }()
278281
279- // Use concurrent workers for storage nodes updates, one per shard
280- var storageWg sync.WaitGroup
281- storageWg .Add (storageNodesShardCount )
282-
283- workChannels := make ([]chan string , storageNodesShardCount )
284- for i := 0 ; i < storageNodesShardCount ; i ++ {
285- workChannels [i ] = make (chan string , 10 ) // Buffer to avoid blocking
286- }
282+ wg .Wait ()
283+ }
287284
288- // Start all workers, each handling its own shard
289- for shardIndex := 0 ; shardIndex < storageNodesShardCount ; shardIndex ++ {
290- go func (shardIdx int ) {
291- defer storageWg .Done ()
292-
293- shard := l .storageNodes [shardIdx ]
294- for key := range workChannels [shardIdx ] {
295- list , exists := shard [key ]
296- if ! exists {
297- list = make ([]common.Hash , 0 , 16 ) // TODO(rjl493456442) use sync pool
298- }
299- list = append (list , state )
300- shard [key ] = list
301- }
302- }(shardIndex )
303- }
285+ func (l * lookup ) addStorageNodes (state common.Hash , nodes map [common.Hash ]map [string ]* trienode.Node ) {
286+ count := 0
287+ for _ , slots := range nodes {
288+ count += len (slots )
289+ }
304290
305- // Distribute work to workers based on shard index
306- for accountHash , slots := range diff .nodes .storageNodes {
291+ // If the number of storage nodes is small, use a single-threaded approach
292+ if count <= 1000 {
293+ for accountHash , slots := range nodes {
307294 accountHex := accountHash .Hex ()
308295 for path := range slots {
296+ key := accountHex + path
309297 shardIndex := getStorageShardIndex (path )
310- workChannels [shardIndex ] <- accountHex + path
298+ list , exists := l.storageNodes [shardIndex ][key ]
299+ if ! exists {
300+ list = make ([]common.Hash , 0 , 16 )
301+ }
302+ list = append (list , state )
303+ l.storageNodes [shardIndex ][key ] = list
311304 }
312305 }
306+ return
307+ }
308+
309+ // Use concurrent workers for storage nodes updates, one per shard
310+ var wg sync.WaitGroup
311+ wg .Add (storageNodesShardCount )
313312
314- // Close all channels to signal workers to finish
315- for i := 0 ; i < storageNodesShardCount ; i ++ {
316- close (workChannels [i ])
313+ workChannels := make ([]chan string , storageNodesShardCount )
314+ for i := 0 ; i < storageNodesShardCount ; i ++ {
315+ workChannels [i ] = make (chan string , 10 ) // Buffer to avoid blocking
316+ }
317+
318+ // Start all workers, each handling its own shard
319+ for shardIndex := 0 ; shardIndex < storageNodesShardCount ; shardIndex ++ {
320+ go func (shardIdx int ) {
321+ defer wg .Done ()
322+
323+ shard := l .storageNodes [shardIdx ]
324+ for key := range workChannels [shardIdx ] {
325+ list , exists := shard [key ]
326+ if ! exists {
327+ list = make ([]common.Hash , 0 , 16 ) // TODO(rjl493456442) use sync pool
328+ }
329+ list = append (list , state )
330+ shard [key ] = list
331+ }
332+ }(shardIndex )
333+ }
334+
335+ for accountHash , slots := range nodes {
336+ accountHex := accountHash .Hex ()
337+ for path := range slots {
338+ shardIndex := getStorageShardIndex (path )
339+ workChannels [shardIndex ] <- accountHex + path
317340 }
341+ }
318342
319- // Wait for all storage workers to complete
320- storageWg .Wait ()
321- }()
343+ // Close all channels to signal workers to finish
344+ for i := 0 ; i < storageNodesShardCount ; i ++ {
345+ close (workChannels [i ])
346+ }
322347
323348 wg .Wait ()
324349}
@@ -408,7 +433,19 @@ func (l *lookup) removeLayer(diff *diffLayer) error {
408433 })
409434
410435 eg .Go (func () error {
411- for accountHash , slots := range diff .nodes .storageNodes {
436+ return l .removeStorageNodes (state , diff .nodes .storageNodes )
437+ })
438+ return eg .Wait ()
439+ }
440+
441+ func (l * lookup ) removeStorageNodes (state common.Hash , nodes map [common.Hash ]map [string ]* trienode.Node ) error {
442+ count := 0
443+ for _ , slots := range nodes {
444+ count += len (slots )
445+ }
446+
447+ if count <= 1000 {
448+ for accountHash , slots := range nodes {
412449 accountHex := accountHash .Hex ()
413450 for path := range slots {
414451 // Construct the combined key and find the correct shard
@@ -426,6 +463,50 @@ func (l *lookup) removeLayer(diff *diffLayer) error {
426463 }
427464 }
428465 return nil
429- })
466+ }
467+
468+ // Use concurrent workers for storage nodes removal, one per shard
469+ var eg errgroup.Group
470+
471+ // Create work channels for each shard
472+ workChannels := make ([]chan string , storageNodesShardCount )
473+
474+ for i := 0 ; i < storageNodesShardCount ; i ++ {
475+ workChannels [i ] = make (chan string , 10 ) // Buffer to avoid blocking
476+ }
477+
478+ // Start all workers, each handling its own shard
479+ for shardIndex := 0 ; shardIndex < storageNodesShardCount ; shardIndex ++ {
480+ shardIdx := shardIndex // Capture the variable
481+ eg .Go (func () error {
482+ shard := l .storageNodes [shardIdx ]
483+ for key := range workChannels [shardIdx ] {
484+ found , list := removeFromList (shard [key ], state )
485+ if ! found {
486+ return fmt .Errorf ("storage lookup is not found, key: %s, state: %x" , key , state )
487+ }
488+ if len (list ) != 0 {
489+ shard [key ] = list
490+ } else {
491+ delete (shard , key )
492+ }
493+ }
494+ return nil
495+ })
496+ }
497+
498+ for accountHash , slots := range nodes {
499+ accountHex := accountHash .Hex ()
500+ for path := range slots {
501+ key := accountHex + path
502+ shardIndex := getStorageShardIndex (path )
503+ workChannels [shardIndex ] <- key
504+ }
505+ }
506+
507+ for i := 0 ; i < storageNodesShardCount ; i ++ {
508+ close (workChannels [i ])
509+ }
510+
430511 return eg .Wait ()
431512}
0 commit comments