diff --git a/consensus/dummy/consensus.go b/consensus/dummy/consensus.go index 3f6d861c8b..1af811f2da 100644 --- a/consensus/dummy/consensus.go +++ b/consensus/dummy/consensus.go @@ -34,10 +34,39 @@ type Mode struct { ModeSkipCoinbase bool } -type DummyEngine struct { - consensusMode Mode - desiredDelayExcess *acp226.DelayExcess -} +type ( + OnFinalizeAndAssembleCallbackType = func( + header *types.Header, + parent *types.Header, + state *state.StateDB, + txs []*types.Transaction, + ) ( + extraData []byte, + blockFeeContribution *big.Int, + extDataGasUsed *big.Int, + err error, + ) + + OnExtraStateChangeType = func( + block *types.Block, + parent *types.Header, + statedb *state.StateDB, + ) ( + blockFeeContribution *big.Int, + extDataGasUsed *big.Int, + err error, + ) + + ConsensusCallbacks struct { + OnFinalizeAndAssemble OnFinalizeAndAssembleCallbackType + OnExtraStateChange OnExtraStateChangeType + } + + DummyEngine struct { + consensusMode Mode + desiredDelayExcess *acp226.DelayExcess + } +) func NewDummyEngine( mode Mode, diff --git a/plugin/evm/block_test.go b/plugin/evm/block_test.go index 0c8bebcf25..6b0a0e2465 100644 --- a/plugin/evm/block_test.go +++ b/plugin/evm/block_test.go @@ -16,6 +16,7 @@ import ( "github.com/ava-labs/subnet-evm/params" "github.com/ava-labs/subnet-evm/params/extras" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" "github.com/ava-labs/subnet-evm/precompile/precompileconfig" ) @@ -26,8 +27,9 @@ func TestHandlePrecompileAccept(t *testing.T) { db := rawdb.NewMemoryDatabase() vm := &VM{ - chaindb: db, - chainConfig: params.TestChainConfig, + chaindb: db, + chainConfig: params.TestChainConfig, + extensionConfig: &extension.Config{}, } precompileAddr := common.Address{0x05} diff --git a/plugin/evm/config/config.go b/plugin/evm/config/config.go index c10f0d2232..a37ac973bd 100644 --- a/plugin/evm/config/config.go +++ b/plugin/evm/config/config.go @@ -236,7 +236,7 @@ func (d Duration) MarshalJSON() ([]byte, error) { } // validate returns an error if this is an invalid config. -func (c *Config) validate(_ uint32) error { +func (c *Config) validate(uint32) error { if c.PopulateMissingTries != nil && (c.OfflinePruning || c.Pruning) { return fmt.Errorf("cannot enable populate missing tries while offline pruning (enabled: %t)/pruning (enabled: %t) are enabled", c.OfflinePruning, c.Pruning) } diff --git a/plugin/evm/extension/config.go b/plugin/evm/extension/config.go index 544d13a214..83f07500ac 100644 --- a/plugin/evm/extension/config.go +++ b/plugin/evm/extension/config.go @@ -4,13 +4,30 @@ package extension import ( + "context" "errors" + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/versiondb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p" + "github.com/ava-labs/avalanchego/snow/consensus/snowman" + "github.com/ava-labs/avalanchego/snow/engine/snowman/block" "github.com/ava-labs/avalanchego/utils/timer/mockable" + "github.com/ava-labs/libevm/common" + "github.com/ava-labs/libevm/core/types" + "github.com/prometheus/client_golang/prometheus" + "github.com/ava-labs/subnet-evm/consensus/dummy" + "github.com/ava-labs/subnet-evm/core" + "github.com/ava-labs/subnet-evm/params" + "github.com/ava-labs/subnet-evm/params/extras" + "github.com/ava-labs/subnet-evm/plugin/evm/config" "github.com/ava-labs/subnet-evm/plugin/evm/message" "github.com/ava-labs/subnet-evm/plugin/evm/sync" "github.com/ava-labs/subnet-evm/sync/handlers" + + avalanchecommon "github.com/ava-labs/avalanchego/snow/engine/common" ) var ( @@ -20,6 +37,85 @@ var ( errNilClock = errors.New("nil clock") ) +type ExtensibleVM interface { + // SetExtensionConfig sets the configuration for the VM extension + // Should be called before any other method and only once + SetExtensionConfig(config *Config) error + // NewClient returns a client to send messages with for the given protocol + NewClient(protocol uint64) *p2p.Client + // AddHandler registers a server handler for an application protocol + AddHandler(protocol uint64, handler p2p.Handler) error + // GetExtendedBlock returns the VMBlock for the given ID or an error if the block is not found + GetExtendedBlock(context.Context, ids.ID) (ExtendedBlock, error) + // LastAcceptedExtendedBlock returns the last accepted VM block + LastAcceptedExtendedBlock() ExtendedBlock + // ChainConfig returns the chain config for the VM + ChainConfig() *params.ChainConfig + // P2PValidators returns the validators for the network + P2PValidators() *p2p.Validators + // Blockchain returns the blockchain client + Blockchain() *core.BlockChain + // Config returns the configuration for the VM + Config() config.Config + // MetricRegistry returns the metric registry for the VM + MetricRegistry() *prometheus.Registry + // ReadLastAccepted returns the last accepted block hash and height + ReadLastAccepted() (common.Hash, uint64, error) + // VersionDB returns the versioned database for the VM + VersionDB() *versiondb.Database +} + +// InnerVM is the interface that must be implemented by the VM +// that's being wrapped by the extension +type InnerVM interface { + ExtensibleVM + avalanchecommon.VM + block.ChainVM + block.BuildBlockWithContextChainVM + block.StateSyncableVM +} + +// ExtendedBlock is a block that can be used by the extension +type ExtendedBlock interface { + snowman.Block + GetEthBlock() *types.Block + GetBlockExtension() BlockExtension +} + +type BlockExtender interface { + // NewBlockExtension is called when a new block is created + NewBlockExtension(b ExtendedBlock) (BlockExtension, error) +} + +// BlockExtension allows the VM extension to handle block processing events. +type BlockExtension interface { + // SyntacticVerify verifies the block syntactically + // it can be implemented to extend inner block verification + SyntacticVerify(rules extras.Rules) error + // SemanticVerify verifies the block semantically + // it can be implemented to extend inner block verification + SemanticVerify() error + // CleanupVerified is called when a block has passed SemanticVerify and SynctacticVerify, + // and should be cleaned up due to error or verification runs under non-write mode. This + // does not return an error because the block has already been verified. + CleanupVerified() + // Accept is called when a block is accepted by the block manager. Accept takes a + // database.Batch that contains the changes that were made to the database as a result + // of accepting the block. The changes in the batch should be flushed to the database in this method. + Accept(acceptedBatch database.Batch) error + // Reject is called when a block is rejected by the block manager + Reject() error +} + +// BuilderMempool is a mempool that's used in the block builder +type BuilderMempool interface { + // PendingLen returns the number of pending transactions + // that are waiting to be included in a block + PendingLen() int + // SubscribePendingTxs returns a channel that's signaled when there are pending transactions + SubscribePendingTxs() <-chan struct{} +} + // LeafRequestConfig is the configuration to handle leaf requests // in the network and syncer type LeafRequestConfig struct { @@ -33,13 +129,29 @@ type LeafRequestConfig struct { // Config is the configuration for the VM extension type Config struct { + // ConsensusCallbacks is the consensus callbacks to use + // for the VM to be used in consensus engine. + // Callback functions can be nil. + ConsensusCallbacks dummy.ConsensusCallbacks // SyncSummaryProvider is the sync summary provider to use // for the VM to be used in syncer. // It's required and should be non-nil SyncSummaryProvider sync.SummaryProvider + // SyncExtender can extend the syncer to handle custom sync logic. + // It's optional and can be nil + SyncExtender sync.Extender // SyncableParser is to parse summary messages from the network. // It's required and should be non-nil SyncableParser message.SyncableParser + // BlockExtender allows the VM extension to create an extension to handle block processing events. + // It's optional and can be nil + BlockExtender BlockExtender + // ExtraSyncLeafHandlerConfig is the extra configuration to handle leaf requests + // in the network and syncer. It's optional and can be nil + ExtraSyncLeafHandlerConfig *LeafRequestConfig + // ExtraMempool is the mempool to be used in the block builder. + // It's optional and can be nil + ExtraMempool BuilderMempool // Clock is the clock to use for time related operations. // It's optional and can be nil Clock *mockable.Clock diff --git a/plugin/evm/message/block_sync_summary_test.go b/plugin/evm/message/block_sync_summary_test.go new file mode 100644 index 0000000000..40aa926d7c --- /dev/null +++ b/plugin/evm/message/block_sync_summary_test.go @@ -0,0 +1,44 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package message + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/ava-labs/avalanchego/snow/engine/snowman/block" + "github.com/ava-labs/libevm/common" + "github.com/stretchr/testify/require" +) + +func TestMarshalBlockSyncSummary(t *testing.T) { + blockSyncSummary, err := NewBlockSyncSummary(common.Hash{1}, 2, common.Hash{3}) + require.NoError(t, err) + + require.Equal(t, common.Hash{1}, blockSyncSummary.GetBlockHash()) + require.Equal(t, uint64(2), blockSyncSummary.Height()) + require.Equal(t, common.Hash{3}, blockSyncSummary.GetBlockRoot()) + + expectedBase64Bytes := "AAAAAAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" + require.Equal(t, expectedBase64Bytes, base64.StdEncoding.EncodeToString(blockSyncSummary.Bytes())) + + parser := NewBlockSyncSummaryParser() + called := false + acceptImplTest := func(Syncable) (block.StateSyncMode, error) { + called = true + return block.StateSyncSkipped, nil + } + s, err := parser.Parse(blockSyncSummary.Bytes(), acceptImplTest) + require.NoError(t, err) + require.Equal(t, blockSyncSummary.GetBlockHash(), s.GetBlockHash()) + require.Equal(t, blockSyncSummary.Height(), s.Height()) + require.Equal(t, blockSyncSummary.GetBlockRoot(), s.GetBlockRoot()) + require.Equal(t, blockSyncSummary.Bytes(), s.Bytes()) + + mode, err := s.Accept(context.TODO()) + require.NoError(t, err) + require.Equal(t, block.StateSyncSkipped, mode) + require.True(t, called) +} diff --git a/plugin/evm/message/codec.go b/plugin/evm/message/codec.go index 59a2632582..ba71af2f32 100644 --- a/plugin/evm/message/codec.go +++ b/plugin/evm/message/codec.go @@ -22,13 +22,11 @@ func init() { c := linearcodec.NewDefault() // Skip registration to keep registeredTypes unchanged after legacy gossip deprecation - c.SkipRegistrations(1) + // Gossip types and sync summary type removed from codec + c.SkipRegistrations(2) errs := wrappers.Errs{} errs.Add( - // Types for state sync frontier consensus - c.RegisterType(BlockSyncSummary{}), - // state sync types c.RegisterType(BlockRequest{}), c.RegisterType(BlockResponse{}), diff --git a/plugin/evm/vm.go b/plugin/evm/vm.go index dd655bcbf4..5d123f2307 100644 --- a/plugin/evm/vm.go +++ b/plugin/evm/vm.go @@ -31,7 +31,6 @@ import ( "github.com/ava-labs/avalanchego/utils/profiler" "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/utils/units" - "github.com/ava-labs/avalanchego/version" "github.com/ava-labs/avalanchego/vms/components/chain" "github.com/ava-labs/avalanchego/vms/evm/acp226" "github.com/ava-labs/avalanchego/vms/evm/uptimetracker" @@ -187,6 +186,7 @@ type VM struct { config config.Config + chainID *big.Int genesisHash common.Hash chainConfig *params.ChainConfig ethConfig ethconfig.Config @@ -279,7 +279,7 @@ func (vm *VM) Initialize( appSender commonEng.AppSender, ) error { vm.ctx = chainCtx - vm.stateSyncDone = make(chan struct{}) + cfg, deprecateMsg, err := config.GetConfig(configBytes, vm.ctx.NetworkID) if err != nil { return fmt.Errorf("failed to get config: %w", err) @@ -341,14 +341,22 @@ func (vm *VM) Initialize( return err } + // vm.ChainConfig() should be available for wrapping VMs before vm.initializeChain() + vm.chainConfig = g.Config + vm.chainID = g.Config.ChainID + vm.ethConfig = ethconfig.NewDefaultConfig() vm.ethConfig.Genesis = g - // NetworkID here is different than Avalanche's NetworkID. - // Avalanche's NetworkID represents the Avalanche network is running on - // like Fuji, Mainnet, Local, etc. - // The NetworkId here is kept same as ChainID to be compatible with - // Ethereum tooling. - vm.ethConfig.NetworkId = g.Config.ChainID.Uint64() + vm.ethConfig.NetworkId = vm.chainID.Uint64() + vm.genesisHash = vm.ethConfig.Genesis.ToBlock().Hash() // must create genesis hash before [vm.ReadLastAccepted] + lastAcceptedHash, lastAcceptedHeight, err := vm.ReadLastAccepted() + if err != nil { + return err + } + log.Info("read last accepted", + "hash", lastAcceptedHash, + "height", lastAcceptedHeight, + ) // Set minimum price for mining and default gas price oracle value to the min // gas price to prevent so transactions and blocks all use the correct fees @@ -439,20 +447,6 @@ func (vm *VM) Initialize( vm.ethConfig.Miner.Etherbase = constants.BlackholeAddr } - vm.chainConfig = g.Config - - // create genesisHash after applying upgradeBytes in case - // upgradeBytes modifies genesis. - vm.genesisHash = vm.ethConfig.Genesis.ToBlock().Hash() // must create genesis hash before [vm.readLastAccepted] - lastAcceptedHash, lastAcceptedHeight, err := vm.readLastAccepted() - if err != nil { - return err - } - log.Info("read last accepted", - "hash", lastAcceptedHash, - "height", lastAcceptedHeight, - ) - vm.networkCodec = message.Codec vm.Network, err = network.NewNetwork(vm.ctx, appSender, vm.networkCodec, vm.config.MaxOutboundActiveRequests, vm.sdkMetrics) if err != nil { @@ -495,7 +489,6 @@ func (vm *VM) Initialize( if err != nil { return err } - if err := vm.initializeChain(lastAcceptedHash, vm.ethConfig); err != nil { return err } @@ -731,6 +724,7 @@ func (vm *VM) initializeStateSync(lastAcceptedHeight uint64) error { MetadataDB: vm.metadataDB, Acceptor: vm, Parser: vm.extensionConfig.SyncableParser, + Extender: nil, }) // If StateSync is disabled, clear any ongoing summary so that we will not attempt to resume @@ -804,7 +798,6 @@ func (vm *VM) onBootstrapStarted() error { // Ensure snapshots are initialized before bootstrapping (i.e., if state sync is skipped). // Note calling this function has no effect if snapshots are already initialized. vm.blockChain.InitializeSnapshots() - return nil } @@ -1302,7 +1295,7 @@ func (vm *VM) startContinuousProfiler() { // last accepted block hash and height by reading directly from [vm.chaindb] instead of relying // on [chain]. // Note: assumes [vm.chaindb] and [vm.genesisHash] have been initialized. -func (vm *VM) readLastAccepted() (common.Hash, uint64, error) { +func (vm *VM) ReadLastAccepted() (common.Hash, uint64, error) { // Attempt to load last accepted block to determine if it is necessary to // initialize state with the genesis block. lastAcceptedBytes, lastAcceptedErr := vm.acceptedBlockDB.Get(lastAcceptedKey) @@ -1370,28 +1363,6 @@ func attachEthService(handler *rpc.Server, apis []rpc.API, names []string) error return nil } -func (vm *VM) Connected(ctx context.Context, nodeID ids.NodeID, version *version.Application) error { - vm.vmLock.Lock() - defer vm.vmLock.Unlock() - - if err := vm.uptimeTracker.Connect(nodeID); err != nil { - return fmt.Errorf("uptime tracker failed to connect node %s: %w", nodeID, err) - } - - return vm.Network.Connected(ctx, nodeID, version) -} - -func (vm *VM) Disconnected(ctx context.Context, nodeID ids.NodeID) error { - vm.vmLock.Lock() - defer vm.vmLock.Unlock() - - if err := vm.uptimeTracker.Disconnect(nodeID); err != nil { - return fmt.Errorf("uptime tracker failed to disconnect node %s: %w", nodeID, err) - } - - return vm.Network.Disconnected(ctx, nodeID) -} - func (vm *VM) PutLastAcceptedID(id ids.ID) error { return vm.acceptedBlockDB.Put(lastAcceptedKey, id[:]) } diff --git a/plugin/evm/vm_extensible.go b/plugin/evm/vm_extensible.go new file mode 100644 index 0000000000..2fae0e156e --- /dev/null +++ b/plugin/evm/vm_extensible.go @@ -0,0 +1,92 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package evm + +import ( + "context" + "errors" + + "github.com/ava-labs/avalanchego/database/versiondb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p" + "github.com/prometheus/client_golang/prometheus" + + "github.com/ava-labs/subnet-evm/core" + "github.com/ava-labs/subnet-evm/params" + "github.com/ava-labs/subnet-evm/plugin/evm/config" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" + + vmsync "github.com/ava-labs/subnet-evm/plugin/evm/sync" +) + +var _ extension.InnerVM = (*VM)(nil) + +var ( + errVMAlreadyInitialized = errors.New("vm already initialized") + errExtensionConfigAlreadySet = errors.New("extension config already set") +) + +func (vm *VM) SetExtensionConfig(config *extension.Config) error { + if vm.ctx != nil { + return errVMAlreadyInitialized + } + if vm.extensionConfig != nil { + return errExtensionConfigAlreadySet + } + vm.extensionConfig = config + return nil +} + +// All these methods below assumes that VM is already initialized + +func (vm *VM) GetExtendedBlock(ctx context.Context, blkID ids.ID) (extension.ExtendedBlock, error) { + // Since each internal handler used by [vm.State] always returns a block + // with non-nil ethBlock value, GetBlockInternal should never return a + // (*Block) with a nil ethBlock value. + blk, err := vm.GetBlockInternal(ctx, blkID) + if err != nil { + return nil, err + } + + return blk.(*wrappedBlock), nil +} + +func (vm *VM) LastAcceptedExtendedBlock() extension.ExtendedBlock { + lastAcceptedBlock := vm.LastAcceptedBlockInternal() + if lastAcceptedBlock == nil { + return nil + } + return lastAcceptedBlock.(*wrappedBlock) +} + +// ChainConfig returns the chain config for the VM +// Even though this is available through Blockchain().Config(), +// ChainConfig() here will be available before the blockchain is initialized. +func (vm *VM) ChainConfig() *params.ChainConfig { + return vm.chainConfig +} + +func (vm *VM) Blockchain() *core.BlockChain { + return vm.blockChain +} + +func (vm *VM) Config() config.Config { + return vm.config +} + +func (vm *VM) MetricRegistry() *prometheus.Registry { + return vm.sdkMetrics +} + +func (vm *VM) Validators() *p2p.Validators { + return vm.P2PValidators() +} + +func (vm *VM) VersionDB() *versiondb.Database { + return vm.versiondb +} + +func (vm *VM) SyncerClient() vmsync.Client { + return vm.Client +} diff --git a/plugin/evm/vm_test.go b/plugin/evm/vm_test.go index e4172cfad3..f1bc1069af 100644 --- a/plugin/evm/vm_test.go +++ b/plugin/evm/vm_test.go @@ -59,6 +59,7 @@ import ( "github.com/ava-labs/subnet-evm/plugin/evm/customheader" "github.com/ava-labs/subnet-evm/plugin/evm/customrawdb" "github.com/ava-labs/subnet-evm/plugin/evm/customtypes" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" "github.com/ava-labs/subnet-evm/plugin/evm/vmerrors" "github.com/ava-labs/subnet-evm/precompile/allowlist" "github.com/ava-labs/subnet-evm/precompile/contracts/deployerallowlist" @@ -1682,7 +1683,7 @@ func testEmptyBlock(t *testing.T, scheme string) { } // Create empty block from blkA - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() emptyEthBlock := types.NewBlock( types.CopyHeader(ethBlock.Header()), @@ -1905,7 +1906,7 @@ func testAcceptReorg(t *testing.T, scheme string) { t.Fatalf("Block failed verification on VM1: %s", err) } - blkBHash := vm1BlkB.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Hash() + blkBHash := vm1BlkB.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Hash() if b := vm1.blockChain.CurrentBlock(); b.Hash() != blkBHash { t.Fatalf("expected current block to have hash %s but got %s", blkBHash.Hex(), b.Hash().Hex()) } @@ -1914,7 +1915,7 @@ func testAcceptReorg(t *testing.T, scheme string) { t.Fatal(err) } - blkCHash := vm1BlkC.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Hash() + blkCHash := vm1BlkC.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Hash() if b := vm1.blockChain.CurrentBlock(); b.Hash() != blkCHash { t.Fatalf("expected current block to have hash %s but got %s", blkCHash.Hex(), b.Hash().Hex()) } @@ -1925,7 +1926,7 @@ func testAcceptReorg(t *testing.T, scheme string) { if err := vm1BlkD.Accept(context.Background()); err != nil { t.Fatal(err) } - blkDHash := vm1BlkD.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Hash() + blkDHash := vm1BlkD.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Hash() if b := vm1.blockChain.CurrentBlock(); b.Hash() != blkDHash { t.Fatalf("expected current block to have hash %s but got %s", blkDHash.Hex(), b.Hash().Hex()) } @@ -2161,7 +2162,7 @@ func testLastAcceptedBlockNumberAllow(t *testing.T, scheme string) { } blkHeight := blk.Height() - blkHash := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Hash() + blkHash := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Hash() tvm.vm.eth.APIBackend.SetAllowUnfinalizedQueries(true) @@ -2255,7 +2256,7 @@ func testBuildAllowListActivationBlock(t *testing.T, scheme string) { } // Verify that the allow list config activation was handled correctly in the first block. - blkState, err := tvm.vm.blockChain.StateAt(blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Root()) + blkState, err := tvm.vm.blockChain.StateAt(blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Root()) if err != nil { t.Fatal(err) } @@ -2371,7 +2372,7 @@ func TestTxAllowListSuccessfulTx(t *testing.T) { require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) // Verify that the constructed block only has the whitelisted tx - block := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() txs := block.Transactions() @@ -2395,7 +2396,7 @@ func TestTxAllowListSuccessfulTx(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - block = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() blkState, err := tvm.vm.blockChain.StateAt(block.Root()) require.NoError(t, err) @@ -2421,7 +2422,7 @@ func TestTxAllowListSuccessfulTx(t *testing.T) { require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) // Verify that the constructed block only has the whitelisted tx - block = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() txs = block.Transactions() require.Len(t, txs, 1) @@ -2576,7 +2577,7 @@ func TestTxAllowListDisablePrecompile(t *testing.T) { blk := issueAndAccept(t, tvm.vm) // Verify that the constructed block only has the whitelisted tx - block := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() txs := block.Transactions() if txs.Len() != 1 { t.Fatalf("Expected number of txs to be %d, but found %d", 1, txs.Len()) @@ -2598,7 +2599,7 @@ func TestTxAllowListDisablePrecompile(t *testing.T) { blk = issueAndAccept(t, tvm.vm) // Verify that the constructed block only has the previously rejected tx - block = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() txs = block.Transactions() if txs.Len() != 1 { t.Fatalf("Expected number of txs to be %d, but found %d", 1, txs.Len()) @@ -2704,7 +2705,7 @@ func TestFeeManagerChangeFee(t *testing.T) { t.Fatalf("Expected new block to match") } - block := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() feeConfig, lastChangedAt, err = tvm.vm.blockChain.GetFeeConfigAt(block.Header()) require.NoError(t, err) @@ -2786,16 +2787,16 @@ func testAllowFeeRecipientDisabled(t *testing.T, scheme string) { blk, err := tvm.vm.BuildBlock(context.Background()) require.NoError(t, err) // this won't return an error since miner will set the etherbase to blackhole address - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, constants.BlackholeAddr, ethBlock.Coinbase()) // Create empty block from blk - internalBlk := blk.(*chain.BlockWrapper).Block.(*wrappedBlock) - modifiedHeader := types.CopyHeader(internalBlk.ethBlock.Header()) + internalBlk := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock) + modifiedHeader := types.CopyHeader(internalBlk.GetEthBlock().Header()) modifiedHeader.Coinbase = common.HexToAddress("0x0123456789") // set non-blackhole address by force modifiedBlock := types.NewBlock( modifiedHeader, - internalBlk.ethBlock.Transactions(), + internalBlk.GetEthBlock().Transactions(), nil, nil, trie.NewStackTrie(nil), @@ -2860,7 +2861,7 @@ func TestAllowFeeRecipientEnabled(t *testing.T) { if newHead.Head.Hash() != common.Hash(blk.ID()) { t.Fatalf("Expected new block to match") } - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, etherBase, ethBlock.Coinbase()) // Verify that etherBase has received fees blkState, err := tvm.vm.blockChain.StateAt(ethBlock.Root()) @@ -2939,7 +2940,7 @@ func TestRewardManagerPrecompileSetRewardAddress(t *testing.T) { blk := issueAndAccept(t, tvm.vm) newHead := <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, etherBase, ethBlock.Coinbase()) // reward address is activated at this block so this is fine tx1 := types.NewTransaction(uint64(0), testEthAddrs[0], big.NewInt(2), 21000, big.NewInt(testMinGasPrice*3), nil) @@ -2954,7 +2955,7 @@ func TestRewardManagerPrecompileSetRewardAddress(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, testAddr, ethBlock.Coinbase()) // reward address was activated at previous block // Verify that etherBase has received fees blkState, err := tvm.vm.blockChain.StateAt(ethBlock.Root()) @@ -2981,7 +2982,7 @@ func TestRewardManagerPrecompileSetRewardAddress(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() // Reward manager deactivated at this block, so we expect the parent state // to determine the coinbase for this block before full deactivation in the // next block. @@ -3002,7 +3003,7 @@ func TestRewardManagerPrecompileSetRewardAddress(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() // reward manager was disabled at previous block // so this block should revert back to enabling fee recipients require.Equal(t, etherBase, ethBlock.Coinbase()) @@ -3080,7 +3081,7 @@ func TestRewardManagerPrecompileAllowFeeRecipients(t *testing.T) { blk := issueAndAccept(t, tvm.vm) newHead := <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, constants.BlackholeAddr, ethBlock.Coinbase()) // reward address is activated at this block so this is fine tx1 := types.NewTransaction(uint64(0), testEthAddrs[0], big.NewInt(2), 21000, big.NewInt(testMinGasPrice*3), nil) @@ -3095,7 +3096,7 @@ func TestRewardManagerPrecompileAllowFeeRecipients(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, etherBase, ethBlock.Coinbase()) // reward address was activated at previous block // Verify that etherBase has received fees blkState, err := tvm.vm.blockChain.StateAt(ethBlock.Root()) @@ -3121,7 +3122,7 @@ func TestRewardManagerPrecompileAllowFeeRecipients(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, etherBase, ethBlock.Coinbase()) // reward address was activated at previous block require.GreaterOrEqual(t, int64(ethBlock.Time()), disableTime.Unix()) @@ -3138,7 +3139,7 @@ func TestRewardManagerPrecompileAllowFeeRecipients(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, constants.BlackholeAddr, ethBlock.Coinbase()) // reward address was activated at previous block require.Greater(t, int64(ethBlock.Time()), disableTime.Unix()) @@ -3298,7 +3299,7 @@ func TestParentBeaconRootBlock(t *testing.T) { } // Modify the block to have a parent beacon root - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() header := types.CopyHeader(ethBlock.Header()) header.ParentBeaconRoot = test.beaconRoot parentBeaconEthBlock := ethBlock.WithSeal(header) @@ -3497,7 +3498,7 @@ func TestFeeManagerRegressionMempoolMinFeeAfterRestart(t *testing.T) { require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) // check that the fee config is updated - block := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() feeConfig, lastChangedAt, err = restartedVM.blockChain.GetFeeConfigAt(block.Header()) require.NoError(t, err) require.Equal(t, restartedVM.blockChain.CurrentBlock().Number, lastChangedAt) diff --git a/plugin/evm/vm_warp_test.go b/plugin/evm/vm_warp_test.go index 4f7aea2385..0461a2bd62 100644 --- a/plugin/evm/vm_warp_test.go +++ b/plugin/evm/vm_warp_test.go @@ -43,6 +43,7 @@ import ( "github.com/ava-labs/subnet-evm/params/extras" "github.com/ava-labs/subnet-evm/params/paramstest" "github.com/ava-labs/subnet-evm/plugin/evm/customheader" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" "github.com/ava-labs/subnet-evm/precompile/contract" "github.com/ava-labs/subnet-evm/utils" "github.com/ava-labs/subnet-evm/warp" @@ -139,7 +140,7 @@ func testSendWarpMessage(t *testing.T, scheme string) { require.NoError(blk.Verify(context.Background())) // Verify that the constructed block contains the expected log with an unsigned warp message in the log data - ethBlock1 := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock1 := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Len(ethBlock1.Transactions(), 1) receipts := rawdb.ReadReceipts(tvm.vm.chaindb, ethBlock1.Hash(), ethBlock1.NumberU64(), ethBlock1.Time(), tvm.vm.chainConfig) require.Len(receipts, 1) @@ -466,7 +467,7 @@ func testWarpVMTransaction(t *testing.T, scheme string, unsignedMessage *avalanc require.NoError(warpBlock.Accept(context.Background())) tvm.vm.blockChain.DrainAcceptorQueue() - ethBlock := warpBlock.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := warpBlock.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() verifiedMessageReceipts := tvm.vm.blockChain.GetReceiptsByHash(ethBlock.Hash()) require.Len(verifiedMessageReceipts, 2) for i, receipt := range verifiedMessageReceipts { @@ -756,7 +757,7 @@ func testReceiveWarpMessage( require.NoError(err) // Require the block was built with a successful predicate result - ethBlock := block2.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := block2.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() headerPredicateResultsBytes := customheader.PredicateBytesFromExtra(ethBlock.Extra()) blockResults, err := predicate.ParseBlockResults(headerPredicateResultsBytes) require.NoError(err) diff --git a/plugin/evm/wrapped_block.go b/plugin/evm/wrapped_block.go index 4fb4ab76de..2d43402379 100644 --- a/plugin/evm/wrapped_block.go +++ b/plugin/evm/wrapped_block.go @@ -26,12 +26,14 @@ import ( "github.com/ava-labs/subnet-evm/params/extras" "github.com/ava-labs/subnet-evm/plugin/evm/customheader" "github.com/ava-labs/subnet-evm/plugin/evm/customtypes" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" "github.com/ava-labs/subnet-evm/precompile/precompileconfig" ) var ( _ snowman.Block = (*wrappedBlock)(nil) _ block.WithVerifyContext = (*wrappedBlock)(nil) + _ extension.ExtendedBlock = (*wrappedBlock)(nil) errMissingParentBlock = errors.New("missing parent block") errInvalidGasUsedRelativeToCapacity = errors.New("invalid gas used relative to capacity") @@ -53,18 +55,26 @@ var ( // wrappedBlock implements the snowman.wrappedBlock interface type wrappedBlock struct { - id ids.ID - ethBlock *types.Block - vm *VM + id ids.ID + ethBlock *types.Block + extension extension.BlockExtension + vm *VM } // wrapBlock returns a new Block wrapping the ethBlock type and implementing the snowman.Block interface -func wrapBlock(ethBlock *types.Block, vm *VM) (*wrappedBlock, error) { //nolint:unparam // this just makes the function compatible with the future syncs I'll do, it's temporary!! +func wrapBlock(ethBlock *types.Block, vm *VM) (*wrappedBlock, error) { b := &wrappedBlock{ id: ids.ID(ethBlock.Hash()), ethBlock: ethBlock, vm: vm, } + if vm.extensionConfig.BlockExtender != nil { + extension, err := vm.extensionConfig.BlockExtender.NewBlockExtension(b) + if err != nil { + return nil, fmt.Errorf("failed to create block extension: %w", err) + } + b.extension = extension + } return b, nil } @@ -419,6 +429,11 @@ func (b *wrappedBlock) syntacticVerify() error { } } + if b.extension != nil { + if err := b.extension.SyntacticVerify(*rulesExtra); err != nil { + return err + } + } return nil } @@ -466,3 +481,5 @@ func (b *wrappedBlock) Bytes() []byte { func (b *wrappedBlock) String() string { return fmt.Sprintf("EVM block, ID = %s", b.ID()) } func (b *wrappedBlock) GetEthBlock() *types.Block { return b.ethBlock } + +func (b *wrappedBlock) GetBlockExtension() extension.BlockExtension { return b.extension } diff --git a/sync/README.md b/sync/README.md index b96c4bcb8b..97efdaa305 100644 --- a/sync/README.md +++ b/sync/README.md @@ -42,7 +42,7 @@ When a new node wants to join the network via state sync, it will need a few pie - Number (height) and hash of the latest available syncable block, - Root of the account trie, -The above information is called a _state summary_, and each syncable block corresponds to one such summary (see `message.SyncSummary`). The engine and VM interact as follows to find a syncable state summary: +The above information is called a _state summary_, and each syncable block corresponds to one such summary (see `message.Summary`). The engine and VM interact as follows to find a syncable state summary: 1. The engine calls `StateSyncEnabled`. The VM returns `true` to initiate state sync, or `false` to start bootstrapping. In `subnet-evm`, this is controlled by the `state-sync-enabled` flag. diff --git a/sync/client/client_test.go b/sync/client/client_test.go index 095e52000e..f83d938294 100644 --- a/sync/client/client_test.go +++ b/sync/client/client_test.go @@ -97,7 +97,7 @@ func TestGetCode(t *testing.T) { Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) for name, test := range tests { @@ -164,7 +164,7 @@ func TestGetBlocks(t *testing.T) { Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) blocksRequestHandler := handlers.NewBlockRequestHandler(buildGetter(blocks), message.Codec, handlerstats.NewNoopHandlerStats()) @@ -424,7 +424,7 @@ func TestGetLeafs(t *testing.T) { Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) tests := map[string]struct { @@ -789,7 +789,7 @@ func TestGetLeafsRetries(t *testing.T) { Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) request := message.LeafsRequest{ @@ -849,7 +849,7 @@ func TestStateSyncNodes(t *testing.T) { Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: stateSyncNodes, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/sync/client/mock_client.go b/sync/client/test_client.go similarity index 79% rename from sync/client/mock_client.go rename to sync/client/test_client.go index 0b18b4c6b7..15b8f4ab26 100644 --- a/sync/client/mock_client.go +++ b/sync/client/test_client.go @@ -19,12 +19,11 @@ import ( ) var ( - _ Client = &MockClient{} - mockBlockParser EthBlockParser = &testBlockParser{} + _ Client = (*TestClient)(nil) + _ EthBlockParser = (*testBlockParser)(nil) ) -// TODO replace with gomock library -type MockClient struct { +type TestClient struct { codec codec.Manager leafsHandler handlers.LeafRequestHandler leavesReceived int32 @@ -33,23 +32,23 @@ type MockClient struct { blocksHandler *handlers.BlockRequestHandler blocksReceived int32 // GetLeafsIntercept is called on every GetLeafs request if set to a non-nil callback. - // The returned response will be returned by MockClient to the caller. + // The returned response will be returned by TestClient to the caller. GetLeafsIntercept func(req message.LeafsRequest, res message.LeafsResponse) (message.LeafsResponse, error) // GetCodesIntercept is called on every GetCode request if set to a non-nil callback. - // The returned response will be returned by MockClient to the caller. + // The returned response will be returned by TestClient to the caller. GetCodeIntercept func(hashes []common.Hash, codeBytes [][]byte) ([][]byte, error) // GetBlocksIntercept is called on every GetBlocks request if set to a non-nil callback. - // The returned response will be returned by MockClient to the caller. + // The returned response will be returned by TestClient to the caller. GetBlocksIntercept func(blockReq message.BlockRequest, blocks types.Blocks) (types.Blocks, error) } -func NewMockClient( +func NewTestClient( codec codec.Manager, leafsHandler handlers.LeafRequestHandler, codesHandler *handlers.CodeRequestHandler, blocksHandler *handlers.BlockRequestHandler, -) *MockClient { - return &MockClient{ +) *TestClient { + return &TestClient{ codec: codec, leafsHandler: leafsHandler, codesHandler: codesHandler, @@ -57,7 +56,7 @@ func NewMockClient( } } -func (ml *MockClient) GetLeafs(ctx context.Context, request message.LeafsRequest) (message.LeafsResponse, error) { +func (ml *TestClient) GetLeafs(ctx context.Context, request message.LeafsRequest) (message.LeafsResponse, error) { response, err := ml.leafsHandler.OnLeafsRequest(ctx, ids.GenerateTestNodeID(), 1, request) if err != nil { return message.LeafsResponse{}, err @@ -71,18 +70,18 @@ func (ml *MockClient) GetLeafs(ctx context.Context, request message.LeafsRequest if ml.GetLeafsIntercept != nil { leafsResponse, err = ml.GetLeafsIntercept(request, leafsResponse) } - // Increment the number of leaves received by the mock client + // Increment the number of leaves received by the test client atomic.AddInt32(&ml.leavesReceived, int32(numLeaves)) return leafsResponse, err } -func (ml *MockClient) LeavesReceived() int32 { +func (ml *TestClient) LeavesReceived() int32 { return atomic.LoadInt32(&ml.leavesReceived) } -func (ml *MockClient) GetCode(ctx context.Context, hashes []common.Hash) ([][]byte, error) { +func (ml *TestClient) GetCode(ctx context.Context, hashes []common.Hash) ([][]byte, error) { if ml.codesHandler == nil { - panic("no code handler for mock client") + panic("no code handler for test client") } request := message.CodeRequest{Hashes: hashes} response, err := ml.codesHandler.OnCodeRequest(ctx, ids.GenerateTestNodeID(), 1, request) @@ -104,13 +103,13 @@ func (ml *MockClient) GetCode(ctx context.Context, hashes []common.Hash) ([][]by return code, err } -func (ml *MockClient) CodeReceived() int32 { +func (ml *TestClient) CodeReceived() int32 { return atomic.LoadInt32(&ml.codeReceived) } -func (ml *MockClient) GetBlocks(ctx context.Context, blockHash common.Hash, height uint64, numParents uint16) ([]*types.Block, error) { +func (ml *TestClient) GetBlocks(ctx context.Context, blockHash common.Hash, height uint64, numParents uint16) ([]*types.Block, error) { if ml.blocksHandler == nil { - panic("no blocks handler for mock client") + panic("no blocks handler for test client") } request := message.BlockRequest{ Hash: blockHash, @@ -122,7 +121,7 @@ func (ml *MockClient) GetBlocks(ctx context.Context, blockHash common.Hash, heig return nil, err } - client := &client{blockParser: mockBlockParser} // Hack to avoid duplicate code + client := &client{blockParser: newTestBlockParser()} // Hack to avoid duplicate code blocksRes, numBlocks, err := client.parseBlocks(ml.codec, request, response) if err != nil { return nil, err @@ -135,12 +134,16 @@ func (ml *MockClient) GetBlocks(ctx context.Context, blockHash common.Hash, heig return blocks, err } -func (ml *MockClient) BlocksReceived() int32 { +func (ml *TestClient) BlocksReceived() int32 { return atomic.LoadInt32(&ml.blocksReceived) } type testBlockParser struct{} +func newTestBlockParser() *testBlockParser { + return &testBlockParser{} +} + func (*testBlockParser) ParseEthBlock(b []byte) (*types.Block, error) { block := new(types.Block) if err := rlp.DecodeBytes(b, block); err != nil { diff --git a/sync/handlers/leafs_request.go b/sync/handlers/leafs_request.go index 495d07a320..111d56f924 100644 --- a/sync/handlers/leafs_request.go +++ b/sync/handlers/leafs_request.go @@ -93,7 +93,6 @@ func (lrh *leafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N lrh.stats.IncInvalidLeafsRequest() return nil, nil } - // TODO: We should know the state root that accounts correspond to, // as this information will be necessary to access storage tries when // the trie is path based. @@ -109,7 +108,6 @@ func (lrh *leafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N if limit > maxLeavesLimit { limit = maxLeavesLimit } - var leafsResponse message.LeafsResponse leafsResponse.Keys = make([][]byte, 0, limit) leafsResponse.Vals = make([][]byte, 0, limit) @@ -127,7 +125,6 @@ func (lrh *leafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N responseBuilder.snap = lrh.snapshotProvider.Snapshots() } err = responseBuilder.handleRequest(ctx) - // ensure metrics are captured properly on all return paths defer func() { lrh.stats.UpdateLeafsRequestProcessingTime(time.Since(startTime)) @@ -144,13 +141,11 @@ func (lrh *leafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N log.Debug("context err set before any leafs were iterated", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "ctxErr", ctx.Err()) return nil, nil } - responseBytes, err := lrh.codec.Marshal(message.Version, leafsResponse) if err != nil { log.Debug("failed to marshal LeafsResponse, dropping request", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "err", err) return nil, nil } - log.Debug("handled leafsRequest", "time", time.Since(startTime), "leafs", len(leafsResponse.Keys), "proofLen", len(leafsResponse.ProofVals)) return responseBytes, nil } diff --git a/sync/statesync/code_syncer_test.go b/sync/statesync/code_syncer_test.go index 86d0b06106..f5d9951255 100644 --- a/sync/statesync/code_syncer_test.go +++ b/sync/statesync/code_syncer_test.go @@ -45,7 +45,7 @@ func testCodeSyncer(t *testing.T, test codeSyncerTest) { // Set up mockClient codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, message.Codec, handlerstats.NewNoopHandlerStats()) - mockClient := statesyncclient.NewMockClient(message.Codec, nil, codeRequestHandler, nil) + mockClient := statesyncclient.NewTestClient(message.Codec, nil, codeRequestHandler, nil) mockClient.GetCodeIntercept = test.getCodeIntercept clientDB := rawdb.NewMemoryDatabase() diff --git a/sync/statesync/statesynctest/test_sync.go b/sync/statesync/statesynctest/test_sync.go index 646e6d1c0a..ab1c507261 100644 --- a/sync/statesync/statesynctest/test_sync.go +++ b/sync/statesync/statesynctest/test_sync.go @@ -90,26 +90,6 @@ func AssertDBConsistency(t testing.TB, root common.Hash, clientDB ethdb.Database assert.Equal(t, trieAccountLeaves, numSnapshotAccounts) } -func FillAccountsWithStorage(t *testing.T, r *rand.Rand, serverDB ethdb.Database, serverTrieDB *triedb.Database, root common.Hash, numAccounts int) common.Hash { - newRoot, _ := FillAccounts(t, r, serverTrieDB, root, numAccounts, func(t *testing.T, _ int, account types.StateAccount) types.StateAccount { - codeBytes := make([]byte, 256) - _, err := r.Read(codeBytes) - if err != nil { - t.Fatalf("error reading random code bytes: %v", err) - } - - codeHash := crypto.Keccak256Hash(codeBytes) - rawdb.WriteCode(serverDB, codeHash, codeBytes) - account.CodeHash = codeHash[:] - - // now create state trie - numKeys := 16 - account.Root, _, _ = GenerateTrie(t, r, serverTrieDB, numKeys, common.HashLength) - return account - }) - return newRoot -} - // FillAccountsWithOverlappingStorage adds [numAccounts] randomly generated accounts to the secure trie at [root] // and commits it to [trieDB]. For each 3 accounts created: // - One does not have a storage trie, diff --git a/sync/statesync/sync_test.go b/sync/statesync/sync_test.go index 9a0004988d..8bf3138e7a 100644 --- a/sync/statesync/sync_test.go +++ b/sync/statesync/sync_test.go @@ -21,9 +21,11 @@ import ( "github.com/ava-labs/libevm/rlp" "github.com/ava-labs/libevm/trie" "github.com/ava-labs/libevm/triedb" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/core/state/snapshot" + "github.com/ava-labs/subnet-evm/plugin/evm/customrawdb" "github.com/ava-labs/subnet-evm/plugin/evm/message" "github.com/ava-labs/subnet-evm/sync/handlers" "github.com/ava-labs/subnet-evm/sync/statesync/statesynctest" @@ -54,7 +56,7 @@ func testSync(t *testing.T, test syncTest) { clientDB, serverDB, serverTrieDB, root := test.prepareForTest(t, r) leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, message.StateTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()) codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, message.Codec, handlerstats.NewNoopHandlerStats()) - mockClient := statesyncclient.NewMockClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) + mockClient := statesyncclient.NewTestClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) // Set intercept functions for the mock client mockClient.GetLeafsIntercept = test.GetLeafsIntercept mockClient.GetCodeIntercept = test.GetCodeIntercept @@ -69,14 +71,17 @@ func testSync(t *testing.T, test syncTest) { RequestSize: 1024, }) require.NoError(t, err, "failed to create state syncer") - // begin sync - s.Start(ctx) + + require.NoError(t, s.Start(ctx), "failed to start state syncer") + waitFor(t, context.Background(), s.Wait, test.expectedError, testSyncTimeout) + + // Only assert database consistency if the sync was expected to succeed. if test.expectedError != nil { return } - statesynctest.AssertDBConsistency(t, root, clientDB, serverTrieDB, triedb.NewDatabase(clientDB, nil)) + assertDBConsistency(t, root, clientDB, serverTrieDB, triedb.NewDatabase(clientDB, nil)) } // testSyncResumes tests a series of syncTests work as expected, invoking a callback function after each @@ -144,7 +149,7 @@ func TestSimpleSyncCases(t *testing.T) { prepareForTest: func(t *testing.T, r *rand.Rand) (ethdb.Database, ethdb.Database, *triedb.Database, common.Hash) { serverDB := rawdb.NewMemoryDatabase() serverTrieDB := triedb.NewDatabase(serverDB, nil) - root := statesynctest.FillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, numAccounts) + root := fillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, numAccounts) return rawdb.NewMemoryDatabase(), serverDB, serverTrieDB, root }, }, @@ -186,7 +191,7 @@ func TestSimpleSyncCases(t *testing.T) { prepareForTest: func(t *testing.T, r *rand.Rand) (ethdb.Database, ethdb.Database, *triedb.Database, common.Hash) { serverDB := rawdb.NewMemoryDatabase() serverTrieDB := triedb.NewDatabase(serverDB, nil) - root := statesynctest.FillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, numAccountsSmall) + root := fillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, numAccountsSmall) return rawdb.NewMemoryDatabase(), serverDB, serverTrieDB, root }, GetCodeIntercept: func(_ []common.Hash, _ [][]byte) ([][]byte, error) { @@ -208,7 +213,7 @@ func TestCancelSync(t *testing.T) { serverDB := rawdb.NewMemoryDatabase() serverTrieDB := triedb.NewDatabase(serverDB, nil) // Create trie with 2000 accounts (more than one leaf request) - root := statesynctest.FillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, 2000) + root := fillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, 2000) ctx, cancel := context.WithCancel(context.Background()) defer cancel() testSync(t, syncTest{ @@ -516,12 +521,98 @@ func testSyncerSyncsToNewRoot(t *testing.T, deleteBetweenSyncs func(*testing.T, }) } +// assertDBConsistency checks [serverTrieDB] and [clientTrieDB] have the same EVM state trie at [root], +// and that [clientTrieDB.DiskDB] has corresponding account & snapshot values. +// Also verifies any code referenced by the EVM state is present in [clientTrieDB] and the hash is correct. +func assertDBConsistency(t testing.TB, root common.Hash, clientDB ethdb.Database, serverTrieDB, clientTrieDB *triedb.Database) { + numSnapshotAccounts := 0 + accountIt := customrawdb.IterateAccountSnapshots(clientDB) + defer accountIt.Release() + for accountIt.Next() { + if !bytes.HasPrefix(accountIt.Key(), rawdb.SnapshotAccountPrefix) || len(accountIt.Key()) != len(rawdb.SnapshotAccountPrefix)+common.HashLength { + continue + } + numSnapshotAccounts++ + } + if err := accountIt.Error(); err != nil { + t.Fatal(err) + } + trieAccountLeaves := 0 + + statesynctest.AssertTrieConsistency(t, root, serverTrieDB, clientTrieDB, func(key, val []byte) error { + trieAccountLeaves++ + accHash := common.BytesToHash(key) + var acc types.StateAccount + if err := rlp.DecodeBytes(val, &acc); err != nil { + return err + } + // check snapshot consistency + snapshotVal := rawdb.ReadAccountSnapshot(clientDB, accHash) + expectedSnapshotVal := types.SlimAccountRLP(acc) + assert.Equal(t, expectedSnapshotVal, snapshotVal) + + // check code consistency + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash[:]) { + codeHash := common.BytesToHash(acc.CodeHash) + code := rawdb.ReadCode(clientDB, codeHash) + actualHash := crypto.Keccak256Hash(code) + assert.NotEmpty(t, code) + assert.Equal(t, codeHash, actualHash) + } + if acc.Root == types.EmptyRootHash { + return nil + } + + storageIt := rawdb.IterateStorageSnapshots(clientDB, accHash) + defer storageIt.Release() + + snapshotStorageKeysCount := 0 + for storageIt.Next() { + snapshotStorageKeysCount++ + } + + storageTrieLeavesCount := 0 + + // check storage trie and storage snapshot consistency + statesynctest.AssertTrieConsistency(t, acc.Root, serverTrieDB, clientTrieDB, func(key, val []byte) error { + storageTrieLeavesCount++ + snapshotVal := rawdb.ReadStorageSnapshot(clientDB, accHash, common.BytesToHash(key)) + assert.Equal(t, val, snapshotVal) + return nil + }) + + assert.Equal(t, storageTrieLeavesCount, snapshotStorageKeysCount) + return nil + }) + + // Check that the number of accounts in the snapshot matches the number of leaves in the accounts trie + assert.Equal(t, trieAccountLeaves, numSnapshotAccounts) +} + +func fillAccountsWithStorage(t *testing.T, r *rand.Rand, serverDB ethdb.Database, serverTrieDB *triedb.Database, root common.Hash, numAccounts int) common.Hash { //nolint:unparam + newRoot, _ := statesynctest.FillAccounts(t, r, serverTrieDB, root, numAccounts, func(_ *testing.T, _ int, account types.StateAccount) types.StateAccount { + codeBytes := make([]byte, 256) + _, err := r.Read(codeBytes) + require.NoError(t, err, "error reading random code bytes") + + codeHash := crypto.Keccak256Hash(codeBytes) + rawdb.WriteCode(serverDB, codeHash, codeBytes) + account.CodeHash = codeHash[:] + + // now create state trie + numKeys := 16 + account.Root, _, _ = statesynctest.GenerateTrie(t, r, serverTrieDB, numKeys, common.HashLength) + return account + }) + return newRoot +} + func TestDifferentWaitContext(t *testing.T) { r := rand.New(rand.NewSource(1)) serverDB := rawdb.NewMemoryDatabase() serverTrieDB := triedb.NewDatabase(serverDB, nil) // Create trie with many accounts to ensure sync takes time - root := statesynctest.FillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, 2000) + root := fillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, 2000) clientDB := rawdb.NewMemoryDatabase() // Track requests to show sync continues after Wait returns @@ -529,7 +620,7 @@ func TestDifferentWaitContext(t *testing.T) { leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, message.StateTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()) codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, message.Codec, handlerstats.NewNoopHandlerStats()) - mockClient := statesyncclient.NewMockClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) + mockClient := statesyncclient.NewTestClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) // Intercept to track ongoing requests and add delay mockClient.GetLeafsIntercept = func(_ message.LeafsRequest, resp message.LeafsResponse) (message.LeafsResponse, error) {