From 2b631d2d3617aa3c76227ea6c52faa28ef4afa3e Mon Sep 17 00:00:00 2001 From: noot <36753753+noot@users.noreply.github.com> Date: Mon, 9 May 2022 21:30:07 -0400 Subject: [PATCH] automated tester improvements and fixes (#111) --- cmd/daemon/main.go | 19 ++- cmd/daemon/main_test.go | 4 +- cmd/recover/main.go | 20 ++- cmd/recover/main_test.go | 2 +- cmd/tester/main.go | 218 +++++++++++++++++++++--------- common/rpcclient/wsclient.go | 64 +++++++-- monero/utils.go | 12 +- monero/utils_test.go | 2 +- protocol/alice/instance.go | 6 + protocol/alice/message_handler.go | 13 +- protocol/alice/recovery.go | 17 +-- protocol/alice/recovery_test.go | 2 +- protocol/alice/swap_state.go | 24 ++++ protocol/bob/message_handler.go | 10 -- protocol/bob/offers.go | 7 + protocol/bob/swap_state.go | 5 +- recover/recovery.go | 5 +- recover/recovery_test.go | 21 +-- rpc/server.go | 1 + rpc/ws.go | 12 ++ scripts/run-integration-tests.sh | 4 +- 21 files changed, 330 insertions(+), 138 deletions(-) diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index 3b7b6244..ea6a199d 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -69,8 +69,10 @@ const ( flagGasPrice = "gas-price" flagGasLimit = "gas-limit" - flagDevAlice = "dev-alice" - flagDevBob = "dev-bob" + flagDevAlice = "dev-alice" + flagDevBob = "dev-bob" + flagDeploy = "deploy" + flagTransferBack = "transfer-back" flagLog = "log" ) @@ -157,6 +159,14 @@ var ( Name: flagDevBob, Usage: "run in development mode and use XMR provider default values", }, + &cli.BoolFlag{ + Name: flagDeploy, + Usage: "deploy an instance of the swap contract; defaults to false", + }, + &cli.BoolFlag{ + Name: flagTransferBack, + Usage: "when receiving XMR in a swap, transfer it back to the original wallet.", + }, &cli.StringFlag{ Name: flagLog, Usage: "set log level: one of [error|warn|info|debug]", @@ -430,7 +440,9 @@ func getProtocolInstances(ctx context.Context, c *cli.Context, env common.Enviro } var contract *swapfactory.SwapFactory - if !devBob { + deploy := c.Bool(flagDeploy) + + if !devBob || deploy { contract, contractAddr, err = getOrDeploySwapFactory(contractAddr, env, cfg.Basepath, big.NewInt(chainID), pk, ec) if err != nil { @@ -458,6 +470,7 @@ func getProtocolInstances(ctx context.Context, c *cli.Context, env common.Enviro SwapManager: sm, SwapContract: contract, SwapContractAddress: contractAddr, + TransferBack: c.Bool(flagTransferBack), } a, err = alice.NewInstance(aliceCfg) diff --git a/cmd/daemon/main_test.go b/cmd/daemon/main_test.go index abf71d41..f934848c 100644 --- a/cmd/daemon/main_test.go +++ b/cmd/daemon/main_test.go @@ -84,8 +84,8 @@ func TestDaemon_DevAlice(t *testing.T) { func TestDaemon_DevBob(t *testing.T) { c := newTestContext(t, "test --dev-bob", - []string{flagDevBob}, - []interface{}{true}, + []string{flagDevBob, flagDeploy}, + []interface{}{true, true}, ) ctx, cancel := context.WithCancel(context.Background()) diff --git a/cmd/recover/main.go b/cmd/recover/main.go index e6826118..96f06ed9 100644 --- a/cmd/recover/main.go +++ b/cmd/recover/main.go @@ -6,6 +6,7 @@ import ( "math/big" "os" + ethcommon "github.com/ethereum/go-ethereum/common" ethcrypto "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethclient" "github.com/urfave/cli" @@ -16,6 +17,7 @@ import ( "github.com/noot/atomic-swap/protocol/alice" "github.com/noot/atomic-swap/protocol/bob" recovery "github.com/noot/atomic-swap/recover" + "github.com/noot/atomic-swap/swapfactory" logging "github.com/ipfs/go-log" ) @@ -108,8 +110,8 @@ func main() { // Recoverer is implemented by a backend which is able to recover monero type Recoverer interface { WalletFromSecrets(aliceSecret, bobSecret string) (mcrypto.Address, error) - RecoverFromBobSecretAndContract(b *bob.Instance, bobSecret, contractAddr string, swapID *big.Int) (*bob.RecoveryResult, error) //nolint:lll - RecoverFromAliceSecretAndContract(a *alice.Instance, aliceSecret, contractAddr string, swapID *big.Int) (*alice.RecoveryResult, error) //nolint:lll + RecoverFromBobSecretAndContract(b *bob.Instance, bobSecret, contractAddr string, swapID *big.Int) (*bob.RecoveryResult, error) //nolint:lll + RecoverFromAliceSecretAndContract(a *alice.Instance, aliceSecret string, swapID *big.Int) (*alice.RecoveryResult, error) //nolint:lll } type instance struct { @@ -188,12 +190,13 @@ func (inst *instance) recover(c *cli.Context) error { } if as != "" && contractAddr != "" { - a, err := createAliceInstance(context.Background(), c, env, cfg) + addr := ethcommon.HexToAddress(contractAddr) + a, err := createAliceInstance(context.Background(), c, env, cfg, addr) if err != nil { return err } - res, err := r.RecoverFromAliceSecretAndContract(a, as, contractAddr, swapID) + res, err := r.RecoverFromAliceSecretAndContract(a, as, swapID) if err != nil { return err } @@ -238,7 +241,7 @@ func getRecoverer(c *cli.Context, env common.Environment) (Recoverer, error) { } func createAliceInstance(ctx context.Context, c *cli.Context, env common.Environment, - cfg common.Config) (*alice.Instance, error) { + cfg common.Config, contractAddr ethcommon.Address) (*alice.Instance, error) { var ( moneroEndpoint, ethEndpoint string ) @@ -281,6 +284,11 @@ func createAliceInstance(ctx context.Context, c *cli.Context, env common.Environ return nil, err } + contract, err := swapfactory.NewSwapFactory(contractAddr, ec) + if err != nil { + return nil, err + } + aliceCfg := &alice.Config{ Ctx: ctx, Basepath: cfg.Basepath, @@ -291,6 +299,8 @@ func createAliceInstance(ctx context.Context, c *cli.Context, env common.Environ ChainID: big.NewInt(chainID), GasPrice: gasPrice, GasLimit: uint64(c.Uint(flagGasLimit)), + SwapContract: contract, + SwapContractAddress: contractAddr, } return alice.NewInstance(aliceCfg) diff --git a/cmd/recover/main_test.go b/cmd/recover/main_test.go index 22bcf912..607f950e 100644 --- a/cmd/recover/main_test.go +++ b/cmd/recover/main_test.go @@ -80,7 +80,7 @@ func (r *mockRecoverer) RecoverFromBobSecretAndContract(b *bob.Instance, bobSecr }, nil } -func (r *mockRecoverer) RecoverFromAliceSecretAndContract(a *alice.Instance, aliceSecret, contractAddr string, +func (r *mockRecoverer) RecoverFromAliceSecretAndContract(a *alice.Instance, aliceSecret string, swapID *big.Int) (*alice.RecoveryResult, error) { return &alice.RecoveryResult{ Claimed: true, diff --git a/cmd/tester/main.go b/cmd/tester/main.go index f40a10de..28cf8d00 100644 --- a/cmd/tester/main.go +++ b/cmd/tester/main.go @@ -10,6 +10,7 @@ import ( mrand "math/rand" "os" "path/filepath" + "strings" "sync" "time" @@ -24,21 +25,14 @@ import ( const ( flagConfig = "config" flagTimeout = "timeout" + flagLog = "log" defaultConfigFile = "testerconfig.json" ) var defaultTimeout = time.Minute * 15 -var ( - log = logging.Logger("cmd") - _ = logging.SetLogLevel("alice", "debug") - _ = logging.SetLogLevel("bob", "debug") - _ = logging.SetLogLevel("common", "debug") - _ = logging.SetLogLevel("cmd", "debug") - _ = logging.SetLogLevel("net", "debug") - _ = logging.SetLogLevel("rpc", "debug") -) +var log = logging.Logger("cmd") var ( app = &cli.App{ @@ -54,6 +48,10 @@ var ( Name: flagTimeout, Usage: "time for which to run tester, in minutes; default=15mins", }, + &cli.StringFlag{ + Name: flagLog, + Usage: "set log level: one of [error|warn|info|debug]", + }, }, } ) @@ -65,7 +63,42 @@ func main() { } } +func setLogLevels(c *cli.Context) error { + const ( + levelError = "error" + levelWarn = "warn" + levelInfo = "info" + levelDebug = "debug" + ) + + _ = logging.SetLogLevel("cmd", levelInfo) + + level := c.String(flagLog) + if level == "" { + level = levelInfo + } + + switch level { + case levelError, levelWarn, levelInfo, levelDebug: + default: + return fmt.Errorf("invalid log level") + } + + _ = logging.SetLogLevel("alice", level) + _ = logging.SetLogLevel("bob", level) + _ = logging.SetLogLevel("common", level) + _ = logging.SetLogLevel("net", level) + _ = logging.SetLogLevel("rpc", level) + _ = logging.SetLogLevel("rpcclient", level) + return nil +} + func runTester(c *cli.Context) error { + err := setLogLevels(c) + if err != nil { + return err + } + var timeout time.Duration timeoutMins := c.Uint(flagTimeout) @@ -75,10 +108,15 @@ func runTester(c *cli.Context) error { timeout = time.Minute * time.Duration(timeoutMins) } - log.Infof("starting to test, total duration is %dmins", timeout.Minutes()) + log.Infof("starting to test, total duration is %vmins", timeout.Minutes()) - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() + timer := time.After(timeout) + done := make(chan struct{}) + + go func() { + <-timer + close(done) + }() config := c.String(flagConfig) if config == "" { @@ -103,14 +141,17 @@ func runTester(c *cli.Context) error { errChs := make([]chan error, len(endpoints)) for i, endpoint := range endpoints { + errChs[i] = make(chan error, 16) + d := &daemon{ rsl: rsl, endpoint: endpoint, errCh: errChs[i], wg: &wg, idx: i, + stop: make(chan struct{}), } - go d.test(ctx) + go d.test(done) } wg.Wait() @@ -133,39 +174,37 @@ func getRandomExchangeRate() types.ExchangeRate { type daemon struct { rsl *resultLogger endpoint string - wsc rpcclient.WsClient errCh chan error wg *sync.WaitGroup idx int + stop chan struct{} + swapMu sync.Mutex } -func (d *daemon) test(ctx context.Context) { +func (d *daemon) test(done <-chan struct{}) { log.Infof("starting tester for node %s at index %d...", d.endpoint, d.idx) defer d.wg.Done() - go d.logErrors(ctx) - - var err error - d.wsc, err = rpcclient.NewWsClient(ctx, d.endpoint) - if err != nil { - d.errCh <- err - return - } + go d.logErrors(done) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() + var sleep int for { - d.makeOffer() - if ctx.Err() != nil { + select { + case <-time.After(time.Second * time.Duration(sleep)): + d.makeOffer(done) + case <-done: + return + case <-d.stop: return } - sleep := getRandomInt(10) - time.Sleep(time.Second * time.Duration(sleep)) + sleep = getRandomInt(60) + 3 } }() @@ -173,35 +212,49 @@ func (d *daemon) test(ctx context.Context) { defer wg.Done() for { - sleep := getRandomInt(10) - time.Sleep(time.Second * time.Duration(sleep)) + sleep := getRandomInt(60) + 3 - d.takeOffer() - if ctx.Err() != nil { + select { + case <-time.After(time.Second * time.Duration(sleep)): + d.takeOffer(done) + case <-done: + return + case <-d.stop: return } } }() wg.Wait() + log.Warnf("node %d returning from d.test", d.idx) } -func (d *daemon) logErrors(ctx context.Context) { +func (d *daemon) logErrors(done <-chan struct{}) { for { select { - case <-ctx.Done(): + case <-done: return case err := <-d.errCh: - log.Errorf("endpoint %d: %w", d.idx, err) + log.Errorf("endpoint %d: %s", d.idx, err) + if strings.Contains(err.Error(), "connection refused") { + close(d.stop) + } } } } -func (d *daemon) takeOffer() { +func (d *daemon) takeOffer(done <-chan struct{}) { log.Debugf("node %d discovering offers...", d.idx) + wsc, err := rpcclient.NewWsClient(context.Background(), d.endpoint) + if err != nil { + d.errCh <- err + return + } + + defer wsc.Close() const defaultDiscoverTimeout = uint64(3) // 3s - providers, err := d.wsc.Discover(types.ProvidesXMR, defaultDiscoverTimeout) + providers, err := wsc.Discover(types.ProvidesXMR, defaultDiscoverTimeout) if err != nil { d.errCh <- err return @@ -216,12 +269,16 @@ func (d *daemon) takeOffer() { log.Debugf("node %d querying peer %s...", d.idx, peer) - resp, err := d.wsc.Query(peer) + resp, err := wsc.Query(peer) if err != nil { d.errCh <- err return } + if len(resp.Offers) == 0 { + return + } + offerIdx := getRandomInt(len(resp.Offers)) offer := resp.Offers[offerIdx] @@ -232,26 +289,34 @@ func (d *daemon) takeOffer() { start := time.Now() log.Infof("node %d taking offer %s", d.idx, offer.GetID().String()) - _, takerStatusCh, err := d.wsc.TakeOfferAndSubscribe(peer, + _, takerStatusCh, err := wsc.TakeOfferAndSubscribe(peer, offer.GetID().String(), providesAmount) if err != nil { d.errCh <- err return } - for status := range takerStatusCh { - log.Debugf("> taker %d got status:", d.idx, status) - if status.IsOngoing() { - continue - } + d.swapMu.Lock() + defer d.swapMu.Unlock() - if status != types.CompletedSuccess { - d.errCh <- fmt.Errorf("swap did not complete successfully for taker: got %s", status) - } + for { + select { + case <-done: + return + case status := <-takerStatusCh: + log.Infof("> taker (node %d) got status: %s", d.idx, status) + if status.IsOngoing() { + continue + } - d.rsl.logTakerStatus(status) - d.rsl.logSwapDuration(time.Since(start)) - return + if status != types.CompletedSuccess { + d.errCh <- fmt.Errorf("swap did not complete successfully for taker: got %s", status) + } + + d.rsl.logTakerStatus(status) + d.rsl.logSwapDuration(time.Since(start)) + return + } } } @@ -260,40 +325,61 @@ func getRandomInt(max int) int { return int(i.Int64()) } -func (d *daemon) makeOffer() { +func (d *daemon) makeOffer(done <-chan struct{}) { log.Infof("node %d making offer...", d.idx) + wsc, err := rpcclient.NewWsClient(context.Background(), d.endpoint) + if err != nil { + d.errCh <- err + return + } - offerID, takenCh, statusCh, err := d.wsc.MakeOfferAndSubscribe(minProvidesAmount, + defer wsc.Close() + + offerID, takenCh, statusCh, err := wsc.MakeOfferAndSubscribe(minProvidesAmount, maxProvidesAmount, getRandomExchangeRate(), ) if err != nil { + log.Errorf("failed to make offer (node %d): %s", d.idx, err) d.errCh <- err return } log.Infof("node %d made offer %s", d.idx, offerID) - taken := <-takenCh - if taken == nil { + select { + case <-done: return + case taken := <-takenCh: + if taken == nil { + log.Warn("got nil from takenCh") + return + } } + d.swapMu.Lock() + defer d.swapMu.Unlock() + start := time.Now() - for status := range statusCh { - log.Debugf("> maker %d got status:", d.idx, status) - if status.IsOngoing() { - continue - } + for { + select { + case <-done: + return + case status := <-statusCh: + log.Infof("> maker (node %d) got status: %s", d.idx, status) + if status.IsOngoing() { + continue + } - if status != types.CompletedSuccess { - d.errCh <- fmt.Errorf("swap did not complete successfully for maker: exit status was %s", status) - } + if status != types.CompletedSuccess { + d.errCh <- fmt.Errorf("swap did not complete successfully for maker: exit status was %s", status) + } - d.rsl.logMakerStatus(status) - d.rsl.logSwapDuration(time.Since(start)) - return + d.rsl.logMakerStatus(status) + d.rsl.logSwapDuration(time.Since(start)) + return + } } } @@ -324,6 +410,10 @@ func (l *resultLogger) logSwapDuration(duration time.Duration) { } func (l *resultLogger) averageDuration() time.Duration { + if len(l.durations) == 0 { + return 0 + } + sum := time.Duration(0) for _, dur := range l.durations { sum += dur diff --git a/common/rpcclient/wsclient.go b/common/rpcclient/wsclient.go index 5ded6ba5..806f00f2 100644 --- a/common/rpcclient/wsclient.go +++ b/common/rpcclient/wsclient.go @@ -17,6 +17,7 @@ var log = logging.Logger("rpcclient") // WsClient ... type WsClient interface { + Close() Discover(provides types.ProvidesCoin, searchTime uint64) ([][]string, error) Query(maddr string) (*rpctypes.QueryPeerResponse, error) SubscribeSwapStatus(id uint64) (<-chan types.Status, error) @@ -27,7 +28,8 @@ type WsClient interface { } type wsClient struct { - sync.Mutex + wmu sync.Mutex + rmu sync.Mutex conn *websocket.Conn } @@ -47,12 +49,27 @@ func NewWsClient(ctx context.Context, endpoint string) (*wsClient, error) { ///n }, nil } +func (c *wsClient) Close() { + _ = c.conn.Close() +} + func (c *wsClient) writeJSON(msg *rpctypes.Request) error { - c.Lock() - defer c.Unlock() + c.wmu.Lock() + defer c.wmu.Unlock() return c.conn.WriteJSON(msg) } +func (c *wsClient) read() ([]byte, error) { + c.rmu.Lock() + defer c.rmu.Unlock() + _, message, err := c.conn.ReadMessage() + if err != nil { + return nil, err + } + + return message, nil +} + func (c *wsClient) Discover(provides types.ProvidesCoin, searchTime uint64) ([][]string, error) { params := &rpctypes.DiscoverRequest{ Provides: provides, @@ -75,8 +92,7 @@ func (c *wsClient) Discover(provides types.ProvidesCoin, searchTime uint64) ([][ return nil, err } - // read ID from connection - _, message, err := c.conn.ReadMessage() + message, err := c.read() if err != nil { return nil, fmt.Errorf("failed to read websockets message: %s", err) } @@ -122,7 +138,7 @@ func (c *wsClient) Query(maddr string) (*rpctypes.QueryPeerResponse, error) { } // read ID from connection - _, message, err := c.conn.ReadMessage() + message, err := c.read() if err != nil { return nil, fmt.Errorf("failed to read websockets message: %s", err) } @@ -175,7 +191,7 @@ func (c *wsClient) SubscribeSwapStatus(id uint64) (<-chan types.Status, error) { defer close(respCh) for { - _, message, err := c.conn.ReadMessage() + message, err := c.read() if err != nil { log.Warnf("failed to read websockets message: %s", err) break @@ -200,7 +216,11 @@ func (c *wsClient) SubscribeSwapStatus(id uint64) (<-chan types.Status, error) { break } - respCh <- types.NewStatus(status.Status) + s := types.NewStatus(status.Status) + respCh <- s + if !s.IsOngoing() { + return + } } }() @@ -232,7 +252,7 @@ func (c *wsClient) TakeOfferAndSubscribe(multiaddr, offerID string, } // read ID from connection - _, message, err := c.conn.ReadMessage() + message, err := c.read() if err != nil { return 0, nil, fmt.Errorf("failed to read websockets message: %s", err) } @@ -259,7 +279,7 @@ func (c *wsClient) TakeOfferAndSubscribe(multiaddr, offerID string, defer close(respCh) for { - _, message, err := c.conn.ReadMessage() + message, err := c.read() if err != nil { log.Warnf("failed to read websockets message: %s", err) break @@ -284,7 +304,11 @@ func (c *wsClient) TakeOfferAndSubscribe(multiaddr, offerID string, break } - respCh <- types.NewStatus(status.Status) + s := types.NewStatus(status.Status) + respCh <- s + if !s.IsOngoing() { + return + } } }() @@ -316,16 +340,22 @@ func (c *wsClient) MakeOfferAndSubscribe(min, max float64, ID: 0, } + log.Debug("writing net_makeOfferAndSubscribe") + if err = c.writeJSON(req); err != nil { return "", nil, nil, err } + log.Debugf("wrote") + // read ID from connection - _, message, err := c.conn.ReadMessage() + message, err := c.read() if err != nil { return "", nil, nil, fmt.Errorf("failed to read websockets message: %s", err) } + log.Debugf("got response") + var resp *rpctypes.Response err = json.Unmarshal(message, &resp) if err != nil { @@ -351,7 +381,7 @@ func (c *wsClient) MakeOfferAndSubscribe(min, max float64, defer close(takenCh) // read if swap was taken - _, message, err := c.conn.ReadMessage() + message, err := c.read() if err != nil { log.Warnf("failed to read websockets message: %s", err) return @@ -379,7 +409,7 @@ func (c *wsClient) MakeOfferAndSubscribe(min, max float64, takenCh <- taken for { - _, message, err := c.conn.ReadMessage() + message, err := c.read() if err != nil { log.Warnf("failed to read websockets message: %s", err) break @@ -404,7 +434,11 @@ func (c *wsClient) MakeOfferAndSubscribe(min, max float64, break } - respCh <- types.NewStatus(status.Status) + s := types.NewStatus(status.Status) + respCh <- s + if !s.IsOngoing() { + return + } } }() diff --git a/monero/utils.go b/monero/utils.go index 52eabbdb..aafe2d0d 100644 --- a/monero/utils.go +++ b/monero/utils.go @@ -20,27 +20,31 @@ var ( ) // WaitForBlocks waits for a new block to arrive. -func WaitForBlocks(client Client) error { +func WaitForBlocks(client Client) (uint, error) { prevHeight, err := client.GetHeight() if err != nil { - return fmt.Errorf("failed to get height: %w", err) + return 0, fmt.Errorf("failed to get height: %w", err) } for i := 0; i < maxRetries; i++ { + if err := client.Refresh(); err != nil { + return 0, err + } + height, err := client.GetHeight() if err != nil { continue } if height > prevHeight { - return nil + return height, nil } log.Infof("waiting for next block, current height=%d", height) time.Sleep(blockSleepDuration) } - return fmt.Errorf("timed out waiting for next block") + return 0, fmt.Errorf("timed out waiting for next block") } // CreateMoneroWallet creates a monero wallet from a private keypair. diff --git a/monero/utils_test.go b/monero/utils_test.go index 06c43be1..cd7211f7 100644 --- a/monero/utils_test.go +++ b/monero/utils_test.go @@ -20,7 +20,7 @@ func TestWaitForBlocks(t *testing.T) { _ = daemon.callGenerateBlocks(addr.Address, 181) }() - err = WaitForBlocks(c) + _, err = WaitForBlocks(c) require.NoError(t, err) } diff --git a/protocol/alice/instance.go b/protocol/alice/instance.go index b8a8d3af..cc41a252 100644 --- a/protocol/alice/instance.go +++ b/protocol/alice/instance.go @@ -87,6 +87,8 @@ type Config struct { func NewInstance(cfg *Config) (*Instance, error) { if cfg.Environment == common.Development { defaultTimeoutDuration = time.Minute + } else if cfg.Environment == common.Stagenet { + defaultTimeoutDuration = time.Hour } pub := cfg.EthereumPrivateKey.Public().(*ecdsa.PublicKey) @@ -105,6 +107,10 @@ func NewInstance(cfg *Config) (*Instance, error) { } } + if cfg.SwapContract == nil || (cfg.SwapContractAddress == ethcommon.Address{}) { + return nil, fmt.Errorf("must provide swap contract and address") + } + // TODO: check that Alice's monero-wallet-cli endpoint has wallet-dir configured return &Instance{ ctx: cfg.Ctx, diff --git a/protocol/alice/message_handler.go b/protocol/alice/message_handler.go index 323595d0..360730f2 100644 --- a/protocol/alice/message_handler.go +++ b/protocol/alice/message_handler.go @@ -243,17 +243,26 @@ func (s *swapState) handleNotifyXMRLock(msg *message.NotifyXMRLock) (net.Message log.Debugf("generated view-only wallet to check funds: %s", walletName) if s.alice.env != common.Development { + log.Infof("waiting for new blocks...") // wait for 2 new blocks, otherwise balance might be 0 // TODO: check transaction hash - if err := monero.WaitForBlocks(s.alice.client); err != nil { + height, err := monero.WaitForBlocks(s.alice.client) + if err != nil { return nil, err } - if err := monero.WaitForBlocks(s.alice.client); err != nil { + log.Infof("monero block height: %d", height) + + height, err = monero.WaitForBlocks(s.alice.client) + if err != nil { return nil, err } + + log.Infof("monero block height: %d", height) } + log.Debug("refreshing client...") + if err := s.alice.client.Refresh(); err != nil { return nil, fmt.Errorf("failed to refresh client: %w", err) } diff --git a/protocol/alice/recovery.go b/protocol/alice/recovery.go index 95f6bc76..34453f38 100644 --- a/protocol/alice/recovery.go +++ b/protocol/alice/recovery.go @@ -20,14 +20,13 @@ import ( var claimedTopic = ethcommon.HexToHash("0xd5a2476fc450083bbb092dd3f4be92698ffdc2d213e6f1e730c7f44a52f1ccfc") type recoveryState struct { - ss *swapState - contractAddr ethcommon.Address + ss *swapState } // NewRecoveryState returns a new *bob.recoveryState, // which has methods to either claim ether or reclaim monero from an initiated swap. func NewRecoveryState(a *Instance, secret *mcrypto.PrivateSpendKey, - contractAddr ethcommon.Address, contractSwapID *big.Int) (*recoveryState, error) { //nolint:revive + contractSwapID *big.Int) (*recoveryState, error) { //nolint:revive txOpts, err := bind.NewKeyedTransactorWithChainID(a.ethPrivKey, a.chainID) if err != nil { return nil, err @@ -64,10 +63,6 @@ func NewRecoveryState(a *Instance, secret *mcrypto.PrivateSpendKey, ss: s, } - if err := rs.setContract(contractAddr); err != nil { - return nil, err - } - if err := rs.ss.setTimeouts(); err != nil { return nil, err } @@ -127,14 +122,6 @@ func (rs *recoveryState) ClaimOrRefund() (*RecoveryResult, error) { }, nil } -// setContract sets the contract in which Alice has locked her ETH. -func (rs *recoveryState) setContract(address ethcommon.Address) error { - var err error - rs.contractAddr = address - rs.ss.alice.contract, err = swapfactory.NewSwapFactory(address, rs.ss.alice.ethClient) - return err -} - func (s *swapState) filterForClaim() (*mcrypto.PrivateSpendKey, error) { const claimedEvent = "Claimed" diff --git a/protocol/alice/recovery_test.go b/protocol/alice/recovery_test.go index e1820fd0..66ab22e9 100644 --- a/protocol/alice/recovery_test.go +++ b/protocol/alice/recovery_test.go @@ -36,7 +36,7 @@ func newTestRecoveryState(t *testing.T) *recoveryState { _, err = s.lockETH(common.NewEtherAmount(1)) require.NoError(t, err) - rs, err := NewRecoveryState(inst, s.privkeys.SpendKey(), inst.contractAddr, s.contractSwapID) + rs, err := NewRecoveryState(inst, s.privkeys.SpendKey(), s.contractSwapID) require.NoError(t, err) return rs } diff --git a/protocol/alice/swap_state.go b/protocol/alice/swap_state.go index 4d6e7932..db25fc6d 100644 --- a/protocol/alice/swap_state.go +++ b/protocol/alice/swap_state.go @@ -68,6 +68,10 @@ type swapState struct { func newSwapState(a *Instance, infofile string, providesAmount common.EtherAmount, receivedAmount common.MoneroAmount, exhangeRate types.ExchangeRate) (*swapState, error) { + if a.contract == nil { + return nil, errors.New("no swap contract found") + } + txOpts, err := bind.NewKeyedTransactorWithChainID(a.ethPrivKey, a.chainID) if err != nil { return nil, err @@ -107,9 +111,29 @@ func newSwapState(a *Instance, infofile string, providesAmount common.EtherAmoun return nil, fmt.Errorf("failed to write contract address to file: %w", err) } + go s.waitForSendKeysMessage() + return s, nil } +func (s *swapState) waitForSendKeysMessage() { + waitDuration := time.Minute + timer := time.After(waitDuration) + select { + case <-s.ctx.Done(): + return + case <-timer: + } + + // check if we've received a response from the counterparty yet + if s.nextExpectedMessage != (&net.SendKeysMessage{}) { + return + } + + // if not, just exit the swap + _ = s.Exit() +} + // SendKeysMessage ... func (s *swapState) SendKeysMessage() (*net.SendKeysMessage, error) { if err := s.generateAndSetKeys(); err != nil { diff --git a/protocol/bob/message_handler.go b/protocol/bob/message_handler.go index 1beddb12..bf7b5dba 100644 --- a/protocol/bob/message_handler.go +++ b/protocol/bob/message_handler.go @@ -143,27 +143,19 @@ func (s *swapState) handleNotifyETHLocked(msg *message.NotifyETHLocked) (net.Mes return nil, fmt.Errorf("failed to instantiate contract instance: %w", err) } - log.Infof("contract set") - if err := pcommon.WriteContractAddressToFile(s.infofile, msg.Address); err != nil { return nil, fmt.Errorf("failed to write contract address to file: %w", err) } - log.Infof("wrote to file") - if err := s.checkContract(ethcommon.HexToHash(msg.TxHash)); err != nil { return nil, err } - log.Infof("checked contract") - addrAB, err := s.lockFunds(common.MoneroToPiconero(s.info.ProvidedAmount())) if err != nil { return nil, fmt.Errorf("failed to lock funds: %w", err) } - log.Infof("locked funds") - out := &message.NotifyXMRLock{ Address: string(addrAB), } @@ -173,8 +165,6 @@ func (s *swapState) handleNotifyETHLocked(msg *message.NotifyETHLocked) (net.Mes return nil, err } - log.Infof("timeouts set") - go func() { until := time.Until(s.t0) diff --git a/protocol/bob/offers.go b/protocol/bob/offers.go index 597abf52..a0a172ee 100644 --- a/protocol/bob/offers.go +++ b/protocol/bob/offers.go @@ -86,3 +86,10 @@ func (b *Instance) GetOffers() []*types.Offer { } return offers } + +// ClearOffers clears all offers. +func (b *Instance) ClearOffers() { + b.swapMu.Lock() + defer b.swapMu.Unlock() + b.offerManager.offers = make(map[types.Hash]*offerWithExtra) +} diff --git a/protocol/bob/swap_state.go b/protocol/bob/swap_state.go index 8f73bb59..e187c4c0 100644 --- a/protocol/bob/swap_state.go +++ b/protocol/bob/swap_state.go @@ -482,9 +482,12 @@ func (s *swapState) lockFunds(amount common.MoneroAmount) (mcrypto.Address, erro _ = s.bob.daemonClient.GenerateBlocks(bobAddr.Address, 2) } else { // otherwise, wait for new blocks - if err := monero.WaitForBlocks(s.bob.client); err != nil { + height, err := monero.WaitForBlocks(s.bob.client) + if err != nil { return "", err } + + log.Infof("monero block height: %d", height) } if err := s.bob.client.Refresh(); err != nil { diff --git a/recover/recovery.go b/recover/recovery.go index 058cc59c..0cdf4d36 100644 --- a/recover/recovery.go +++ b/recover/recovery.go @@ -90,7 +90,7 @@ func (r *recoverer) RecoverFromBobSecretAndContract(b *bob.Instance, // RecoverFromAliceSecretAndContract recovers funds by either claiming locked monero or refunding ether. func (r *recoverer) RecoverFromAliceSecretAndContract(a *alice.Instance, - aliceSecret, contractAddr string, swapID *big.Int) (*alice.RecoveryResult, error) { + aliceSecret string, swapID *big.Int) (*alice.RecoveryResult, error) { as, err := hex.DecodeString(aliceSecret) if err != nil { return nil, fmt.Errorf("failed to decode Alice's secret: %w", err) @@ -101,8 +101,7 @@ func (r *recoverer) RecoverFromAliceSecretAndContract(a *alice.Instance, return nil, err } - addr := ethcommon.HexToAddress(contractAddr) - rs, err := alice.NewRecoveryState(a, ak, addr, swapID) + rs, err := alice.NewRecoveryState(a, ak, swapID) if err != nil { return nil, err } diff --git a/recover/recovery_test.go b/recover/recovery_test.go index 62e21c88..6c9b71e1 100644 --- a/recover/recovery_test.go +++ b/recover/recovery_test.go @@ -27,7 +27,8 @@ func newRecoverer(t *testing.T) *recoverer { return r } -func newSwap(t *testing.T, claimKey, refundKey [32]byte, setReady bool) (ethcommon.Address, *big.Int) { +func newSwap(t *testing.T, claimKey, refundKey [32]byte, + setReady bool) (ethcommon.Address, *swapfactory.SwapFactory, *big.Int) { tm := big.NewInt(defaulTimeout) pk, err := ethcrypto.HexToECDSA(common.DefaultPrivKeyAlice) @@ -59,10 +60,10 @@ func newSwap(t *testing.T, claimKey, refundKey [32]byte, setReady bool) (ethcomm require.NoError(t, err) } - return addr, swapID + return addr, contract, swapID } -func newAliceInstance(t *testing.T) *alice.Instance { +func newAliceInstance(t *testing.T, addr ethcommon.Address, contract *swapfactory.SwapFactory) *alice.Instance { pk, err := ethcrypto.HexToECDSA(common.DefaultPrivKeyAlice) require.NoError(t, err) @@ -76,6 +77,8 @@ func newAliceInstance(t *testing.T) *alice.Instance { EthereumClient: ec, ChainID: big.NewInt(common.GanacheChainID), MoneroWalletEndpoint: common.DefaultAliceMoneroEndpoint, + SwapContract: contract, + SwapContractAddress: addr, } a, err := alice.NewInstance(cfg) @@ -128,7 +131,7 @@ func TestRecoverer_RecoverFromBobSecretAndContract_Claim(t *testing.T) { b := newBobInstance(t) claimKey := keys.Secp256k1PublicKey.Keccak256() - addr, swapID := newSwap(t, claimKey, [32]byte{}, true) + addr, _, swapID := newSwap(t, claimKey, [32]byte{}, true) r := newRecoverer(t) res, err := r.RecoverFromBobSecretAndContract(b, keys.PrivateKeyPair.SpendKey().Hex(), addr.String(), swapID) @@ -147,7 +150,7 @@ func TestRecoverer_RecoverFromBobSecretAndContract_Claim_afterTimeout(t *testing b := newBobInstance(t) claimKey := keys.Secp256k1PublicKey.Keccak256() - addr, swapID := newSwap(t, claimKey, [32]byte{}, false) + addr, _, swapID := newSwap(t, claimKey, [32]byte{}, false) r := newRecoverer(t) res, err := r.RecoverFromBobSecretAndContract(b, keys.PrivateKeyPair.SpendKey().Hex(), addr.String(), swapID) @@ -163,13 +166,13 @@ func TestRecoverer_RecoverFromAliceSecretAndContract_Refund(t *testing.T) { keys, err := pcommon.GenerateKeysAndProof() require.NoError(t, err) - a := newAliceInstance(t) - refundKey := keys.Secp256k1PublicKey.Keccak256() - addr, swapID := newSwap(t, [32]byte{}, refundKey, false) + addr, contract, swapID := newSwap(t, [32]byte{}, refundKey, false) + + a := newAliceInstance(t, addr, contract) r := newRecoverer(t) - res, err := r.RecoverFromAliceSecretAndContract(a, keys.PrivateKeyPair.SpendKey().Hex(), addr.String(), swapID) + res, err := r.RecoverFromAliceSecretAndContract(a, keys.PrivateKeyPair.SpendKey().Hex(), swapID) require.NoError(t, err) require.True(t, res.Refunded) } diff --git a/rpc/server.go b/rpc/server.go index 19de7a07..3cf301b7 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -125,6 +125,7 @@ type Bob interface { MakeOffer(offer *types.Offer) (*types.OfferExtra, error) SetMoneroWalletFile(file, password string) error GetOffers() []*types.Offer + ClearOffers() } // SwapManager ... diff --git a/rpc/ws.go b/rpc/ws.go index bbece475..e6cdb4c7 100644 --- a/rpc/ws.go +++ b/rpc/ws.go @@ -160,6 +160,10 @@ func (s *wsServer) subscribeTakeOffer(ctx context.Context, conn *websocket.Conn, if err := writeResponse(conn, resp); err != nil { return err } + + if !status.IsOngoing() { + return nil + } case <-ctx.Done(): return nil } @@ -215,6 +219,10 @@ func (s *wsServer) subscribeMakeOffer(ctx context.Context, conn *websocket.Conn, if err := writeResponse(conn, resp); err != nil { return err } + + if !status.IsOngoing() { + return nil + } case <-ctx.Done(): return nil } @@ -245,6 +253,10 @@ func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn if err := writeResponse(conn, resp); err != nil { return err } + + if !status.IsOngoing() { + return nil + } case <-ctx.Done(): return nil } diff --git a/scripts/run-integration-tests.sh b/scripts/run-integration-tests.sh index a212aeb0..e815a4eb 100755 --- a/scripts/run-integration-tests.sh +++ b/scripts/run-integration-tests.sh @@ -38,11 +38,11 @@ bash scripts/build.sh ALICE_PID=$! sleep 3 echo "starting bob, logs in ./tests/bob.log" -./swapd --dev-bob --bootnodes /ip4/127.0.0.1/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2 --wallet-file test-wallet &> ./tests/bob.log & +./swapd --dev-bob --bootnodes /ip4/127.0.0.1/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2 --wallet-file test-wallet --deploy &> ./tests/bob.log & BOB_PID=$! sleep 3 echo "starting charlie, logs in ./tests/charlie.log" -./swapd --libp2p-port 9955 --rpc-port 5003 --ws-port 8083 --bootnodes /ip4/127.0.0.1/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2 &> ./tests/charlie.log & +./swapd --libp2p-port 9955 --rpc-port 5003 --ws-port 8083 --bootnodes /ip4/127.0.0.1/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2 --deploy &> ./tests/charlie.log & CHARLIE_PID=$! sleep 3