diff --git a/cmd/client/main.go b/cmd/client/main.go index ce337bbe..7f16fbfa 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -1,11 +1,13 @@ package main import ( + "context" "errors" "fmt" "os" "github.com/noot/atomic-swap/cmd/client/client" + "github.com/noot/atomic-swap/common/rpcclient" "github.com/noot/atomic-swap/common/types" logging "github.com/ipfs/go-log" @@ -101,6 +103,10 @@ var ( Name: "provides-amount", Usage: "amount of coin to send in the swap", }, + &cli.BoolFlag{ + Name: "subscribe", + Usage: "subscribe to push notifications about the swap's status", + }, daemonAddrFlag, }, }, @@ -278,6 +284,29 @@ func runTake(ctx *cli.Context) error { endpoint = defaultSwapdAddress } + if ctx.Bool("subscribe") { + c, err := rpcclient.NewWsClient(context.Background(), endpoint) + if err != nil { + return err + } + + id, statusCh, err := c.TakeOfferAndSubscribe(maddr, offerID, providesAmount) + if err != nil { + return err + } + + fmt.Printf("Initiated swap with ID=%d\n", id) + + for stage := range statusCh { + fmt.Printf("> Stage updated: %s\n", stage) + if !stage.IsOngoing() { + return nil + } + } + + return nil + } + c := client.NewClient(endpoint) id, err := c.TakeOffer(maddr, offerID, providesAmount) if err != nil { diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index fd733264..fc75ba05 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -309,6 +309,7 @@ func (d *daemon) make(c *cli.Context) error { } rpcCfg := &rpc.Config{ + Ctx: d.ctx, Port: rpcPort, WsPort: wsPort, Net: host, diff --git a/common/interfaces.go b/common/interfaces.go index 3549d57c..d2ec51a1 100644 --- a/common/interfaces.go +++ b/common/interfaces.go @@ -21,60 +21,5 @@ type SwapStateNet interface { type SwapStateRPC interface { SendKeysMessage() (*message.SendKeysMessage, error) ID() uint64 - Stage() Stage -} - -// Stage represents the stage that a swap is at. -type Stage byte - -const ( - ExpectingKeysStage Stage = iota //nolint:revive - KeysExchangedStage - ContractDeployedStage - XMRLockedStage - ContractReadyStage - ClaimOrRefundStage - UnknownStage -) - -const unknownString string = "unknown" - -// String ... -func (s Stage) String() string { - switch s { - case ExpectingKeysStage: - return "ExpectingKeys" - case KeysExchangedStage: - return "KeysExchanged" - case ContractDeployedStage: - return "ContractDeployed" - case XMRLockedStage: - return "XMRLocked" - case ContractReadyStage: - return "ContractReady" - case ClaimOrRefundStage: - return "ClaimOrRefund" - default: - return unknownString - } -} - -// Info returns a description of the swap stage. -func (s Stage) Info() string { - switch s { - case ExpectingKeysStage: - return "keys have not yet been exchanged" - case KeysExchangedStage: - return "keys have been exchanged, but no value has been locked" - case ContractDeployedStage: - return "the ETH provider has locked their ether, but no XMR has been locked" - case XMRLockedStage: - return "both the XMR and ETH providers have locked their funds" - case ContractReadyStage: - return "the locked ether is ready to be claimed" - case ClaimOrRefundStage: - return "the locked funds have been claimed or refunded" - default: - return unknownString - } + //Status() types.Status } diff --git a/common/rpcclient/types.go b/common/rpcclient/types.go new file mode 100644 index 00000000..8b2f7531 --- /dev/null +++ b/common/rpcclient/types.go @@ -0,0 +1,46 @@ +package rpcclient + +import ( + "encoding/json" + "fmt" +) + +// Request represents a JSON-RPC request +type Request struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params map[string]interface{} `json:"params"` + ID uint64 `json:"id"` +} + +// Response is the JSON format of a response +type Response struct { + // JSON-RPC Version + Version string `json:"jsonrpc"` + // Resulting values + Result json.RawMessage `json:"result"` + // Any generated errors + Error *Error `json:"error"` + // Request id + ID *json.RawMessage `json:"id"` +} + +// ErrCode is a int type used for the rpc error codes +type ErrCode int + +// Error is a struct that holds the error message and the error code for a error +type Error struct { + Message string `json:"message"` + ErrorCode ErrCode `json:"code"` + Data map[string]interface{} `json:"data"` +} + +// Error ... +func (e *Error) Error() string { + return fmt.Sprintf("message=%s; code=%d; data=%v", e.Message, e.ErrorCode, e.Data) +} + +// SubscribeSwapStatusResponse ... +type SubscribeSwapStatusResponse struct { + Stage string `json:"stage"` +} diff --git a/common/rpcclient/utils.go b/common/rpcclient/utils.go index 81382a7d..cc042e1d 100644 --- a/common/rpcclient/utils.go +++ b/common/rpcclient/utils.go @@ -28,35 +28,8 @@ var ( } ) -// ServerResponse is the JSON format of a response -type ServerResponse struct { - // JSON-RPC Version - Version string `json:"jsonrpc"` - // Resulting values - Result json.RawMessage `json:"result"` - // Any generated errors - Error *Error `json:"error"` - // Request id - ID *json.RawMessage `json:"id"` -} - -// ErrCode is a int type used for the rpc error codes -type ErrCode int - -// Error is a struct that holds the error message and the error code for a error -type Error struct { - Message string `json:"message"` - ErrorCode ErrCode `json:"code"` - Data map[string]interface{} `json:"data"` -} - -// Error ... -func (e *Error) Error() string { - return fmt.Sprintf("message=%s; code=%d; data=%v", e.Message, e.ErrorCode, e.Data) -} - // PostRPC posts a JSON-RPC call to the given endpoint. -func PostRPC(endpoint, method, params string) (*ServerResponse, error) { +func PostRPC(endpoint, method, params string) (*Response, error) { data := []byte(`{"jsonrpc":"2.0","method":"` + method + `","params":` + params + `,"id":0}`) buf := &bytes.Buffer{} _, err := buf.Write(data) @@ -88,7 +61,7 @@ func PostRPC(endpoint, method, params string) (*ServerResponse, error) { return nil, fmt.Errorf("failed to read response body: %w", err) } - var sv *ServerResponse + var sv *Response if err = json.Unmarshal(body, &sv); err != nil { return nil, err } diff --git a/common/rpcclient/wsclient.go b/common/rpcclient/wsclient.go new file mode 100644 index 00000000..b3fe5beb --- /dev/null +++ b/common/rpcclient/wsclient.go @@ -0,0 +1,181 @@ +package rpcclient + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/noot/atomic-swap/common/types" + + "github.com/gorilla/websocket" + logging "github.com/ipfs/go-log" +) + +// DefaultJSONRPCVersion ... +const DefaultJSONRPCVersion = "2.0" + +var log = logging.Logger("rpcclient") + +// WsClient ... +type WsClient interface { + SubscribeSwapStatus(id uint64) (<-chan types.Status, error) + TakeOfferAndSubscribe(multiaddr, offerID string, + providesAmount float64) (id uint64, ch <-chan types.Status, err error) +} + +type wsClient struct { + conn *websocket.Conn +} + +// NewWsClient ... +func NewWsClient(ctx context.Context, endpoint string) (*wsClient, error) { ///nolint:revive + conn, resp, err := websocket.DefaultDialer.DialContext(ctx, endpoint, nil) + if err != nil { + return nil, fmt.Errorf("failed to dial endpoint: %w", err) + } + + if err = resp.Body.Close(); err != nil { + return nil, err + } + + return &wsClient{ + conn: conn, + }, nil +} + +// SubscribeSwapStatus returns a channel that is written to each time the swap's status updates. +// If there is no swap with the given ID, it returns an error. +func (c *wsClient) SubscribeSwapStatus(id uint64) (<-chan types.Status, error) { + req := &Request{ + JSONRPC: DefaultJSONRPCVersion, + Method: "swap_subscribeStatus", + Params: map[string]interface{}{ + "id": id, + }, + ID: 0, + } + + if err := c.conn.WriteJSON(req); err != nil { + return nil, err + } + + respCh := make(chan types.Status) + + go func() { + defer close(respCh) + + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + log.Warnf("failed to read websockets message: %s", err) + break + } + + var resp *Response + err = json.Unmarshal(message, &resp) + if err != nil { + log.Warnf("failed to unmarshal response: %s", err) + break + } + + if resp.Error != nil { + log.Warnf("websocket server returned error: %s", resp.Error) + break + } + + log.Debugf("received message over websockets: %s", message) + var status *SubscribeSwapStatusResponse + if err := json.Unmarshal(resp.Result, &status); err != nil { + log.Warnf("failed to unmarshal response: %s", err) + break + } + + respCh <- types.NewStatus(status.Stage) + } + }() + + return respCh, nil +} + +func (c *wsClient) TakeOfferAndSubscribe(multiaddr, offerID string, + providesAmount float64) (id uint64, ch <-chan types.Status, err error) { + req := &Request{ + JSONRPC: DefaultJSONRPCVersion, + Method: "net_takeOfferAndSubscribe", + Params: map[string]interface{}{ + "multiaddr": multiaddr, + "offerID": offerID, + "providesAmount": providesAmount, + }, + ID: 0, + } + + if err = c.conn.WriteJSON(req); err != nil { + return 0, nil, err + } + + // read ID from connection + _, message, err := c.conn.ReadMessage() + if err != nil { + return 0, nil, fmt.Errorf("failed to read websockets message: %s", err) + } + + var resp *Response + err = json.Unmarshal(message, &resp) + if err != nil { + return 0, nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + if resp.Error != nil { + return 0, nil, fmt.Errorf("websocket server returned error: %w", resp.Error) + } + + log.Debugf("received message over websockets: %s", message) + var idResp map[string]uint64 + if err := json.Unmarshal(resp.Result, &idResp); err != nil { + return 0, nil, fmt.Errorf("failed to unmarshal response: %s", err) + } + + id, ok := idResp["id"] + if !ok { + return 0, nil, errors.New("websocket response did not contain ID") + } + + respCh := make(chan types.Status) + + go func() { + defer close(respCh) + + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + log.Warnf("failed to read websockets message: %s", err) + break + } + + var resp *Response + err = json.Unmarshal(message, &resp) + if err != nil { + log.Warnf("failed to unmarshal response: %s", err) + break + } + + if resp.Error != nil { + log.Warnf("websocket server returned error: %s", resp.Error) + break + } + + log.Debugf("received message over websockets: %s", message) + var status *SubscribeSwapStatusResponse + if err := json.Unmarshal(resp.Result, &status); err != nil { + log.Warnf("failed to unmarshal response: %s", err) + break + } + + respCh <- types.NewStatus(status.Stage) + } + }() + + return id, respCh, nil +} diff --git a/common/types/status.go b/common/types/status.go new file mode 100644 index 00000000..b66eab20 --- /dev/null +++ b/common/types/status.go @@ -0,0 +1,104 @@ +package types + +// Status represents the stage that a swap is at. +type Status byte + +const ( + ExpectingKeys Status = iota //nolint:revive + KeysExchanged + ContractDeployed + XMRLocked + ContractReady + // CompletedSuccess represents a successful swap. + CompletedSuccess + // CompletedRefund represents a swap that was refunded. + CompletedRefund + // CompletedAbort represents the case where the swap aborts before any funds are locked. + CompletedAbort + UnknownStatus +) + +const unknownString string = "unknown" + +// NewStatus returns a Status from the given string. +// If there is no match, it returns UnknownStatus +func NewStatus(str string) Status { + switch str { + case "ExpectingKeys": + return ExpectingKeys + case "KeysExchanged": + return KeysExchanged + case "ContractDeployed": + return ContractDeployed + case "XMRLocked": + return XMRLocked + case "ContractReady": + return ContractReady + case "Success": + return CompletedSuccess + case "Refunded": + return CompletedRefund + case "Aborted": + return CompletedAbort + default: + return UnknownStatus + } +} + +// String ... +func (s Status) String() string { + switch s { + case ExpectingKeys: + return "ExpectingKeys" + case KeysExchanged: + return "KeysExchanged" + case ContractDeployed: + return "ContractDeployed" + case XMRLocked: + return "XMRLocked" + case ContractReady: + return "ContractReady" + case CompletedSuccess: + return "Success" + case CompletedRefund: + return "Refunded" + case CompletedAbort: + return "Aborted" + default: + return unknownString + } +} + +// Info returns a description of the swap stage. +func (s Status) Info() string { + switch s { + case ExpectingKeys: + return "keys have not yet been exchanged" + case KeysExchanged: + return "keys have been exchanged, but no value has been locked" + case ContractDeployed: + return "the ETH provider has locked their ether, but no XMR has been locked" + case XMRLocked: + return "both the XMR and ETH providers have locked their funds" + case ContractReady: + return "the locked ether is ready to be claimed" + case CompletedSuccess: + return "the locked funds have been claimed and the swap has completed successfully" + case CompletedRefund: + return "the locked funds have been refunded and the swap has completed" + case CompletedAbort: + return "the swap was aborted before any funds were locked" + default: + return unknownString + } +} + +// IsOngoing returns true if the status means the swap has not completed +func (s Status) IsOngoing() bool { + switch s { + case ExpectingKeys, KeysExchanged, ContractDeployed, XMRLocked, ContractReady, UnknownStatus: + return true + default: + return false + } +} diff --git a/protocol/alice/message_handler.go b/protocol/alice/message_handler.go index 8cecde46..3f92da9f 100644 --- a/protocol/alice/message_handler.go +++ b/protocol/alice/message_handler.go @@ -12,7 +12,6 @@ import ( "github.com/noot/atomic-swap/net" "github.com/noot/atomic-swap/net/message" pcommon "github.com/noot/atomic-swap/protocol" - pswap "github.com/noot/atomic-swap/protocol/swap" "github.com/noot/atomic-swap/swapfactory" ethcommon "github.com/ethereum/go-ethereum/common" @@ -54,14 +53,30 @@ func (s *swapState) HandleProtocolMessage(msg net.Message) (net.Message, bool, e close(s.claimedCh) log.Info("successfully created monero wallet from our secrets: address=", address) - s.nextExpectedMessage = nil - s.info.SetStatus(pswap.Success) + s.clearNextExpectedMessage(types.CompletedSuccess) return nil, true, nil default: return nil, false, errors.New("unexpected message type") } } +func (s *swapState) clearNextExpectedMessage(status types.Status) { + s.nextExpectedMessage = nil + s.info.SetStatus(status) + if s.statusCh != nil { + s.statusCh <- status + } +} + +func (s *swapState) setNextExpectedMessage(msg net.Message) { + s.nextExpectedMessage = msg + // TODO: check stage is not unknown (ie. swap completed) + stage := pcommon.GetStatus(msg.Type()) + if s.statusCh != nil { + s.statusCh <- stage + } +} + func (s *swapState) checkMessageType(msg net.Message) error { if msg.Type() != s.nextExpectedMessage.Type() { return errors.New("received unexpected message") @@ -150,7 +165,7 @@ func (s *swapState) handleSendKeysMessage(msg *net.SendKeysMessage) (net.Message }() - s.nextExpectedMessage = &message.NotifyXMRLock{} + s.setNextExpectedMessage(&message.NotifyXMRLock{}) out := &message.NotifyContractDeployed{ Address: s.alice.contractAddr.String(), @@ -278,7 +293,7 @@ func (s *swapState) handleNotifyXMRLock(msg *message.NotifyXMRLock) (net.Message } }() - s.nextExpectedMessage = &message.NotifyClaimed{} + s.setNextExpectedMessage(&message.NotifyClaimed{}) return &message.NotifyReady{}, nil } diff --git a/protocol/alice/swap_state.go b/protocol/alice/swap_state.go index 8c0ed3ff..ff889391 100644 --- a/protocol/alice/swap_state.go +++ b/protocol/alice/swap_state.go @@ -34,7 +34,8 @@ type swapState struct { cancel context.CancelFunc sync.Mutex - info *pswap.Info + info *pswap.Info + statusCh chan types.Status // our keys for this session dleqProof *dleq.Proof @@ -70,7 +71,10 @@ func newSwapState(a *Instance, providesAmount common.EtherAmount) (*swapState, e txOpts.GasPrice = a.gasPrice txOpts.GasLimit = a.gasLimit - info := pswap.NewInfo(types.ProvidesETH, providesAmount.AsEther(), 0, 0, pswap.Ongoing) + stage := types.ExpectingKeys + statusCh := make(chan types.Status, 16) + statusCh <- stage + info := pswap.NewInfo(types.ProvidesETH, providesAmount.AsEther(), 0, 0, stage, statusCh) if err := a.swapManager.AddSwap(info); err != nil { return nil, err } @@ -85,6 +89,7 @@ func newSwapState(a *Instance, providesAmount common.EtherAmount) (*swapState, e xmrLockedCh: make(chan struct{}), claimedCh: make(chan struct{}), info: info, + statusCh: statusCh, } return s, nil @@ -117,14 +122,6 @@ func (s *swapState) receivedAmountInPiconero() common.MoneroAmount { return common.MoneroToPiconero(s.info.ReceivedAmount()) } -func (s *swapState) Stage() common.Stage { - if s.nextExpectedMessage == nil { - return pcommon.GetStage(message.NilType) - } - - return pcommon.GetStage(s.nextExpectedMessage.Type()) -} - // ID returns the ID of the swap func (s *swapState) ID() uint64 { return s.info.ID() @@ -143,13 +140,13 @@ func (s *swapState) ProtocolExited() error { s.alice.swapManager.CompleteOngoingSwap() }() - if s.info.Status() == pswap.Success { + if s.info.Status() == types.CompletedSuccess { str := color.New(color.Bold).Sprintf("**swap completed successfully: id=%d**", s.info.ID()) log.Info(str) return nil } - if s.info.Status() == pswap.Refunded { + if s.info.Status() == types.CompletedRefund { str := color.New(color.Bold).Sprintf("**swap refunded successfully! id=%d**", s.info.ID()) log.Info(str) return nil @@ -158,30 +155,30 @@ func (s *swapState) ProtocolExited() error { switch s.nextExpectedMessage.(type) { case *net.SendKeysMessage: // we are fine, as we only just initiated the protocol. - s.info.SetStatus(pswap.Aborted) + s.clearNextExpectedMessage(types.CompletedAbort) return errSwapAborted case *message.NotifyXMRLock: // we already deployed the contract, so we should call Refund(). txHash, err := s.tryRefund() if err != nil { - s.info.SetStatus(pswap.Aborted) + s.clearNextExpectedMessage(types.CompletedAbort) log.Errorf("failed to refund: err=%s", err) return err } - s.info.SetStatus(pswap.Refunded) + s.clearNextExpectedMessage(types.CompletedRefund) log.Infof("refunded ether: transaction hash=%s", txHash) case *message.NotifyClaimed: // the XMR has been locked, but the ETH hasn't been claimed. // we should also refund in this case. txHash, err := s.tryRefund() if err != nil { - s.info.SetStatus(pswap.Aborted) + s.clearNextExpectedMessage(types.CompletedAbort) log.Errorf("failed to refund: err=%s", err) return err } - s.info.SetStatus(pswap.Refunded) + s.clearNextExpectedMessage(types.CompletedRefund) log.Infof("refunded ether: transaction hash=%s", txHash) case nil: skA, err := s.filterForClaim() @@ -197,7 +194,7 @@ func (s *swapState) ProtocolExited() error { log.Infof("claimed monero: address=%s", addr) default: log.Errorf("unexpected nextExpectedMessage in ProtocolExited: type=%T", s.nextExpectedMessage) - s.info.SetStatus(pswap.Aborted) + s.clearNextExpectedMessage(types.CompletedAbort) return errUnexpectedMessageType } @@ -213,12 +210,12 @@ func (s *swapState) doRefund() (ethcommon.Hash, error) { // we can refund in this case. txHash, err := s.tryRefund() if err != nil { - s.info.SetStatus(pswap.Aborted) + s.clearNextExpectedMessage(types.CompletedAbort) log.Errorf("failed to refund: err=%s", err) return ethcommon.Hash{}, err } - s.info.SetStatus(pswap.Refunded) + s.clearNextExpectedMessage(types.CompletedRefund) log.Infof("refunded ether: transaction hash=%s", txHash) // send NotifyRefund msg @@ -341,7 +338,7 @@ func (s *swapState) lockETH(amount common.EtherAmount) error { return fmt.Errorf("failed to deploy Swap.sol: %w", err) } - log.Debugf("deploying Swap.sol, amount=%s txHash=%s", amount, tx.Hash()) + log.Debugf("instantiating swap on-chain: amount=%s txHash=%s", amount, tx.Hash()) receipt, err := common.WaitForReceipt(s.ctx, s.alice.ethClient, tx.Hash()) if err != nil { return fmt.Errorf("failed to call new_swap in contract: %w", err) @@ -395,7 +392,7 @@ func (s *swapState) refund() (ethcommon.Hash, error) { return ethcommon.Hash{}, fmt.Errorf("failed to call Refund function in contract: %w", err) } - s.info.SetStatus(pswap.Refunded) + s.clearNextExpectedMessage(types.CompletedRefund) return tx.Hash(), nil } diff --git a/protocol/alice/swap_state_test.go b/protocol/alice/swap_state_test.go index e821c85f..91aed2d3 100644 --- a/protocol/alice/swap_state_test.go +++ b/protocol/alice/swap_state_test.go @@ -12,6 +12,7 @@ import ( "github.com/ethereum/go-ethereum/ethclient" "github.com/noot/atomic-swap/common" + "github.com/noot/atomic-swap/common/types" mcrypto "github.com/noot/atomic-swap/crypto/monero" "github.com/noot/atomic-swap/monero" "github.com/noot/atomic-swap/net" @@ -138,8 +139,15 @@ func TestSwapState_HandleProtocolMessage_SendKeysMessage_Refund(t *testing.T) { require.Equal(t, bobKeysAndProof.PublicKeyPair.SpendKey().Hex(), s.bobPublicSpendKey.Hex()) require.Equal(t, bobKeysAndProof.PrivateKeyPair.ViewKey().Hex(), s.bobPrivateViewKey.Hex()) + for status := range s.statusCh { + if status == types.CompletedRefund { + break + } else if !status.IsOngoing() { + t.Fatalf("got wrong exit status %s, expected CompletedRefund", status) + } + } + // ensure we refund before t0 - time.Sleep(time.Second * 15) require.NotNil(t, s.alice.net.(*mockNet).msg) require.Equal(t, message.NotifyRefundType, s.alice.net.(*mockNet).msg.Type()) @@ -225,7 +233,14 @@ func TestSwapState_NotifyXMRLock_Refund(t *testing.T) { _, ok := resp.(*message.NotifyReady) require.True(t, ok) - time.Sleep(time.Second * 25) + for status := range s.statusCh { + if status == types.CompletedRefund { + break + } else if !status.IsOngoing() { + t.Fatalf("got wrong exit status %s, expected CompletedRefund", status) + } + } + require.NotNil(t, s.alice.net.(*mockNet).msg) require.Equal(t, message.NotifyRefundType, s.alice.net.(*mockNet).msg.Type()) @@ -273,7 +288,7 @@ func TestSwapState_NotifyClaimed(t *testing.T) { daemonClient := monero.NewClient(common.DefaultMoneroDaemonEndpoint) _ = daemonClient.GenerateBlocks(bobAddr.Address, 60) - amt := common.MoneroAmount(333) + amt := common.MoneroAmount(1) s.info.SetReceivedAmount(amt.AsMonero()) kp := mcrypto.SumSpendAndViewKeys(s.pubkeys, s.pubkeys) xmrAddr := kp.Address(common.Mainnet) @@ -327,7 +342,7 @@ func TestProtocolExited_afterSendKeysMessage(t *testing.T) { err := s.ProtocolExited() require.Equal(t, errSwapAborted, err) info := s.alice.swapManager.GetPastSwap(s.info.ID()) - require.Equal(t, pswap.Aborted, info.Status()) + require.Equal(t, types.CompletedAbort, info.Status()) } func TestProtocolExited_afterNotifyXMRLock(t *testing.T) { @@ -350,7 +365,7 @@ func TestProtocolExited_afterNotifyXMRLock(t *testing.T) { err = s.ProtocolExited() require.NoError(t, err) info := s.alice.swapManager.GetPastSwap(s.info.ID()) - require.Equal(t, pswap.Refunded, info.Status()) + require.Equal(t, types.CompletedRefund, info.Status()) } func TestProtocolExited_afterNotifyClaimed(t *testing.T) { @@ -373,7 +388,7 @@ func TestProtocolExited_afterNotifyClaimed(t *testing.T) { err = s.ProtocolExited() require.NoError(t, err) info := s.alice.swapManager.GetPastSwap(s.info.ID()) - require.Equal(t, pswap.Refunded, info.Status()) + require.Equal(t, types.CompletedRefund, info.Status()) } func TestProtocolExited_invalidNextMessageType(t *testing.T) { @@ -397,5 +412,5 @@ func TestProtocolExited_invalidNextMessageType(t *testing.T) { err = s.ProtocolExited() require.Equal(t, errUnexpectedMessageType, err) info := s.alice.swapManager.GetPastSwap(s.info.ID()) - require.Equal(t, pswap.Aborted, info.Status()) + require.Equal(t, types.CompletedAbort, info.Status()) } diff --git a/protocol/bob/message_handler.go b/protocol/bob/message_handler.go index 8ef758e2..cfb0e1ac 100644 --- a/protocol/bob/message_handler.go +++ b/protocol/bob/message_handler.go @@ -8,11 +8,11 @@ import ( ethcommon "github.com/ethereum/go-ethereum/common" "github.com/noot/atomic-swap/common" + "github.com/noot/atomic-swap/common/types" mcrypto "github.com/noot/atomic-swap/crypto/monero" "github.com/noot/atomic-swap/net" "github.com/noot/atomic-swap/net/message" pcommon "github.com/noot/atomic-swap/protocol" - pswap "github.com/noot/atomic-swap/protocol/swap" "github.com/noot/atomic-swap/swapfactory" ) @@ -60,7 +60,7 @@ func (s *swapState) HandleProtocolMessage(msg net.Message) (net.Message, bool, e TxHash: txHash.String(), } - s.info.SetStatus(pswap.Success) + s.clearNextExpectedMessage(types.CompletedSuccess) return out, true, nil case *message.NotifyRefund: // generate monero wallet, regaining control over locked funds @@ -69,7 +69,7 @@ func (s *swapState) HandleProtocolMessage(msg net.Message) (net.Message, bool, e return nil, false, err } - s.info.SetStatus(pswap.Refunded) + s.clearNextExpectedMessage(types.CompletedRefund) log.Infof("regained control over monero account %s", addr) return nil, true, nil default: @@ -77,6 +77,23 @@ func (s *swapState) HandleProtocolMessage(msg net.Message) (net.Message, bool, e } } +func (s *swapState) clearNextExpectedMessage(status types.Status) { + s.nextExpectedMessage = nil + s.info.SetStatus(status) + if s.statusCh != nil { + s.statusCh <- status + } +} + +func (s *swapState) setNextExpectedMessage(msg net.Message) { + s.nextExpectedMessage = msg + // TODO: check stage is not unknown (ie. swap completed) + stage := pcommon.GetStatus(msg.Type()) + if s.statusCh != nil { + s.statusCh <- stage + } +} + func (s *swapState) checkMessageType(msg net.Message) error { // Alice might refund anytime before t0 or after t1, so we should allow this. if _, ok := msg.(*message.NotifyRefund); ok { @@ -147,7 +164,7 @@ func (s *swapState) handleNotifyContractDeployed(msg *message.NotifyContractDepl } log.Debug("funds claimed!") - s.info.SetStatus(pswap.Success) + s.clearNextExpectedMessage(types.CompletedSuccess) // send *message.NotifyClaimed if err := s.bob.net.SendSwapMessage(&message.NotifyClaimed{ @@ -160,7 +177,7 @@ func (s *swapState) handleNotifyContractDeployed(msg *message.NotifyContractDepl } }() - s.nextExpectedMessage = &message.NotifyReady{} + s.setNextExpectedMessage(&message.NotifyReady{}) return out, nil } @@ -183,7 +200,7 @@ func (s *swapState) handleSendKeysMessage(msg *net.SendKeysMessage) error { } s.setAlicePublicKeys(kp, secp256k1Pub) - s.nextExpectedMessage = &message.NotifyContractDeployed{} + s.setNextExpectedMessage(&message.NotifyContractDeployed{}) return nil } diff --git a/protocol/bob/recovery_test.go b/protocol/bob/recovery_test.go index c345f900..458747da 100644 --- a/protocol/bob/recovery_test.go +++ b/protocol/bob/recovery_test.go @@ -44,6 +44,10 @@ func TestClaimOrRecover_Claim(t *testing.T) { } func TestClaimOrRecover_Recover(t *testing.T) { + if testing.Short() { + t.Skip() // TODO: fails on CI w/ "not enough money" + } + // test case where Bob is able to reclaim his monero, after Alice refunds rs := newTestRecoveryState(t) @@ -54,7 +58,7 @@ func TestClaimOrRecover_Recover(t *testing.T) { // lock XMR rs.ss.setAlicePublicKeys(rs.ss.pubkeys, nil) - addrAB, err := rs.ss.lockFunds(333) + addrAB, err := rs.ss.lockFunds(1) require.NoError(t, err) // call refund w/ Alice's spend key diff --git a/protocol/bob/swap_state.go b/protocol/bob/swap_state.go index d58f1ded..901d6ca6 100644 --- a/protocol/bob/swap_state.go +++ b/protocol/bob/swap_state.go @@ -42,8 +42,9 @@ type swapState struct { cancel context.CancelFunc sync.Mutex - info *pswap.Info - offerID types.Hash + info *pswap.Info + offerID types.Hash + statusCh chan types.Status // our keys for this session dleqProof *dleq.Proof @@ -82,8 +83,11 @@ func newSwapState(b *Instance, offerID types.Hash, providesAmount common.MoneroA txOpts.GasLimit = b.gasLimit exchangeRate := types.ExchangeRate(providesAmount.AsMonero() / desiredAmount.AsEther()) + stage := types.ExpectingKeys + statusCh := make(chan types.Status, 7) + statusCh <- stage info := pswap.NewInfo(types.ProvidesXMR, providesAmount.AsMonero(), desiredAmount.AsEther(), - exchangeRate, pswap.Ongoing) + exchangeRate, stage, statusCh) if err := b.swapManager.AddSwap(info); err != nil { return nil, err } @@ -98,6 +102,7 @@ func newSwapState(b *Instance, offerID types.Hash, providesAmount common.MoneroA readyCh: make(chan struct{}), txOpts: txOpts, info: info, + statusCh: statusCh, } return s, nil @@ -124,14 +129,6 @@ func (s *swapState) ReceivedAmount() float64 { return s.info.ReceivedAmount() } -func (s *swapState) Stage() common.Stage { - if s.nextExpectedMessage == nil { - return pcommon.GetStage(message.NilType) - } - - return pcommon.GetStage(s.nextExpectedMessage.Type()) -} - // ID returns the ID of the swap func (s *swapState) ID() uint64 { return s.info.ID() @@ -150,7 +147,7 @@ func (s *swapState) ProtocolExited() error { s.bob.swapManager.CompleteOngoingSwap() }() - if s.info.Status() == pswap.Success { + if s.info.Status() == types.CompletedSuccess { str := color.New(color.Bold).Sprintf("**swap completed successfully: id=%d**", s.ID()) log.Info(str) @@ -159,7 +156,7 @@ func (s *swapState) ProtocolExited() error { return nil } - if s.info.Status() == pswap.Refunded { + if s.info.Status() == types.CompletedRefund { str := color.New(color.Bold).Sprintf("**swap refunded successfully: id=%d**", s.ID()) log.Info(str) return nil @@ -168,12 +165,12 @@ func (s *swapState) ProtocolExited() error { switch s.nextExpectedMessage.(type) { case *net.SendKeysMessage: // we are fine, as we only just initiated the protocol. - s.info.SetStatus(pswap.Aborted) + s.clearNextExpectedMessage(types.CompletedAbort) return errSwapAborted case *message.NotifyContractDeployed: // we were waiting for the contract to be deployed, but haven't // locked out funds yet, so we're fine. - s.info.SetStatus(pswap.Aborted) + s.clearNextExpectedMessage(types.CompletedAbort) return errSwapAborted case *message.NotifyReady: // we already locked our funds - need to wait until we can claim @@ -194,12 +191,12 @@ func (s *swapState) ProtocolExited() error { return err } - s.info.SetStatus(pswap.Refunded) + s.clearNextExpectedMessage(types.CompletedRefund) s.moneroReclaimAddress = address log.Infof("regained private key to monero wallet, address=%s", address) return nil default: - s.info.SetStatus(pswap.Aborted) + s.clearNextExpectedMessage(types.CompletedAbort) log.Errorf("unexpected nextExpectedMessage in ProtocolExited: type=%T", s.nextExpectedMessage) return errUnexpectedMessageType } diff --git a/protocol/bob/swap_state_test.go b/protocol/bob/swap_state_test.go index 302dfb48..2eefb070 100644 --- a/protocol/bob/swap_state_test.go +++ b/protocol/bob/swap_state_test.go @@ -69,7 +69,7 @@ func newTestBob(t *testing.T) *Instance { bobAddr, err := bob.client.GetAddress(0) require.NoError(t, err) - _ = bob.daemonClient.GenerateBlocks(bobAddr.Address, 121) + _ = bob.daemonClient.GenerateBlocks(bobAddr.Address, 256) err = bob.client.Refresh() require.NoError(t, err) return bob @@ -149,7 +149,7 @@ func TestSwapState_ClaimFunds(t *testing.T) { txHash, err := swapState.claimFunds() require.NoError(t, err) require.NotEqual(t, "", txHash) - require.Equal(t, pswap.Ongoing, swapState.info.Status()) + require.True(t, swapState.info.Status().IsOngoing()) } func TestSwapState_handleSendKeysMessage(t *testing.T) { @@ -167,7 +167,7 @@ func TestSwapState_handleSendKeysMessage(t *testing.T) { require.Equal(t, &message.NotifyContractDeployed{}, s.nextExpectedMessage) require.Equal(t, alicePubKeys.SpendKey().Hex(), s.alicePublicKeys.SpendKey().Hex()) require.Equal(t, alicePubKeys.ViewKey().Hex(), s.alicePublicKeys.ViewKey().Hex()) - require.Equal(t, pswap.Ongoing, s.info.Status()) + require.True(t, s.info.Status().IsOngoing()) } func TestSwapState_HandleProtocolMessage_NotifyContractDeployed_ok(t *testing.T) { @@ -206,7 +206,7 @@ func TestSwapState_HandleProtocolMessage_NotifyContractDeployed_ok(t *testing.T) require.Equal(t, addr, s.contractAddr) require.Equal(t, duration, s.t1.Sub(s.t0)) require.Equal(t, &message.NotifyReady{}, s.nextExpectedMessage) - require.Equal(t, pswap.Ongoing, s.info.Status()) + require.True(t, s.info.Status().IsOngoing()) } func TestSwapState_HandleProtocolMessage_NotifyContractDeployed_timeout(t *testing.T) { @@ -247,10 +247,16 @@ func TestSwapState_HandleProtocolMessage_NotifyContractDeployed_timeout(t *testi require.Equal(t, duration, s.t1.Sub(s.t0)) require.Equal(t, &message.NotifyReady{}, s.nextExpectedMessage) - // TODO: fix this, it's sometimes nil - time.Sleep(duration * 3) + for status := range s.statusCh { + if status == types.CompletedSuccess { + break + } else if !status.IsOngoing() { + t.Fatalf("got wrong exit status %s, expected CompletedSuccess", status) + } + } + require.NotNil(t, s.bob.net.(*mockNet).msg) - require.Equal(t, pswap.Success, s.info.Status()) + require.Equal(t, types.CompletedSuccess, s.info.Status()) } func TestSwapState_HandleProtocolMessage_NotifyReady(t *testing.T) { @@ -274,7 +280,7 @@ func TestSwapState_HandleProtocolMessage_NotifyReady(t *testing.T) { require.True(t, done) require.NotNil(t, resp) require.Equal(t, message.NotifyClaimedType, resp.Type()) - require.Equal(t, pswap.Success, s.info.Status()) + require.Equal(t, types.CompletedSuccess, s.info.Status()) } func TestSwapState_handleRefund(t *testing.T) { @@ -346,7 +352,7 @@ func TestSwapState_HandleProtocolMessage_NotifyRefund(t *testing.T) { require.NoError(t, err) require.True(t, done) require.Nil(t, resp) - require.Equal(t, pswap.Refunded, s.info.Status()) + require.Equal(t, types.CompletedRefund, s.info.Status()) } // test that if the protocol exits early, and Alice refunds, Bob can reclaim his monero @@ -391,7 +397,7 @@ func TestSwapState_ProtocolExited_Reclaim(t *testing.T) { balance, err := bob.client.GetBalance(0) require.NoError(t, err) require.Equal(t, common.MoneroToPiconero(s.info.ProvidedAmount()).Uint64(), uint64(balance.Balance)) - require.Equal(t, pswap.Refunded, s.info.Status()) + require.Equal(t, types.CompletedRefund, s.info.Status()) } func TestSwapState_ProtocolExited_Aborted(t *testing.T) { @@ -399,17 +405,17 @@ func TestSwapState_ProtocolExited_Aborted(t *testing.T) { s.nextExpectedMessage = &message.SendKeysMessage{} err := s.ProtocolExited() require.Equal(t, errSwapAborted, err) - require.Equal(t, pswap.Aborted, s.info.Status()) + require.Equal(t, types.CompletedAbort, s.info.Status()) s.nextExpectedMessage = &message.NotifyContractDeployed{} err = s.ProtocolExited() require.Equal(t, errSwapAborted, err) - require.Equal(t, pswap.Aborted, s.info.Status()) + require.Equal(t, types.CompletedAbort, s.info.Status()) s.nextExpectedMessage = nil err = s.ProtocolExited() require.Equal(t, errUnexpectedMessageType, err) - require.Equal(t, pswap.Aborted, s.info.Status()) + require.Equal(t, types.CompletedAbort, s.info.Status()) } func TestSwapState_ProtocolExited_Success(t *testing.T) { @@ -423,7 +429,7 @@ func TestSwapState_ProtocolExited_Success(t *testing.T) { b.MakeOffer(offer) s.offerID = offer.GetID() - s.info.SetStatus(pswap.Success) + s.info.SetStatus(types.CompletedSuccess) err := s.ProtocolExited() require.NoError(t, err) require.Nil(t, b.offerManager.getOffer(offer.GetID())) @@ -440,7 +446,7 @@ func TestSwapState_ProtocolExited_Refunded(t *testing.T) { b.MakeOffer(offer) s.offerID = offer.GetID() - s.info.SetStatus(pswap.Refunded) + s.info.SetStatus(types.CompletedRefund) err := s.ProtocolExited() require.NoError(t, err) require.NotNil(t, b.offerManager.getOffer(offer.GetID())) diff --git a/protocol/stage.go b/protocol/stage.go index 07348758..f8a969b7 100644 --- a/protocol/stage.go +++ b/protocol/stage.go @@ -1,26 +1,24 @@ package protocol import ( - "github.com/noot/atomic-swap/common" + "github.com/noot/atomic-swap/common/types" "github.com/noot/atomic-swap/net/message" ) -// GetStage returns the stage corresponding to a next expected message type. -func GetStage(t message.Type) common.Stage { +// GetStatus returns the status corresponding to a next expected message type. +func GetStatus(t message.Type) types.Status { switch t { case message.SendKeysType: - return common.ExpectingKeysStage + return types.ExpectingKeys case message.NotifyContractDeployedType: - return common.KeysExchangedStage + return types.KeysExchanged case message.NotifyXMRLockType: - return common.ContractDeployedStage + return types.ContractDeployed case message.NotifyReadyType: - return common.XMRLockedStage + return types.XMRLocked case message.NotifyClaimedType: - return common.ContractReadyStage - case message.NilType: - return common.ClaimOrRefundStage + return types.ContractReady default: - return common.UnknownStage + return types.UnknownStatus } } diff --git a/protocol/swap/manager.go b/protocol/swap/manager.go index d225d195..e7a30d09 100644 --- a/protocol/swap/manager.go +++ b/protocol/swap/manager.go @@ -8,36 +8,10 @@ import ( var nextID uint64 -// Status represents the status of a swap. -type Status byte - -const ( - // Ongoing represents an ongoing swap. - Ongoing Status = iota - // Success represents a successful swap. - Success - // Refunded represents a swap that was refunded. - Refunded - // Aborted represents the case where the swap aborts before any funds are locked. - Aborted +type ( + Status = types.Status //nolint:revive ) -// String ... -func (s Status) String() string { - switch s { - case Ongoing: - return "ongoing" - case Success: - return "success" - case Refunded: - return "refunded" - case Aborted: - return "aborted" - default: - return "unknown" - } -} - // Info contains the details of the swap as well as its status. type Info struct { id uint64 // ID number of the swap (not the swap offer ID!) @@ -46,6 +20,7 @@ type Info struct { receivedAmount float64 exchangeRate types.ExchangeRate status Status + statusCh <-chan types.Status } // ID returns the swap ID. @@ -86,6 +61,11 @@ func (i *Info) Status() Status { return i.status } +// StatusCh returns the swap's status update channel. +func (i *Info) StatusCh() <-chan types.Status { + return i.statusCh +} + // SetReceivedAmount ... func (i *Info) SetReceivedAmount(a float64) { i.receivedAmount = a @@ -107,7 +87,7 @@ func (i *Info) SetStatus(s Status) { // NewInfo ... func NewInfo(provides types.ProvidesCoin, providedAmount, receivedAmount float64, - exchangeRate types.ExchangeRate, status Status) *Info { + exchangeRate types.ExchangeRate, status Status, statusCh <-chan types.Status) *Info { info := &Info{ id: nextID, provides: provides, @@ -115,6 +95,7 @@ func NewInfo(provides types.ProvidesCoin, providedAmount, receivedAmount float64 receivedAmount: receivedAmount, exchangeRate: exchangeRate, status: status, + statusCh: statusCh, } nextID++ return info @@ -139,8 +120,8 @@ func (m *Manager) AddSwap(info *Info) error { m.Lock() defer m.Unlock() - switch info.status { - case Ongoing: + switch info.status.IsOngoing() { + case true: if m.ongoing != nil { return errHaveOngoingSwap } diff --git a/protocol/swap/manager_test.go b/protocol/swap/manager_test.go index 45101caa..eb02c015 100644 --- a/protocol/swap/manager_test.go +++ b/protocol/swap/manager_test.go @@ -10,7 +10,7 @@ import ( func TestManager_AddSwap_Ongoing(t *testing.T) { m := NewManager() - info := NewInfo(types.ProvidesXMR, 1, 1, 0.1, Ongoing) + info := NewInfo(types.ProvidesXMR, 1, 1, 0.1, types.ExpectingKeys, nil) err := m.AddSwap(info) require.NoError(t, err) @@ -32,7 +32,7 @@ func TestManager_AddSwap_Past(t *testing.T) { info := &Info{ id: 1, - status: Success, + status: types.CompletedSuccess, } err := m.AddSwap(info) @@ -41,7 +41,7 @@ func TestManager_AddSwap_Past(t *testing.T) { info = &Info{ id: 2, - status: Success, + status: types.CompletedSuccess, } err = m.AddSwap(info) diff --git a/rpc/net.go b/rpc/net.go index ff372633..127f02d1 100644 --- a/rpc/net.go +++ b/rpc/net.go @@ -1,6 +1,7 @@ package rpc import ( + "errors" "fmt" "net/http" "time" @@ -136,31 +137,46 @@ type TakeOfferResponse struct { // TakeOffer initiates a swap with the given peer by taking an offer they've made. func (s *NetService) TakeOffer(_ *http.Request, req *TakeOfferRequest, resp *TakeOfferResponse) error { - swapState, err := s.alice.InitiateProtocol(req.ProvidesAmount) + id, _, err := s.takeOffer(req.Multiaddr, req.OfferID, req.ProvidesAmount) if err != nil { return err } + resp.ID = id + return nil +} + +func (s *NetService) takeOffer(multiaddr, offerID string, + providesAmount float64) (uint64, <-chan types.Status, error) { + swapState, err := s.alice.InitiateProtocol(providesAmount) + if err != nil { + return 0, nil, err + } + skm, err := swapState.SendKeysMessage() if err != nil { - return err + return 0, nil, err } - skm.OfferID = req.OfferID - skm.ProvidedAmount = req.ProvidesAmount + skm.OfferID = offerID + skm.ProvidedAmount = providesAmount - who, err := net.StringToAddrInfo(req.Multiaddr) + who, err := net.StringToAddrInfo(multiaddr) if err != nil { - return err + return 0, nil, err } if err = s.net.Initiate(who, skm, swapState); err != nil { _ = swapState.ProtocolExited() - return err + return 0, nil, err } - resp.ID = swapState.ID() - return nil + info := s.sm.GetOngoingSwap() + if info == nil { + return 0, nil, errors.New("failed to get swap info after initiating") + } + + return swapState.ID(), info.StatusCh(), nil } // TakeOfferSyncResponse ... diff --git a/rpc/server.go b/rpc/server.go index e4943337..f552e7e5 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -1,6 +1,7 @@ package rpc import ( + "context" "fmt" "net/http" @@ -28,6 +29,7 @@ type Server struct { // Config ... type Config struct { + Ctx context.Context Port uint16 WsPort uint16 Net Net @@ -40,7 +42,9 @@ type Config struct { func NewServer(cfg *Config) (*Server, error) { s := rpc.NewServer() s.RegisterCodec(NewCodec(), "application/json") - if err := s.RegisterService(NewNetService(cfg.Net, cfg.Alice, cfg.Bob, cfg.SwapManager), "net"); err != nil { //nolint:lll + + ns := NewNetService(cfg.Net, cfg.Alice, cfg.Bob, cfg.SwapManager) + if err := s.RegisterService(ns, "net"); err != nil { return nil, err } @@ -54,7 +58,7 @@ func NewServer(cfg *Config) (*Server, error) { return &Server{ s: s, - wsServer: newWsServer(cfg.SwapManager, cfg.Alice, cfg.Bob), + wsServer: newWsServer(cfg.Ctx, cfg.SwapManager, cfg.Alice, cfg.Bob, ns), port: cfg.Port, wsPort: cfg.WsPort, }, nil diff --git a/rpc/swap.go b/rpc/swap.go index 655610fc..df809058 100644 --- a/rpc/swap.go +++ b/rpc/swap.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" - "github.com/noot/atomic-swap/common" "github.com/noot/atomic-swap/common/types" ) @@ -129,19 +128,7 @@ func (s *SwapService) GetStage(_ *http.Request, _ *interface{}, resp *GetStageRe return errors.New("no current ongoing swap") } - var swapState common.SwapState - switch info.Provides() { - case types.ProvidesETH: - swapState = s.alice.GetOngoingSwapState() - case types.ProvidesXMR: - swapState = s.bob.GetOngoingSwapState() - } - - if swapState == nil { - return errors.New("failed to get current swap state") - } - - resp.Stage = swapState.Stage().String() - resp.Info = swapState.Stage().Info() + resp.Stage = info.Status().String() + resp.Info = info.Status().Info() return nil } diff --git a/rpc/ws.go b/rpc/ws.go index a3eebea1..cd0b9b74 100644 --- a/rpc/ws.go +++ b/rpc/ws.go @@ -1,13 +1,12 @@ package rpc import ( + "context" "encoding/json" "errors" "fmt" "net/http" - "time" - "github.com/noot/atomic-swap/common" "github.com/noot/atomic-swap/common/rpcclient" "github.com/noot/atomic-swap/common/types" @@ -15,25 +14,35 @@ import ( ) const ( - defaultJSONRPCVersion = "2.0" - subscribeNewPeer = "net_subscribeNewPeer" + subscribeTakeOffer = "net_takeOfferAndSubscribe" subscribeSwapStatus = "swap_subscribeStatus" ) var upgrader = websocket.Upgrader{} +//nolint:revive +type ( + Request = rpcclient.Request + Response = rpcclient.Response + SubscribeSwapStatusResponse = rpcclient.SubscribeSwapStatusResponse +) + type wsServer struct { + ctx context.Context sm SwapManager alice Alice bob Bob + ns *NetService } -func newWsServer(sm SwapManager, a Alice, b Bob) *wsServer { +func newWsServer(ctx context.Context, sm SwapManager, a Alice, b Bob, ns *NetService) *wsServer { return &wsServer{ + ctx: ctx, sm: sm, alice: a, bob: b, + ns: ns, } } @@ -69,14 +78,6 @@ func (s *wsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -// Request represents a JSON-RPC request -type Request struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params map[string]interface{} `json:"params"` - ID uint64 `json:"id"` -} - func (s *wsServer) handleRequest(conn *websocket.Conn, req *Request) error { switch req.Method { case subscribeNewPeer: @@ -92,80 +93,135 @@ func (s *wsServer) handleRequest(conn *websocket.Conn, req *Request) error { return fmt.Errorf("failed to cast id parameter to float64: got %T", idi) } - return s.subscribeSwapStatus(conn, uint64(id)) + return s.subscribeSwapStatus(s.ctx, conn, uint64(id)) + case subscribeTakeOffer: + maddri, has := req.Params["multiaddr"] + if !has { + return errors.New("params missing multiaddr field") + } + + maddr, ok := maddri.(string) + if !ok { + return fmt.Errorf("failed to cast multiaddr parameter to string: got %T", maddri) + } + + offerIDi, has := req.Params["offerID"] + if !has { + return errors.New("params missing offerID field") + } + + offerID, ok := offerIDi.(string) + if !ok { + return fmt.Errorf("failed to cast multiaddr parameter to string: got %T", offerIDi) + } + + providesi, has := req.Params["providesAmount"] + if !has { + return errors.New("params missing providesAmount field") + } + + providesAmount, ok := providesi.(float64) + if !ok { + return fmt.Errorf("failed to cast providesAmount parameter to float64: got %T", providesi) + } + + id, ch, err := s.ns.takeOffer(maddr, offerID, providesAmount) + if err != nil { + return err + } + + return s.subscribeTakeOffer(s.ctx, conn, id, ch) default: return errors.New("invalid method") } } -// SubscribeSwapStatusResponse ... -type SubscribeSwapStatusResponse struct { - Stage string `json:"stage"` -} +func (s *wsServer) subscribeTakeOffer(ctx context.Context, conn *websocket.Conn, + id uint64, statusCh <-chan types.Status) error { + // firstly write swap ID + idMsg := map[string]uint64{ + "id": id, + } + + if err := writeResponse(conn, idMsg); err != nil { + return err + } -// subscribeSwapStatus writes the swap's stage to the connection every time it updates. -// when the swap completes, it writes the final status then closes the connection. -// example: `{"jsonrpc":"2.0", "method":"swap_subscribeStatus", "params": {"id": 0}, "id": 0}` -func (s *wsServer) subscribeSwapStatus(conn *websocket.Conn, id uint64) error { - var prevStage common.Stage for { - info := s.sm.GetOngoingSwap() - if info == nil { - info = s.sm.GetPastSwap(id) - if info == nil { - return errors.New("unable to find swap with given ID") + select { + case status, ok := <-statusCh: + if !ok { + return nil } resp := &SubscribeSwapStatusResponse{ - Stage: info.Status().String(), + Stage: status.String(), } if err := writeResponse(conn, resp); err != nil { return err } - + case <-ctx.Done(): return nil } - - var swapState common.SwapState - switch info.Provides() { - case types.ProvidesETH: - swapState = s.alice.GetOngoingSwapState() - case types.ProvidesXMR: - swapState = s.bob.GetOngoingSwapState() - } - - if swapState == nil { - // we probably completed the swap, continue to call GetPastSwap - continue - } - - currStage := swapState.Stage() - if currStage == prevStage { - time.Sleep(time.Millisecond * 10) - continue - } - - resp := &SubscribeSwapStatusResponse{ - Stage: currStage.String(), - } - - if err := writeResponse(conn, resp); err != nil { - return err - } - - prevStage = currStage } } +// subscribeSwapStatus writes the swap's stage to the connection every time it updates. +// when the swap completes, it writes the final status then closes the connection. +// example: `{"jsonrpc":"2.0", "method":"swap_subscribeStatus", "params": {"id": 0}, "id": 0}` +func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn, id uint64) error { + info := s.sm.GetOngoingSwap() + if info == nil { + return s.writeSwapExitStatus(conn, id) + } + + statusCh := info.StatusCh() + for { + select { + case status, ok := <-statusCh: + if !ok { + return nil + } + + resp := &SubscribeSwapStatusResponse{ + Stage: status.String(), + } + + if err := writeResponse(conn, resp); err != nil { + return err + } + case <-ctx.Done(): + return nil + } + } +} + +func (s *wsServer) writeSwapExitStatus(conn *websocket.Conn, id uint64) error { + info := s.sm.GetPastSwap(id) + if info == nil { + return errors.New("unable to find swap with given ID") + } + + resp := &SubscribeSwapStatusResponse{ + Stage: info.Status().String(), + } + + if err := writeResponse(conn, resp); err != nil { + return err + } + + return nil +} + func writeResponse(conn *websocket.Conn, result interface{}) error { bz, err := json.Marshal(result) if err != nil { return err } - resp := &rpcclient.ServerResponse{ - Version: defaultJSONRPCVersion, + resp := &Response{ + Version: rpcclient.DefaultJSONRPCVersion, Result: bz, } @@ -173,8 +229,8 @@ func writeResponse(conn *websocket.Conn, result interface{}) error { } func writeError(conn *websocket.Conn, err error) error { - resp := &rpcclient.ServerResponse{ - Version: defaultJSONRPCVersion, + resp := &Response{ + Version: rpcclient.DefaultJSONRPCVersion, Error: &rpcclient.Error{ Message: err.Error(), }, diff --git a/scripts/build.sh b/scripts/build.sh index 01122615..4a0fd2cc 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -1,9 +1,20 @@ #!/bin/bash -cd cmd/daemon && go build -o swapd +cd cmd/daemon +if ! go build -o swapd ; then + exit 1 +fi mv swapd ../.. -cd ../client && go build -o swapcli + +cd ../client +if ! go build -o swapcli ; then + exit 1 +fi mv swapcli ../.. -cd ../recover && go build -o swaprecover + +cd ../recover +if ! go build -o swaprecover ; then + exit 1 +fi mv swaprecover ../.. cd ../.. \ No newline at end of file diff --git a/scripts/install_lint.sh b/scripts/install_lint.sh index 30b7c6ba..b8af8af9 100755 --- a/scripts/install_lint.sh +++ b/scripts/install_lint.sh @@ -6,7 +6,7 @@ fi if ! command -v golangci-lint &> /dev/null then - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.41.0 + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.44.2 fi export PATH=$PATH:$(go env GOPATH)/bin \ No newline at end of file diff --git a/scripts/setup-env.sh b/scripts/setup-env.sh new file mode 100644 index 00000000..45de91f4 --- /dev/null +++ b/scripts/setup-env.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# install monero and run daemon and wallet RPC servers for alice and bob +bash ./scripts/install-monero-linux.sh +echo "starting monerod..." +./monero-x86_64-linux-gnu-v0.17.3.0/monerod --detach --regtest --offline --fixed-difficulty=1 --rpc-bind-port 18081 & +sleep 5 + +echo "starting monero-wallet-rpc on port 18083..." +mkdir bob-test-keys +./monero-x86_64-linux-gnu-v0.17.3.0/monero-wallet-rpc --rpc-bind-port 18083 --disable-rpc-login --wallet-dir ./bob-test-keys &> monero-wallet-cli-bob.log & +MONERO_WALLET_CLI_BOB_PID=$! + +sleep 5 +curl http://localhost:18083/json_rpc -d '{"jsonrpc":"2.0","id":"0","method":"create_wallet","params":{"filename":"test-wallet","password":"","language":"English"}}' -H 'Content-Type: application/json' + +echo "starting monero-wallet-rpc on port 18084..." +mkdir alice-test-keys +./monero-x86_64-linux-gnu-v0.17.3.0/monero-wallet-rpc --rpc-bind-port 18084 --disable-rpc-login --wallet-dir ./alice-test-keys &> monero-wallet-cli-alice.log & +MONERO_WALLET_CLI_ALICE_PID=$! + +# install ganache and run +echo "installing and starting ganache-cli..." +if ! command -v golangci-lint &> /dev/null; then + npm i -g ganache-cli +fi +export NODE_OPTIONS=--max_old_space_size=8192 +ganache-cli -d &> ganache-cli.log & +GANACHE_CLI_PID=$! + +# wait for servers to start +sleep 10 diff --git a/tests/integration.go b/tests/integration.go new file mode 100644 index 00000000..e1362487 --- /dev/null +++ b/tests/integration.go @@ -0,0 +1,3 @@ +package tests + +// need this file or we get `no non-test Go files in ~/go/src/github.com/noot/atomic-swap/tests` diff --git a/tests/integration_test.go b/tests/integration_test.go index e22262fe..c90cbc92 100644 --- a/tests/integration_test.go +++ b/tests/integration_test.go @@ -32,9 +32,9 @@ func TestMain(m *testing.M) { } cmd := exec.Command("../scripts/build.sh") - err := cmd.Run() + out, err := cmd.CombinedOutput() if err != nil { - panic(err) + panic(fmt.Sprintf("%s\n%s", out, err)) } os.Exit(m.Run())