diff --git a/sharding/mainchain/interfaces.go b/sharding/mainchain/interfaces.go new file mode 100644 index 0000000000..73c93dbd1f --- /dev/null +++ b/sharding/mainchain/interfaces.go @@ -0,0 +1,46 @@ +package mainchain + +import ( + "context" + "math/big" + + ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/sharding/contracts" +) + +// Signer defines an interface that can read from the Ethereum mainchain as well call +// read-only methods and functions from the Sharding Manager Contract. +type Signer interface { + Sign(hash common.Hash) ([]byte, error) +} + +// ContractManager specifies an interface that defines both read/write +// operations on a contract in the Ethereum mainchain. +type ContractManager interface { + ContractCaller + ContractTransactor +} + +// ContractCaller defines an interface that can read from a contract on the +// Ethereum mainchain as well as call its read-only methods and functions. +type ContractCaller interface { + SMCCaller() *contracts.SMCCaller + GetShardCount() (int64, error) +} + +// ContractTransactor defines an interface that can transact with a contract on the +// Ethereum mainchain as well as call its methods and functions. +type ContractTransactor interface { + SMCTransactor() *contracts.SMCTransactor + CreateTXOpts(value *big.Int) (*bind.TransactOpts, error) +} + +// Reader defines an interface for a struct that can read mainchain information +// such as blocks, transactions, receipts, and more. Useful for testing. +type Reader interface { + BlockByNumber(ctx context.Context, number *big.Int) (*types.Block, error) + SubscribeNewHead(ctx context.Context, ch chan<- *types.Header) (ethereum.Subscription, error) +} diff --git a/sharding/mainchain/smc_client.go b/sharding/mainchain/smc_client.go index 01eb2ee0a2..9c02a6ba7c 100644 --- a/sharding/mainchain/smc_client.go +++ b/sharding/mainchain/smc_client.go @@ -29,23 +29,6 @@ import ( // ClientIdentifier tells us what client the node we interact with over RPC is running. const ClientIdentifier = "geth" -// Client contains useful methods for a sharding node to interact with -// an Ethereum client running on the mainchain. -type Client interface { - Account() *accounts.Account - CreateTXOpts(value *big.Int) (*bind.TransactOpts, error) - SMCCaller() *contracts.SMCCaller - SMCTransactor() *contracts.SMCTransactor - SMCFilterer() *contracts.SMCFilterer - TransactionReceipt(common.Hash) (*types.Receipt, error) - ChainReader() ethereum.ChainReader - DepositFlag() bool - SetDepositFlag(deposit bool) - DataDirPath() string - Sign(hash common.Hash) ([]byte, error) - GetShardCount() (int64, error) -} - // SMCClient defines a struct that interacts with a // mainchain node via RPC. Specifically, it aids in SMC bindings that are useful // to other sharding services. diff --git a/sharding/mainchain/smc_client_test.go b/sharding/mainchain/smc_client_test.go index d13db9b41e..5cf0a4a395 100644 --- a/sharding/mainchain/smc_client_test.go +++ b/sharding/mainchain/smc_client_test.go @@ -2,8 +2,5 @@ package mainchain import "github.com/ethereum/go-ethereum/sharding" -// Verifies that SMCClient implements the Client interface. -var _ = Client(&SMCClient{}) - // Verifies that SMCCLient implements the sharding Service inteface. var _ = sharding.Service(&SMCClient{}) diff --git a/sharding/notary/notary.go b/sharding/notary/notary.go index ddd2507bee..5d2e463a24 100644 --- a/sharding/notary/notary.go +++ b/sharding/notary/notary.go @@ -2,10 +2,10 @@ package notary import ( "context" - "errors" "fmt" "math/big" + "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" @@ -14,15 +14,15 @@ import ( shardparams "github.com/ethereum/go-ethereum/sharding/params" ) -// SubscribeBlockHeaders checks incoming block headers and determines if +// subscribeBlockHeaders checks incoming block headers and determines if // we are an eligible notary for collations. Then, it finds the pending tx's // from the running geth node and sorts them by descending order of gas price, // eliminates those that ask for too much gas, and routes them over // to the SMC to create a collation. -func subscribeBlockHeaders(client mainchain.Client) error { +func subscribeBlockHeaders(reader mainchain.Reader, caller mainchain.ContractCaller, account *accounts.Account) error { headerChan := make(chan *types.Header, 16) - _, err := client.ChainReader().SubscribeNewHead(context.Background(), headerChan) + _, err := reader.SubscribeNewHead(context.Background(), headerChan) if err != nil { return fmt.Errorf("unable to subscribe to incoming headers. %v", err) } @@ -36,13 +36,13 @@ func subscribeBlockHeaders(client mainchain.Client) error { log.Info(fmt.Sprintf("Received new header: %v", head.Number.String())) // Check if we are in the notary pool before checking if we are an eligible notary. - v, err := isAccountInNotaryPool(client) + v, err := isAccountInNotaryPool(caller, account) if err != nil { return fmt.Errorf("unable to verify client in notary pool. %v", err) } if v { - if err := checkSMCForNotary(client, head); err != nil { + if err := checkSMCForNotary(caller, account, head); err != nil { return fmt.Errorf("unable to watch shards. %v", err) } } @@ -53,22 +53,22 @@ func subscribeBlockHeaders(client mainchain.Client) error { // collation for the available shards in the SMC. The function calls // getEligibleNotary from the SMC and notary a collation if // conditions are met. -func checkSMCForNotary(client mainchain.Client, head *types.Header) error { +func checkSMCForNotary(caller mainchain.ContractCaller, account *accounts.Account, head *types.Header) error { log.Info("Checking if we are an eligible collation notary for a shard...") - shardCount, err := client.GetShardCount() + shardCount, err := caller.GetShardCount() if err != nil { return fmt.Errorf("can't get shard count from smc: %v", err) } for s := int64(0); s < shardCount; s++ { // Checks if we are an eligible notary according to the SMC. - addr, err := client.SMCCaller().GetNotaryInCommittee(&bind.CallOpts{}, big.NewInt(s)) + addr, err := caller.SMCCaller().GetNotaryInCommittee(&bind.CallOpts{}, big.NewInt(s)) if err != nil { return err } // If output is non-empty and the addr == coinbase. - if addr == client.Account().Address { + if addr == account.Address { log.Info(fmt.Sprintf("Selected as notary on shard: %d", s)) err := submitCollation(s) if err != nil { @@ -76,7 +76,7 @@ func checkSMCForNotary(client mainchain.Client, head *types.Header) error { } // If the account is selected as notary, submit collation. - if addr == client.Account().Address { + if addr == account.Address { log.Info(fmt.Sprintf("Selected as notary on shard: %d", s)) err := submitCollation(s) if err != nil { @@ -93,10 +93,9 @@ func checkSMCForNotary(client mainchain.Client, head *types.Header) error { // we can't guarantee our tx for deposit will be in the next block header we receive. // The function calls IsNotaryDeposited from the SMC and returns true if // the user is in the notary pool. -func isAccountInNotaryPool(client mainchain.Client) (bool, error) { - account := client.Account() +func isAccountInNotaryPool(caller mainchain.ContractCaller, account *accounts.Account) (bool, error) { // Checks if our deposit has gone through according to the SMC. - nreg, err := client.SMCCaller().NotaryRegistry(&bind.CallOpts{}, account.Address) + nreg, err := caller.SMCCaller().NotaryRegistry(&bind.CallOpts{}, account.Address) if !nreg.Deposited && err != nil { log.Warn(fmt.Sprintf("Account %s not in notary pool.", account.Address.String())) } @@ -136,22 +135,18 @@ func submitCollation(shardID int64) error { // joinNotaryPool checks if the deposit flag is true and the account is a // notary in the SMC. If the account is not in the set, it will deposit ETH // into contract. -func joinNotaryPool(config *shardparams.Config, client mainchain.Client) error { - if !client.DepositFlag() { - return errors.New("joinNotaryPool called when deposit flag was not set") - } - - if b, err := isAccountInNotaryPool(client); b || err != nil { +func joinNotaryPool(manager mainchain.ContractManager, account *accounts.Account, config *shardparams.Config) error { + if b, err := isAccountInNotaryPool(manager, account); b || err != nil { return err } log.Info("Joining notary pool") - txOps, err := client.CreateTXOpts(shardparams.DefaultConfig.NotaryDeposit) + txOps, err := manager.CreateTXOpts(shardparams.DefaultConfig.NotaryDeposit) if err != nil { return fmt.Errorf("unable to initiate the deposit transaction: %v", err) } - tx, err := client.SMCTransactor().RegisterNotary(txOps) + tx, err := manager.SMCTransactor().RegisterNotary(txOps) if err != nil { return fmt.Errorf("unable to deposit eth and become a notary: %v", err) } diff --git a/sharding/notary/service.go b/sharding/notary/service.go index fb2d280e3e..eae73a96eb 100644 --- a/sharding/notary/service.go +++ b/sharding/notary/service.go @@ -39,17 +39,20 @@ func (n *Notary) Stop() error { return nil } +// notarizeCollations checks incoming block headers and determines if +// we are an eligible notary for collations. func (n *Notary) notarizeCollations() { + // TODO: handle this better through goroutines. Right now, these methods // are blocking. if n.smcClient.DepositFlag() { - if err := joinNotaryPool(n.config, n.smcClient); err != nil { + if err := joinNotaryPool(n.smcClient, n.smcClient.Account(), n.config); err != nil { log.Error(fmt.Sprintf("Could not fetch current block number: %v", err)) return } } - if err := subscribeBlockHeaders(n.smcClient); err != nil { + if err := subscribeBlockHeaders(n.smcClient.ChainReader(), n.smcClient, n.smcClient.Account()); err != nil { log.Error(fmt.Sprintf("Could not fetch current block number: %v", err)) return } diff --git a/sharding/notary/service_test.go b/sharding/notary/service_test.go index ae016d4c66..fca4e0f112 100644 --- a/sharding/notary/service_test.go +++ b/sharding/notary/service_test.go @@ -8,14 +8,11 @@ import ( "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind/backends" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/sharding" "github.com/ethereum/go-ethereum/sharding/contracts" "github.com/ethereum/go-ethereum/sharding/params" - cli "gopkg.in/urfave/cli.v1" ) var ( @@ -46,54 +43,16 @@ func (s *smcClient) ChainReader() ethereum.ChainReader { return nil } -func (s *smcClient) Context() *cli.Context { - return nil -} - func (s *smcClient) SMCTransactor() *contracts.SMCTransactor { return &s.smc.SMCTransactor } -func (s *smcClient) SMCFilterer() *contracts.SMCFilterer { - return &s.smc.SMCFilterer -} - -func (s *smcClient) TransactionReceipt(hash common.Hash) (*types.Receipt, error) { - return nil, nil -} - func (s *smcClient) CreateTXOpts(value *big.Int) (*bind.TransactOpts, error) { txOpts := transactOpts() txOpts.Value = value return txOpts, nil } -func (s *smcClient) DepositFlag() bool { - return s.depositFlag -} - -func (s *smcClient) SetDepositFlag(deposit bool) { - s.depositFlag = deposit -} - -func (m *smcClient) Sign(hash common.Hash) ([]byte, error) { - return nil, nil -} - -// Unused mockClient methods. -func (m *smcClient) Start() error { - m.t.Fatal("Start called") - return nil -} - -func (m *smcClient) Close() { - m.t.Fatal("Close called") -} - -func (s *smcClient) DataDirPath() string { - return "/tmp/datadir" -} - func (s *smcClient) GetShardCount() (int64, error) { return 100, nil } @@ -117,7 +76,7 @@ func TestIsAccountInNotaryPool(t *testing.T) { client := &smcClient{smc: smc, t: t} // address should not be in pool initially. - b, err := isAccountInNotaryPool(client) + b, err := isAccountInNotaryPool(client, client.Account()) if err != nil { t.Fatal(err) } @@ -132,7 +91,7 @@ func TestIsAccountInNotaryPool(t *testing.T) { t.Fatalf("Failed to deposit: %v", err) } backend.Commit() - b, err = isAccountInNotaryPool(client) + b, err = isAccountInNotaryPool(client, client.Account()) if err != nil { t.Fatal(err) } @@ -153,14 +112,7 @@ func TestJoinNotaryPool(t *testing.T) { t.Fatalf("Unexpected number of notaries. Got %d, wanted 0.", numNotaries) } - client.SetDepositFlag(false) - err = joinNotaryPool(params.DefaultConfig, client) - if err == nil { - t.Error("Joined notary pool while --deposit was not present") - } - - client.SetDepositFlag(true) - err = joinNotaryPool(params.DefaultConfig, client) + err = joinNotaryPool(client, client.Account(), params.DefaultConfig) if err != nil { t.Fatal(err) } @@ -176,7 +128,7 @@ func TestJoinNotaryPool(t *testing.T) { } // Trying to join while deposited should do nothing - err = joinNotaryPool(params.DefaultConfig, client) + err = joinNotaryPool(client, client.Account(), params.DefaultConfig) if err != nil { t.Error(err) } diff --git a/sharding/proposer/proposer.go b/sharding/proposer/proposer.go index aaf56df4b5..0a2b3b6450 100644 --- a/sharding/proposer/proposer.go +++ b/sharding/proposer/proposer.go @@ -4,6 +4,7 @@ import ( "fmt" "math/big" + "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" @@ -15,9 +16,9 @@ import ( // and body. Header consists of shardID, ChunkRoot, period, // proposer addr and signatures. Body contains serialized blob // of a collations transactions. -func createCollation(client mainchain.Client, shardID *big.Int, period *big.Int, txs []*types.Transaction) (*sharding.Collation, error) { +func createCollation(caller mainchain.ContractCaller, account *accounts.Account, signer mainchain.Signer, shardID *big.Int, period *big.Int, txs []*types.Transaction) (*sharding.Collation, error) { // shardId has to be within range - shardCount, err := client.GetShardCount() + shardCount, err := caller.GetShardCount() if err != nil { return nil, fmt.Errorf("can't get shard count from smc: %v", err) } @@ -26,7 +27,7 @@ func createCollation(client mainchain.Client, shardID *big.Int, period *big.Int, } // check with SMC to see if we can add the header. - if a, _ := checkHeaderAdded(client, shardID, period); !a { + if a, _ := checkHeaderAdded(caller, shardID, period); !a { return nil, fmt.Errorf("can't create collation, collation with same period has already been added") } @@ -37,13 +38,13 @@ func createCollation(client mainchain.Client, shardID *big.Int, period *big.Int, } // construct the header, leave chunkRoot and signature fields empty, to be filled later. - addr := client.Account().Address + addr := account.Address header := sharding.NewCollationHeader(shardID, nil, period, &addr, nil) // construct the body with header, blobs(serialized txs) and txs. collation := sharding.NewCollation(header, blobs, txs) collation.CalculateChunkRoot() - sig, err := client.Sign(collation.Header().Hash()) + sig, err := signer.Sign(collation.Header().Hash()) if err != nil { return nil, fmt.Errorf("can't create collation, sign collationHeader failed: %v", err) } @@ -58,10 +59,10 @@ func createCollation(client mainchain.Client, shardID *big.Int, period *big.Int, // an addHeader transaction to the sharding manager contract. // There can only exist one header per period per shard, it's proposer's // responsibility to check if a header has been added. -func addHeader(client mainchain.Client, collation *sharding.Collation) error { +func addHeader(transactor mainchain.ContractTransactor, collation *sharding.Collation) error { log.Info("Adding header to SMC") - txOps, err := client.CreateTXOpts(big.NewInt(0)) + txOps, err := transactor.CreateTXOpts(big.NewInt(0)) if err != nil { return fmt.Errorf("unable to initiate add header transaction: %v", err) } @@ -70,7 +71,7 @@ func addHeader(client mainchain.Client, collation *sharding.Collation) error { var chunkRoot [32]byte copy(chunkRoot[:], collation.Header().ChunkRoot().Bytes()) - tx, err := client.SMCTransactor().AddHeader(txOps, collation.Header().ShardID(), collation.Header().Period(), chunkRoot) + tx, err := transactor.SMCTransactor().AddHeader(txOps, collation.Header().ShardID(), collation.Header().Period(), chunkRoot) if err != nil { return fmt.Errorf("unable to add header to SMC: %v", err) } @@ -82,9 +83,9 @@ func addHeader(client mainchain.Client, collation *sharding.Collation) error { // submitted to the main chain. There can only be one header per shard // per period, proposer should check if a header's already submitted, // checkHeaderAdded returns true if it's available, false if it's unavailable. -func checkHeaderAdded(client mainchain.Client, shardID *big.Int, period *big.Int) (bool, error) { +func checkHeaderAdded(caller mainchain.ContractCaller, shardID *big.Int, period *big.Int) (bool, error) { // Get the period of the last header. - lastPeriod, err := client.SMCCaller().LastSubmittedCollation(&bind.CallOpts{}, shardID) + lastPeriod, err := caller.SMCCaller().LastSubmittedCollation(&bind.CallOpts{}, shardID) if err != nil { return false, fmt.Errorf("unable to get the period of last submitted collation: %v", err) } diff --git a/sharding/proposer/service.go b/sharding/proposer/service.go index 62eb19e9f3..754e4ab6c9 100644 --- a/sharding/proposer/service.go +++ b/sharding/proposer/service.go @@ -69,7 +69,7 @@ func (p *Proposer) proposeCollations() { period := new(big.Int).Div(blockNumber.Number(), big.NewInt(p.config.PeriodLength)) // Create collation. - collation, err := createCollation(p.client, big.NewInt(int64(p.shardID)), period, txs) + collation, err := createCollation(p.client, p.client.Account(), p.client, big.NewInt(int64(p.shardID)), period, txs) if err != nil { log.Error(fmt.Sprintf("Could not create collation: %v", err)) return diff --git a/sharding/proposer/service_test.go b/sharding/proposer/service_test.go index a57d862b0c..c6155adfa2 100644 --- a/sharding/proposer/service_test.go +++ b/sharding/proposer/service_test.go @@ -16,7 +16,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/sharding/contracts" "github.com/ethereum/go-ethereum/sharding/params" - "gopkg.in/urfave/cli.v1" ) var ( @@ -33,66 +32,32 @@ type mockNode struct { backend *backends.SimulatedBackend } -func (s *mockNode) Account() *accounts.Account { +func (m *mockNode) Account() *accounts.Account { return &accounts.Account{Address: addr} } -func (s *mockNode) SMCCaller() *contracts.SMCCaller { - return &s.smc.SMCCaller +func (m *mockNode) SMCCaller() *contracts.SMCCaller { + return &m.smc.SMCCaller } -func (s *mockNode) ChainReader() ethereum.ChainReader { +func (m *mockNode) ChainReader() ethereum.ChainReader { return nil } -func (s *mockNode) Context() *cli.Context { - return nil +func (m *mockNode) SMCTransactor() *contracts.SMCTransactor { + return &m.smc.SMCTransactor } -func (s *mockNode) SMCTransactor() *contracts.SMCTransactor { - return &s.smc.SMCTransactor -} - -func (s *mockNode) SMCFilterer() *contracts.SMCFilterer { - return &s.smc.SMCFilterer -} - -func (s *mockNode) TransactionReceipt(hash common.Hash) (*types.Receipt, error) { - return nil, nil -} - -func (s *mockNode) CreateTXOpts(value *big.Int) (*bind.TransactOpts, error) { +func (m *mockNode) CreateTXOpts(value *big.Int) (*bind.TransactOpts, error) { txOpts := transactOpts() txOpts.Value = value return txOpts, nil } -func (s *mockNode) DepositFlag() bool { - return false -} - -func (s *mockNode) SetDepositFlag(deposit bool) { - s.depositFlag = deposit -} - func (m *mockNode) Sign(hash common.Hash) ([]byte, error) { return nil, nil } -// Unused mockClient methods. -func (m *mockNode) Start() error { - m.t.Fatal("Start called") - return nil -} - -func (m *mockNode) Close() { - m.t.Fatal("Close called") -} - -func (m *mockNode) DataDirPath() string { - return "/tmp/datadir" -} - func (m *mockNode) GetShardCount() (int64, error) { return 100, nil } @@ -122,7 +87,7 @@ func TestCreateCollation(t *testing.T) { nil, 0, nil, data)) } - collation, err := createCollation(node, big.NewInt(0), big.NewInt(1), txs) + collation, err := createCollation(node, node.Account(), node, big.NewInt(0), big.NewInt(1), txs) if err != nil { t.Fatalf("Create collation failed: %v", err) } @@ -133,7 +98,7 @@ func TestCreateCollation(t *testing.T) { } // negative test case #1: create collation with shard > shardCount. - collation, err = createCollation(node, big.NewInt(101), big.NewInt(2), txs) + collation, err = createCollation(node, node.Account(), node, big.NewInt(101), big.NewInt(2), txs) if err == nil { t.Errorf("Create collation should have failed with invalid shard number") } @@ -145,13 +110,13 @@ func TestCreateCollation(t *testing.T) { badTxs = append(badTxs, types.NewTransaction(0, common.HexToAddress("0x0"), nil, 0, nil, data)) } - collation, err = createCollation(node, big.NewInt(0), big.NewInt(2), badTxs) + collation, err = createCollation(node, node.Account(), node, big.NewInt(0), big.NewInt(2), badTxs) if err == nil { t.Errorf("Create collation should have failed with Txs longer than collation body limit") } // normal test case #1 create collation with correct parameters. - collation, err = createCollation(node, big.NewInt(5), big.NewInt(5), txs) + collation, err = createCollation(node, node.Account(), node, big.NewInt(5), big.NewInt(5), txs) if err != nil { t.Errorf("Create collation failed: %v", err) } @@ -180,7 +145,7 @@ func TestAddCollation(t *testing.T) { nil, 0, nil, data)) } - collation, err := createCollation(node, big.NewInt(0), big.NewInt(1), txs) + collation, err := createCollation(node, node.Account(), node, big.NewInt(0), big.NewInt(1), txs) if err != nil { t.Errorf("Create collation failed: %v", err) } @@ -210,7 +175,7 @@ func TestAddCollation(t *testing.T) { } // negative test case #1 create the same collation that just got added to SMC. - collation, err = createCollation(node, big.NewInt(0), big.NewInt(1), txs) + collation, err = createCollation(node, node.Account(), node, big.NewInt(0), big.NewInt(1), txs) if err == nil { t.Errorf("Create collation should fail due to same collation in SMC") } @@ -228,7 +193,7 @@ func TestCheckCollation(t *testing.T) { nil, 0, nil, data)) } - collation, err := createCollation(node, big.NewInt(0), big.NewInt(1), txs) + collation, err := createCollation(node, node.Account(), node, big.NewInt(0), big.NewInt(1), txs) if err != nil { t.Errorf("Create collation failed: %v", err) }