Skip to content
18 changes: 18 additions & 0 deletions triedb/pathdb/layertree.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,21 @@ func (tree *layerTree) lookupStorage(accountHash common.Hash, slotHash common.Ha
}
return l, nil
}

// lookupNode returns the layer that is guaranteed to contain the trie node
// data corresponding to the specified state root being queried.
func (tree *layerTree) lookupNode(accountHash common.Hash, path string, state common.Hash) (layer, error) {
// Hold the read lock to prevent the unexpected layer changes
tree.lock.RLock()
defer tree.lock.RUnlock()

tip := tree.lookup.nodeTip(accountHash, path, state, tree.base.root)
if tip == (common.Hash{}) {
return nil, fmt.Errorf("[%#x] %w", state, errSnapshotStale)
}
l := tree.layers[tip]
if l == nil {
return nil, fmt.Errorf("triedb layer [%#x] missing", tip)
}
return l, nil
}
271 changes: 271 additions & 0 deletions triedb/pathdb/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@ import (
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/trie/trienode"
"golang.org/x/sync/errgroup"
)

// trienodeShardCount is the number of shards used for trie nodes.
const trienodeShardCount = 16

// storageKey returns a key for uniquely identifying the storage slot.
func storageKey(accountHash common.Hash, slotHash common.Hash) [64]byte {
var key [64]byte
Expand All @@ -33,6 +38,23 @@ func storageKey(accountHash common.Hash, slotHash common.Hash) [64]byte {
return key
}

// trienodeKey uses a fixed-size byte array instead of string to avoid string allocations.
type trienodeKey [96]byte // 32 bytes for hash + up to 64 bytes for path

// makeTrienodeKey returns a key for uniquely identifying the trie node.
func makeTrienodeKey(accountHash common.Hash, path string) trienodeKey {
var key trienodeKey
copy(key[:32], accountHash[:])
copy(key[32:], path)
return key
}

// shardTask used to batch task by shard to minimize lock contention
type shardTask struct {
accountHash common.Hash
path string
}

// lookup is an internal structure used to efficiently determine the layer in
// which a state entry resides.
type lookup struct {
Expand All @@ -48,11 +70,34 @@ type lookup struct {
// where the slot was modified, with the order from oldest to newest.
storages map[[64]byte][]common.Hash

// accountNodes represents the mutation history for specific account
// trie nodes, distributed across 16 shards for efficiency.
// The key is the trie path of the node, and the value is a slice
// of **diff layer** IDs indicating where the account was modified,
// with the order from oldest to newest.
accountNodes [trienodeShardCount]map[string][]common.Hash

// storageNodes represents the mutation history for specific storage
// slot trie nodes, distributed across 16 shards for efficiency.
// The key is the account address hash and the trie path of the node,
// the value is a slice of **diff layer** IDs indicating where the
// slot was modified, with the order from oldest to newest.
storageNodes [trienodeShardCount]map[trienodeKey][]common.Hash

// descendant is the callback indicating whether the layer with
// given root is a descendant of the one specified by `ancestor`.
descendant func(state common.Hash, ancestor common.Hash) bool
}

// getNodeShardIndex returns the shard index for a given path
func getNodeShardIndex(path string) int {
if len(path) == 0 {
return 0
}
// use the first char of the path to determine the shard index
return int(path[0]) % trienodeShardCount
}

// newLookup initializes the lookup structure.
func newLookup(head layer, descendant func(state common.Hash, ancestor common.Hash) bool) *lookup {
var (
Expand All @@ -68,6 +113,12 @@ func newLookup(head layer, descendant func(state common.Hash, ancestor common.Ha
storages: make(map[[64]byte][]common.Hash),
descendant: descendant,
}
// Initialize all 16 storage node shards
for i := 0; i < trienodeShardCount; i++ {
l.storageNodes[i] = make(map[trienodeKey][]common.Hash)
l.accountNodes[i] = make(map[string][]common.Hash)
}

// Apply the diff layers from bottom to top
for i := len(layers) - 1; i >= 0; i-- {
switch diff := layers[i].(type) {
Expand Down Expand Up @@ -161,6 +212,45 @@ func (l *lookup) storageTip(accountHash common.Hash, slotHash common.Hash, state
return common.Hash{}
}

// nodeTip traverses the layer list associated with the given account and path
// in reverse order to locate the first entry that either matches
// the specified stateID or is a descendant of it.
//
// If found, the trie node data corresponding to the supplied stateID resides
// in that layer. Otherwise, two scenarios are possible:
//
// (a) the trie node remains unmodified from the current disk layer up to
// the state layer specified by the stateID: fallback to the disk layer for
// data retrieval, (b) or the layer specified by the stateID is stale: reject
// the data retrieval.
func (l *lookup) nodeTip(accountHash common.Hash, path string, stateID common.Hash, base common.Hash) common.Hash {
var list []common.Hash
if accountHash == (common.Hash{}) {
shardIndex := getNodeShardIndex(path)
list = l.accountNodes[shardIndex][path]
} else {
shardIndex := getNodeShardIndex(path) // Use only path for sharding
list = l.storageNodes[shardIndex][makeTrienodeKey(accountHash, path)]
}
for i := len(list) - 1; i >= 0; i-- {
// If the current state matches the stateID, or the requested state is a
// descendant of it, return the current state as the most recent one
// containing the modified data. Otherwise, the current state may be ahead
// of the requested one or belong to a different branch.
if list[i] == stateID || l.descendant(stateID, list[i]) {
return list[i]
}
}
// No layer matching the stateID or its descendants was found. Use the
// current disk layer as a fallback.
if base == stateID || l.descendant(stateID, base) {
return base
}
// The layer associated with 'stateID' is not the descendant of the current
// disk layer, it's already stale, return nothing.
return common.Hash{}
}

// addLayer traverses the state data retained in the specified diff layer and
// integrates it into the lookup set.
//
Expand All @@ -170,6 +260,7 @@ func (l *lookup) storageTip(accountHash common.Hash, slotHash common.Hash, state
func (l *lookup) addLayer(diff *diffLayer) {
defer func(now time.Time) {
lookupAddLayerTimer.UpdateSince(now)
log.Debug("PathDB lookup add layer", "id", diff.id, "block", diff.block, "elapsed", time.Since(now))
}(time.Now())

var (
Expand Down Expand Up @@ -204,6 +295,97 @@ func (l *lookup) addLayer(diff *diffLayer) {
}
}
}()

wg.Add(1)
go func() {
defer wg.Done()
l.addAccountNodes(state, diff.nodes.accountNodes)
}()

wg.Add(1)
go func() {
defer wg.Done()
l.addStorageNodes(state, diff.nodes.storageNodes)
}()

states := len(diff.states.accountData)
for _, slots := range diff.states.storageData {
states += len(slots)
}
lookupStateMeter.Mark(int64(states))

trienodes := len(diff.nodes.accountNodes)
for _, nodes := range diff.nodes.storageNodes {
trienodes += len(nodes)
}
lookupTrienodeMeter.Mark(int64(trienodes))

wg.Wait()
}

func (l *lookup) addStorageNodes(state common.Hash, nodes map[common.Hash]map[string]*trienode.Node) {
defer func(start time.Time) {
lookupAddTrienodeLayerTimer.UpdateSince(start)
}(time.Now())

var (
wg sync.WaitGroup
tasks = make([][]shardTask, trienodeShardCount)
)
for accountHash, slots := range nodes {
for path := range slots {
shardIndex := getNodeShardIndex(path)
tasks[shardIndex] = append(tasks[shardIndex], shardTask{
accountHash: accountHash,
path: path,
})
}
}
for shardIdx := 0; shardIdx < trienodeShardCount; shardIdx++ {
taskList := tasks[shardIdx]
if len(taskList) == 0 {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
shard := l.storageNodes[shardIdx]
for _, task := range taskList {
key := makeTrienodeKey(task.accountHash, task.path)
shard[key] = append(shard[key], state)
}
}()
}
wg.Wait()
}

func (l *lookup) addAccountNodes(state common.Hash, nodes map[string]*trienode.Node) {
defer func(start time.Time) {
lookupAddTrienodeLayerTimer.UpdateSince(start)
}(time.Now())

var (
wg sync.WaitGroup
tasks = make([][]string, trienodeShardCount)
)
for path := range nodes {
shardIndex := getNodeShardIndex(path)
tasks[shardIndex] = append(tasks[shardIndex], path)
}
for shardIdx := 0; shardIdx < trienodeShardCount; shardIdx++ {
taskList := tasks[shardIdx]
if len(taskList) == 0 {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
shard := l.accountNodes[shardIdx]
for _, path := range taskList {
shard[path] = append(shard[path], state)
}
}()
}
wg.Wait()
}

Expand Down Expand Up @@ -236,6 +418,7 @@ func removeFromList(list []common.Hash, element common.Hash) (bool, []common.Has
func (l *lookup) removeLayer(diff *diffLayer) error {
defer func(now time.Time) {
lookupRemoveLayerTimer.UpdateSince(now)
log.Debug("PathDB lookup remove layer", "id", diff.id, "block", diff.block, "elapsed", time.Since(now))
}(time.Now())

var (
Expand Down Expand Up @@ -274,5 +457,93 @@ func (l *lookup) removeLayer(diff *diffLayer) error {
}
return nil
})

eg.Go(func() error {
return l.removeAccountNodes(state, diff.nodes.accountNodes)
})

eg.Go(func() error {
return l.removeStorageNodes(state, diff.nodes.storageNodes)
})
return eg.Wait()
}

func (l *lookup) removeStorageNodes(state common.Hash, nodes map[common.Hash]map[string]*trienode.Node) error {
defer func(start time.Time) {
lookupRemoveTrienodeLayerTimer.UpdateSince(start)
}(time.Now())

var (
eg errgroup.Group
tasks = make([][]shardTask, trienodeShardCount)
)
for accountHash, slots := range nodes {
for path := range slots {
shardIndex := getNodeShardIndex(path)
tasks[shardIndex] = append(tasks[shardIndex], shardTask{
accountHash: accountHash,
path: path,
})
}
}
for shardIdx := 0; shardIdx < trienodeShardCount; shardIdx++ {
taskList := tasks[shardIdx]
if len(taskList) == 0 {
continue
}
eg.Go(func() error {
shard := l.storageNodes[shardIdx]
for _, task := range taskList {
key := makeTrienodeKey(task.accountHash, task.path)
found, list := removeFromList(shard[key], state)
if !found {
return fmt.Errorf("storage lookup is not found, key: %x, state: %x", key, state)
}
if len(list) != 0 {
shard[key] = list
} else {
delete(shard, key)
}
}
return nil
})
}
return eg.Wait()
}

func (l *lookup) removeAccountNodes(state common.Hash, nodes map[string]*trienode.Node) error {
defer func(start time.Time) {
lookupRemoveTrienodeLayerTimer.UpdateSince(start)
}(time.Now())

var (
eg errgroup.Group
tasks = make([][]string, trienodeShardCount)
)
for path := range nodes {
shardIndex := getNodeShardIndex(path)
tasks[shardIndex] = append(tasks[shardIndex], path)
}
for shardIdx := 0; shardIdx < trienodeShardCount; shardIdx++ {
taskList := tasks[shardIdx]
if len(taskList) == 0 {
continue
}
eg.Go(func() error {
shard := l.accountNodes[shardIdx]
for _, path := range taskList {
found, list := removeFromList(shard[path], state)
if !found {
return fmt.Errorf("account lookup is not found, %x, state: %x", path, state)
}
if len(list) != 0 {
shard[path] = list
} else {
delete(shard, path)
}
}
return nil
})
}
return eg.Wait()
}
Loading
Loading