allow for multiple ongoing swaps (#128)

This commit is contained in:
noot
2022-06-13 20:50:10 -04:00
committed by GitHub
parent 11e3b27c2e
commit cca0b6c771
49 changed files with 1193 additions and 745 deletions

View File

@@ -1,4 +1,4 @@
.PHONY: lint test install build build-dleq
.PHONY: lint test install build build-dleq mock
all: build-dleq install
lint:
@@ -21,4 +21,7 @@ build-all:
ALL=true ./scripts/build.sh
build-dleq:
./scripts/install-rust.sh && cd farcaster-dleq && cargo build --release && cd ..
./scripts/install-rust.sh && cd farcaster-dleq && cargo build --release && cd ..
mock:
go generate -run mockgen ./...

View File

@@ -123,15 +123,21 @@ var (
Name: "get-ongoing-swap",
Usage: "get information about ongoing swap, if there is one",
Action: runGetOngoingSwap,
Flags: []cli.Flag{daemonAddrFlag},
Flags: []cli.Flag{
&cli.StringFlag{
Name: "offer-id",
Usage: "ID of swap to retrieve info for",
},
daemonAddrFlag,
},
},
{
Name: "get-past-swap",
Usage: "get information about a past swap with the given ID",
Action: runGetPastSwap,
Flags: []cli.Flag{
&cli.UintFlag{
Name: "id",
&cli.StringFlag{
Name: "offer-id",
Usage: "ID of swap to retrieve info for",
},
daemonAddrFlag,
@@ -141,19 +147,37 @@ var (
Name: "refund",
Usage: "if we are the ETH provider for an ongoing swap, refund it if possible.",
Action: runRefund,
Flags: []cli.Flag{daemonAddrFlag},
Flags: []cli.Flag{
&cli.StringFlag{
Name: "offer-id",
Usage: "ID of swap to retrieve info for",
},
daemonAddrFlag,
},
},
{
Name: "cancel",
Usage: "cancel the ongoing swap if possible.",
Usage: "cancel a ongoing swap if possible.",
Action: runCancel,
Flags: []cli.Flag{daemonAddrFlag},
Flags: []cli.Flag{
&cli.StringFlag{
Name: "offer-id",
Usage: "ID of swap to retrieve info for",
},
daemonAddrFlag,
},
},
{
Name: "get-stage",
Usage: "get the stage of the current swap.",
Usage: "get the stage of a current swap.",
Action: runGetStage,
Flags: []cli.Flag{daemonAddrFlag},
Flags: []cli.Flag{
&cli.StringFlag{
Name: "offer-id",
Usage: "ID of swap to retrieve info for",
},
daemonAddrFlag,
},
},
{
Name: "set-swap-timeout",
@@ -280,20 +304,12 @@ func runMake(ctx *cli.Context) error {
return err
}
id, takenCh, statusCh, err := c.MakeOfferAndSubscribe(min, max, types.ExchangeRate(exchangeRate))
id, statusCh, err := c.MakeOfferAndSubscribe(min, max, types.ExchangeRate(exchangeRate))
if err != nil {
return err
}
fmt.Printf("Made offer with ID=%s\n", id)
taken := <-takenCh
if taken == nil {
fmt.Printf("connection closed\n")
return nil
}
fmt.Printf("Offer taken! Swap ID=%d\n", taken.ID)
fmt.Printf("Made offer with ID %s\n", id)
for stage := range statusCh {
fmt.Printf("> Stage updated: %s\n", stage)
@@ -342,12 +358,12 @@ func runTake(ctx *cli.Context) error {
return err
}
id, statusCh, err := c.TakeOfferAndSubscribe(maddr, offerID, providesAmount)
statusCh, err := c.TakeOfferAndSubscribe(maddr, offerID, providesAmount)
if err != nil {
return err
}
fmt.Printf("Initiated swap with ID=%d\n", id)
fmt.Printf("Initiated swap with ID %s\n", offerID)
for stage := range statusCh {
fmt.Printf("> Stage updated: %s\n", stage)
@@ -360,12 +376,12 @@ func runTake(ctx *cli.Context) error {
}
c := rpcclient.NewClient(endpoint)
id, err := c.TakeOffer(maddr, offerID, providesAmount)
err := c.TakeOffer(maddr, offerID, providesAmount)
if err != nil {
return err
}
fmt.Printf("Initiated swap with ID=%d\n", id)
fmt.Printf("Initiated swap with ID %s\n", offerID)
return nil
}
@@ -391,14 +407,18 @@ func runGetOngoingSwap(ctx *cli.Context) error {
endpoint = defaultSwapdAddress
}
offerID := ctx.String("offer-id")
if offerID == "" {
return errNoOfferID
}
c := rpcclient.NewClient(endpoint)
info, err := c.GetOngoingSwap()
info, err := c.GetOngoingSwap(offerID)
if err != nil {
return err
}
fmt.Printf("ID: %d\n Provided: %s\n ProvidedAmount: %v\n ReceivedAmount: %v\n ExchangeRate: %v\n Status: %s\n",
info.ID,
fmt.Printf("Provided: %s\n ProvidedAmount: %v\n ReceivedAmount: %v\n ExchangeRate: %v\n Status: %s\n",
info.Provided,
info.ProvidedAmount,
info.ReceivedAmount,
@@ -409,21 +429,23 @@ func runGetOngoingSwap(ctx *cli.Context) error {
}
func runGetPastSwap(ctx *cli.Context) error {
id := ctx.Uint("id")
endpoint := ctx.String("daemon-addr")
if endpoint == "" {
endpoint = defaultSwapdAddress
}
offerID := ctx.String("offer-id")
if offerID == "" {
return errNoOfferID
}
c := rpcclient.NewClient(endpoint)
info, err := c.GetPastSwap(uint64(id))
info, err := c.GetPastSwap(offerID)
if err != nil {
return err
}
fmt.Printf("ID: %d\n Provided: %s\n ProvidedAmount: %v\n ReceivedAmount: %v\n ExchangeRate: %v\n Status: %s\n",
id,
fmt.Printf("Provided: %s\n ProvidedAmount: %v\n ReceivedAmount: %v\n ExchangeRate: %v\n Status: %s\n",
info.Provided,
info.ProvidedAmount,
info.ReceivedAmount,
@@ -439,8 +461,13 @@ func runRefund(ctx *cli.Context) error {
endpoint = defaultSwapdAddress
}
offerID := ctx.String("offer-id")
if offerID == "" {
return errNoOfferID
}
c := rpcclient.NewClient(endpoint)
resp, err := c.Refund()
resp, err := c.Refund(offerID)
if err != nil {
return err
}
@@ -455,8 +482,13 @@ func runCancel(ctx *cli.Context) error {
endpoint = defaultSwapdAddress
}
offerID := ctx.String("offer-id")
if offerID == "" {
return errNoOfferID
}
c := rpcclient.NewClient(endpoint)
resp, err := c.Cancel()
resp, err := c.Cancel(offerID)
if err != nil {
return err
}
@@ -471,8 +503,13 @@ func runGetStage(ctx *cli.Context) error {
endpoint = defaultSwapdAddress
}
offerID := ctx.String("offer-id")
if offerID == "" {
return errNoOfferID
}
c := rpcclient.NewClient(endpoint)
resp, err := c.GetStage()
resp, err := c.GetStage(offerID)
if err != nil {
return err
}

View File

@@ -320,7 +320,7 @@ func (d *daemon) takeOffer(done <-chan struct{}) {
start := time.Now()
log.Infof("node %d taking offer %s", d.idx, offer.GetID().String())
_, takerStatusCh, err := wsc.TakeOfferAndSubscribe(peer,
takerStatusCh, err := wsc.TakeOfferAndSubscribe(peer,
offer.GetID().String(), providesAmount)
if err != nil {
d.errCh <- err
@@ -370,7 +370,7 @@ func (d *daemon) makeOffer(done <-chan struct{}) {
defer wsc.Close()
offerID, takenCh, statusCh, err := wsc.MakeOfferAndSubscribe(minProvidesAmount,
offerID, statusCh, err := wsc.MakeOfferAndSubscribe(minProvidesAmount,
maxProvidesAmount,
getRandomExchangeRate(),
)
@@ -394,16 +394,6 @@ func (d *daemon) makeOffer(done <-chan struct{}) {
log.Infof("node %d made offer %s", d.idx, offerID)
select {
case <-done:
return
case taken := <-takenCh:
if taken == nil {
log.Warn("got nil from takenCh")
return
}
}
d.swapMu.Lock()
defer d.swapMu.Unlock()

View File

@@ -1,6 +1,7 @@
package common
import (
"github.com/noot/atomic-swap/common/types"
"github.com/noot/atomic-swap/net/message"
)
@@ -14,13 +15,14 @@ type SwapState interface {
// It is implemented by *xmrtaker.swapState and *xmrmaker.swapState
type SwapStateNet interface {
HandleProtocolMessage(msg message.Message) (resp message.Message, done bool, err error)
ID() types.Hash
Exit() error
}
// SwapStateRPC contains the methods used by the RPC server into the SwapState.
type SwapStateRPC interface {
SendKeysMessage() (*message.SendKeysMessage, error)
ID() uint64
ID() types.Hash
InfoFile() string
Exit() error
}

View File

@@ -6,7 +6,7 @@ import (
// SubscribeSwapStatusRequest ...
type SubscribeSwapStatusRequest struct {
ID uint64 `json:"id"`
ID types.Hash `json:"id"`
}
// SubscribeSwapStatusResponse ...
@@ -45,7 +45,6 @@ type TakeOfferRequest struct {
// TakeOfferResponse ...
type TakeOfferResponse struct {
ID uint64 `json:"id"`
InfoFile string `json:"infoFile"`
}

View File

@@ -1,6 +1,7 @@
package types
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
@@ -48,8 +49,13 @@ func (o *Offer) GetID() Hash {
panic(err)
}
o.ID = sha3.Sum256(b)
// TODO: add some randomness in here
var buf [8]byte
_, err = rand.Read(buf[:])
if err != nil {
panic(err)
}
o.ID = sha3.Sum256(append(b, buf[:]...))
return o.ID
}
@@ -66,7 +72,6 @@ func (o *Offer) String() string {
// OfferExtra represents extra data that is passed when an offer is made.
type OfferExtra struct {
IDCh chan uint64
StatusCh chan Status
InfoFile string
}

View File

@@ -16,12 +16,9 @@ Returns:
Example:
```
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"net_addresses","params":{}}' -H 'Content-Type: application/json'
```
```
{"jsonrpc":"2.0","result":{"addresses":["/ip4/192.168.0.101/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2","/ip4/127.0.0.1/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2","/ip4/38.88.101.233/tcp/14815/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2"]},"id":"0"}
# {"jsonrpc":"2.0","result":{"addresses":["/ip4/192.168.0.101/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2","/ip4/127.0.0.1/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2","/ip4/38.88.101.233/tcp/14815/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2"]},"id":"0"}
```
### `net_discover`
@@ -37,12 +34,9 @@ Returns:
Example:
```
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"net_discover","params":{"searchTime":3}}' -H 'Content-Type: application/json'
```
```
{"jsonrpc":"2.0","result":{"peers":[["/ip4/127.0.0.1/tcp/9934/p2p/12D3KooWHLUrLnJtUbaGzTSi6azZavKhNgUZTtSiUZ9Uy12v1eZ7","/ip4/192.168.0.101/tcp/9934/p2p/12D3KooWHLUrLnJtUbaGzTSi6azZavKhNgUZTtSiUZ9Uy12v1eZ7"]]},"id":"0"}
# {"jsonrpc":"2.0","result":{"peers":[["/ip4/127.0.0.1/tcp/9934/p2p/12D3KooWHLUrLnJtUbaGzTSi6azZavKhNgUZTtSiUZ9Uy12v1eZ7","/ip4/192.168.0.101/tcp/9934/p2p/12D3KooWHLUrLnJtUbaGzTSi6azZavKhNgUZTtSiUZ9Uy12v1eZ7"]]},"id":"0"}
```
### `net_queryPeer`
@@ -57,12 +51,9 @@ Returns:
Example:
```
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"net_queryPeer","params":{"multiaddr":"/ip4/192.168.0.101/tcp/9934/p2p/12D3KooWHLUrLnJtUbaGzTSi6azZavKhNgUZTtSiUZ9Uy12v1eZ7"}}' -H 'Content-Type: application/json'
```
```
{"jsonrpc":"2.0","result":{"offers":[{"ID":[207,75,240,26,7,117,160,209,63,164,27,20,81,110,75,137,3,67,0,112,122,23,84,224,217,155,101,246,203,111,255,185],"Provides":"XMR","MinimumAmount":0.1,"MaximumAmount":1,"ExchangeRate":0.05}]},"id":"0"}
# {"jsonrpc":"2.0","result":{"offers":[{"ID":[207,75,240,26,7,117,160,209,63,164,27,20,81,110,75,137,3,67,0,112,122,23,84,224,217,155,101,246,203,111,255,185],"Provides":"XMR","MinimumAmount":0.1,"MaximumAmount":1,"ExchangeRate":0.05}]},"id":"0"}
```
### `net_makeOffer`
@@ -78,12 +69,9 @@ Returns:
- `offerID`: ID of the swap offer.
Example:
```
```bash
curl -X POST http://127.0.0.1:5002 -d '{"jsonrpc":"2.0","id":"0","method":"net_makeOffer","params":{"minimumAmount":1, "maximumAmount":10, "exchangeRate": 0.1}}' -H 'Content-Type: application/json'
```
```
{"jsonrpc":"2.0","result":{"offerID":"12b9d56a4c568c772a4e099aaed03a457256d6680562be2a518753f75d75b7ad"},"id":"0"}
# {"jsonrpc":"2.0","result":{"offerID":"12b9d56a4c568c772a4e099aaed03a457256d6680562be2a518753f75d75b7ad"},"id":"0"}
```
@@ -97,15 +85,12 @@ Parameters:
- `providesAmount`: amount of ETH you will be providing. Must be between the offer's `minimumAmount * exchangeRate` and `maximumAmount * exchangeRate`. For example, if the offer has a minimum of 1 XMR and a maximum of 5 XMR and an exchange rate of 0.1, you must provide between 0.1 ETH and 0.5 ETH.
Returns:
- `id`: ID of the initiated swap.
- null
Example:
```
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"net_takeOffer","params":{"multiaddr":"/ip4/192.168.0.101/tcp/9934/p2p/12D3KooWHLUrLnJtUbaGzTSi6azZavKhNgUZTtSiUZ9Uy12v1eZ7", "offerID":"12b9d56a4c568c772a4e099aaed03a457256d6680562be2a518753f75d75b7ad", "providesAmount": 0.3}}' -H 'Content-Type: application/json'
```
```
{"jsonrpc":"2.0","result":{"id":1},"id":"0"}
# {"jsonrpc":"2.0","result":null,"id":"0"}
```
### `net_takeOfferSync`
@@ -118,16 +103,12 @@ Parameters:
- `providesAmount`: amount of ETH you will be providing. Must be between the offer's `minimumAmount * exchangeRate` and `maximumAmount * exchangeRate`. For example, if the offer has a minimum of 1 XMR and a maximum of 5 XMR and an exchange rate of 0.1, you must provide between 0.1 ETH and 0.5 ETH.
Returns:
- `id`: ID of the initiated swap.
- `status`: the swap's status, one of `success`, `refunded`, or `aborted`.
Example:
```
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"net_takeOffer","params":{"multiaddr":"/ip4/192.168.0.101/tcp/9934/p2p/12D3KooWHLUrLnJtUbaGzTSi6azZavKhNgUZTtSiUZ9Uy12v1eZ7", "offerID":"12b9d56a4c568c772a4e099aaed03a457256d6680562be2a518753f75d75b7ad", "providesAmount": 0.3}}' -H 'Content-Type: application/json'
```
```
{"jsonrpc":"2.0","result":{"id":1,"status":"success"},"id":"0"}
# {"jsonrpc":"2.0","result":{status":"success"},"id":"0"}
```
@@ -145,15 +126,29 @@ Returns:
- none
Example:
```
```bash
curl -X POST http://127.0.0.1:5002 -d '{"jsonrpc":"2.0","id":"0","method":"personal_setMoneroWalletFile","params":{"walletFile":"test-wallet", "walletPassword": ""}}' -H 'Content-Type: application/json'
```
```
{"jsonrpc":"2.0","result":null,"id":"0"}
#{"jsonrpc":"2.0","result":null,"id":"0"}
```
## `swap` namespace
### `swap_cancel`
Attempts to cancel an ongoing swap.
Parameters:
- `id`: id of the swap to refund
Returns:
- `status`: exit status of the swap.
Example:
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"swap_cancel","params":{"id": "17c01ad48a1f75c1456932b12cb51d430953bb14ffe097195b1f8cace7776e70"}}' -H 'Content-Type: application/json'
# {"jsonrpc":"2.0","result":{"status":"Success"},"id":"0"}
```
### `swap_getOngoing`
Gets information about the ongoing swap, if there is one.
@@ -162,7 +157,7 @@ Parameters:
- none
Returns:
- `id`: the swap's ID. **Note: this is not the same as an offer ID.**
- `id`: the swap's ID.
- `provided`: the coin provided during the swap.
- `providedAmount`: the amount of coin provided during the swap.
- `receivedAmount`: the amount of coin expected to be received during the swap.
@@ -170,11 +165,9 @@ Returns:
- `status`: the swap's status; should always be "ongoing".
Example:
```
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"swap_getOngoing","params":{}}' -H 'Content-Type: application/json'
```
```
{"jsonrpc":"2.0","result":{"id":3,"provided":"ETH","providedAmount":0.05,"receivedAmount":0,"exchangeRate":0,"status":"ongoing"},"id":"0"}
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"swap_getOngoing","params":{"id":"17c01ad48a1f75c1456932b12cb51d430953bb14ffe097195b1f8cace7776e70"}}' -H 'Content-Type: application/json'
# {"jsonrpc":"2.0","result":{"id":3,"provided":"ETH","providedAmount":0.05,"receivedAmount":0,"exchangeRate":0,"status":"ongoing"},"id":"0"}
```
### `swap_getPastIDs`
@@ -188,12 +181,9 @@ Returns:
- `ids`: a list of all past swap IDs.
Example:
```
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"swap_getPastIDs","params":{}}' -H 'Content-Type: application/json'
```
```
{"jsonrpc":"2.0","result":{"ids":[2,3,0,1]},"id":"0"}
# {"jsonrpc":"2.0","result":{"ids":["7492ceb4d0f5f45ecd5d06923b35cae406d1406cd685ce1ba184f2a40c683ac2","17c01ad48a1f75c1456932b12cb51d430953bb14ffe097195b1f8cace7776e70"]},"id":"0"}
```
### `swap_getPast`
@@ -212,10 +202,27 @@ Returns:
Example:
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"swap_getPast","params":{"id": 0}}' -H 'Content-Type: application/json'
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"swap_getPast","params":{"id": "17c01ad48a1f75c1456932b12cb51d430953bb14ffe097195b1f8cace7776e70"}}' -H 'Content-Type: application/json'
# {"jsonrpc":"2.0","result":{"provided":"ETH","providedAmount":0.05,"receivedAmount":1,"exchangeRate":20,"status":"success"},"id":"0"}
```
### `swap_getStage`
Gets the stage of an ongoing swap.
Parameters:
- `id`: id of the swap to get the stage of
Returns:
- `stage`: stage of the swap
- `info`: description of the swap's stage
Example:
```bash
curl -X POST http://127.0.0.1:5001 -d '{"jsonrpc":"2.0","id":"0","method":"swap_getStage","params":{"id": "17c01ad48a1f75c1456932b12cb51d430953bb14ffe097195b1f8cace7776e70"}}' -H 'Content-Type: application/json'
# {"jsonrpc":"2.0","result":{"stage":"KeysExchanged", "info":"keys have been exchanged, but no value has been locked"},"id":"0"}
```
## websocket subscriptions
The daemon also runs a websockets server that can be used to subscribe to push notifications for updates. You can use the command-line tool `wscat` to easily connect to a websockets server.
@@ -234,7 +241,7 @@ Example:
```bash
wscat -c ws://localhost:8081
# Connected (press CTRL+C to quit)
# > {"jsonrpc":"2.0", "method":"swap_subscribeStatus", "params": {"id": 0}, "id": 0}
# > {"jsonrpc":"2.0", "method":"swap_subscribeStatus", "params": {"id": "7492ceb4d0f5f45ecd5d06923b35cae406d1406cd685ce1ba184f2a40c683ac2"}, "id": 0}
# < {"jsonrpc":"2.0","result":{"stage":"ETHLocked"},"error":null,"id":null}
# < {"jsonrpc":"2.0","result":{"stage":"refunded"},"error":null,"id":null}
```

View File

@@ -1,5 +1,4 @@
const { expect } = require("chai");
const secp = require('noble-secp256k1');
const arrayify = ethers.utils.arrayify;

View File

@@ -1,6 +1,8 @@
package monero
import (
"sync"
"github.com/noot/atomic-swap/common"
"github.com/noot/atomic-swap/common/rpctypes"
mcrypto "github.com/noot/atomic-swap/crypto/monero"
@@ -8,6 +10,8 @@ import (
// Client represents a monero-wallet-rpc client.
type Client interface {
LockClient() // can't use Lock/Unlock due to name conflict
UnlockClient()
GetAccounts() (*GetAccountsResponse, error)
GetAddress(idx uint) (*GetAddressResponse, error)
GetBalance(idx uint) (*GetBalanceResponse, error)
@@ -23,6 +27,7 @@ type Client interface {
}
type client struct {
sync.Mutex
endpoint string
}
@@ -33,6 +38,14 @@ func NewClient(endpoint string) *client { //nolint:revive
}
}
func (c *client) LockClient() {
c.Lock()
}
func (c *client) UnlockClient() {
c.Unlock()
}
func (c *client) GetAccounts() (*GetAccountsResponse, error) {
return c.callGetAccounts()
}

View File

@@ -36,10 +36,15 @@ type Host interface {
Discover(provides types.ProvidesCoin, searchTime time.Duration) ([]peer.AddrInfo, error)
Query(who peer.AddrInfo) (*QueryResponse, error)
Initiate(who peer.AddrInfo, msg *SendKeysMessage, s common.SwapState) error
Initiate(who peer.AddrInfo, msg *SendKeysMessage, s common.SwapStateNet) error
MessageSender
}
type swap struct {
swapState SwapState
stream libp2pnetwork.Stream
}
type host struct {
ctx context.Context
cancel context.CancelFunc
@@ -51,9 +56,8 @@ type host struct {
handler Handler
// swap instance info
swapMu sync.Mutex
swapState SwapState
swapStream libp2pnetwork.Stream
swapMu sync.Mutex
swaps map[types.Hash]*swap
queryMu sync.Mutex
queryBuf []byte
@@ -149,7 +153,8 @@ func NewHost(cfg *Config) (*host, error) { //nolint:revive
h: h,
handler: cfg.Handler,
bootnodes: bns,
queryBuf: make([]byte, 2048),
queryBuf: make([]byte, 1024*5),
swaps: make(map[types.Hash]*swap),
}
hst.discovery, err = newDiscovery(ourCtx, h, hst.getBootnodes)
@@ -233,15 +238,16 @@ func (h *host) Discover(provides types.ProvidesCoin, searchTime time.Duration) (
}
// SendSwapMessage sends a message to the peer who we're currently doing a swap with.
func (h *host) SendSwapMessage(msg Message) error {
func (h *host) SendSwapMessage(msg Message, id types.Hash) error {
h.swapMu.Lock()
defer h.swapMu.Unlock()
if h.swapStream == nil {
swap, has := h.swaps[id]
if !has {
return errNoOngoingSwap
}
return h.writeToStream(h.swapStream, msg)
return h.writeToStream(swap.stream, msg)
}
func (h *host) getBootnodes() []peer.AddrInfo {

View File

@@ -3,27 +3,51 @@ package net
import (
"context"
"fmt"
"os"
"testing"
"github.com/noot/atomic-swap/common"
"github.com/noot/atomic-swap/common/types"
logging "github.com/ipfs/go-log"
"github.com/stretchr/testify/require"
)
var defaultPort uint16 = 5001
func TestMain(m *testing.M) {
logging.SetLogLevel("net", "debug")
m.Run()
os.Exit(0)
}
type mockHandler struct{}
var defaultPort uint16 = 5001
var testID = types.Hash{99}
type mockHandler struct {
id types.Hash
}
func (h *mockHandler) GetOffers() []*types.Offer {
return []*types.Offer{}
}
func (h *mockHandler) HandleInitiateMessage(msg *SendKeysMessage) (s SwapState, resp Message, err error) {
if (h.id != types.Hash{}) {
return &mockSwapState{h.id}, &SendKeysMessage{}, nil
}
return &mockSwapState{}, &SendKeysMessage{}, nil
}
type mockSwapState struct{}
type mockSwapState struct {
id types.Hash
}
func (s *mockSwapState) ID() types.Hash {
if (s.id != types.Hash{}) {
return s.id
}
return testID
}
func (s *mockSwapState) HandleProtocolMessage(msg Message) (resp Message, done bool, err error) {
return nil, false, nil

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/noot/atomic-swap/common"
"github.com/noot/atomic-swap/common/types"
"github.com/noot/atomic-swap/net/message"
libp2pnetwork "github.com/libp2p/go-libp2p-core/network"
@@ -18,17 +19,20 @@ const (
protocolTimeout = time.Second * 5
)
func (h *host) Initiate(who peer.AddrInfo, msg *SendKeysMessage, s common.SwapState) error {
func (h *host) Initiate(who peer.AddrInfo, msg *SendKeysMessage, s common.SwapStateNet) error {
h.swapMu.Lock()
defer h.swapMu.Unlock()
if h.swapState != nil {
id := s.ID()
if h.swaps[id] != nil {
return errSwapAlreadyInProgress
}
ctx, cancel := context.WithTimeout(h.ctx, protocolTimeout)
defer cancel()
// TODO: check if already connected
if err := h.h.Connect(ctx, who); err != nil {
return err
}
@@ -47,41 +51,88 @@ func (h *host) Initiate(who peer.AddrInfo, msg *SendKeysMessage, s common.SwapSt
return err
}
h.swapState = s
h.swapStream = stream
go h.handleProtocolStreamInner(stream)
h.swaps[id] = &swap{
swapState: s,
stream: stream,
}
go h.handleProtocolStreamInner(stream, s)
return nil
}
// handleProtocolStream is called when there is an incoming protocol stream.
func (h *host) handleProtocolStream(stream libp2pnetwork.Stream) {
if h.handler == nil {
_ = stream.Close()
return
}
// TODO: don't allocate this twice
msgBytes := make([]byte, 1<<17)
tot, err := readStream(stream, msgBytes[:])
if err != nil {
log.Debug("peer closed stream with us, protocol exited")
_ = stream.Close()
return
}
// decode message based on message type
msg, err := message.DecodeMessage(msgBytes[:tot])
if err != nil {
log.Debug("failed to decode message from peer, id=", stream.ID(), " protocol=", stream.Protocol(), " err=", err)
_ = stream.Close()
return
}
log.Debug(
"received message from peer, peer=", stream.Conn().RemotePeer(), " type=", msg.Type(),
)
im, ok := msg.(*SendKeysMessage)
if !ok {
log.Warnf("failed to handle protocol message: message was not SendKeysMessage")
_ = stream.Close()
return
}
var s SwapState
s, resp, err := h.handler.HandleInitiateMessage(im)
if err != nil {
log.Warnf("failed to handle protocol message: err=%s", err)
_ = stream.Close()
return
}
if err := h.writeToStream(stream, resp); err != nil {
log.Warnf("failed to send response to peer: err=%s", err)
_ = s.Exit()
_ = stream.Close()
return
}
h.swapMu.Lock()
defer h.swapMu.Unlock()
if h.swapState != nil {
log.Debug("failed to handling incoming swap stream, already have ongoing swap")
h.swaps[s.ID()] = &swap{
swapState: s,
stream: stream,
}
h.swapMu.Unlock()
h.swapStream = stream
h.handleProtocolStreamInner(stream)
h.handleProtocolStreamInner(stream, s)
}
// handleProtocolStreamInner is called to handle a protocol stream, in both ingoing and outgoing cases.
func (h *host) handleProtocolStreamInner(stream libp2pnetwork.Stream) {
func (h *host) handleProtocolStreamInner(stream libp2pnetwork.Stream, s SwapState) {
defer func() {
log.Debugf("closing stream: peer=%s protocol=%s", stream.Conn().RemotePeer(), stream.Protocol())
_ = stream.Close()
if h.swapState != nil {
log.Debugf("exiting swap...")
if err := h.swapState.Exit(); err != nil {
log.Errorf("failed to exit protocol: err=%s", err)
}
h.swapState = nil
log.Debugf("exiting swap...")
if err := s.Exit(); err != nil {
log.Errorf("failed to exit protocol: err=%s", err)
}
h.swapMu.Lock()
delete(h.swaps, s.ID())
h.swapMu.Unlock()
}()
msgBytes := make([]byte, 1<<17)
@@ -104,32 +155,10 @@ func (h *host) handleProtocolStreamInner(stream libp2pnetwork.Stream) {
"received message from peer, peer=", stream.Conn().RemotePeer(), " type=", msg.Type(),
)
var (
resp Message
done bool
)
if h.swapState == nil {
im, ok := msg.(*SendKeysMessage)
if !ok {
log.Warnf("failed to handle protocol message: message was not SendKeysMessage")
return
}
var s SwapState
s, resp, err = h.handler.HandleInitiateMessage(im)
if err != nil {
log.Warnf("failed to handle protocol message: err=%s", err)
return
}
h.swapState = s
} else {
resp, done, err = h.swapState.HandleProtocolMessage(msg)
if err != nil {
log.Warnf("failed to handle protocol message: err=%s", err)
return
}
resp, done, err := s.HandleProtocolMessage(msg)
if err != nil {
log.Warnf("failed to handle protocol message: err=%s", err)
return
}
if resp == nil {
@@ -149,8 +178,14 @@ func (h *host) handleProtocolStreamInner(stream libp2pnetwork.Stream) {
}
// CloseProtocolStream closes the current swap protocol stream.
func (h *host) CloseProtocolStream() {
stream := h.swapStream
log.Debugf("closing stream: peer=%s protocol=%s", stream.Conn().RemotePeer(), stream.Protocol())
_ = stream.Close()
func (h *host) CloseProtocolStream(id types.Hash) {
swap, has := h.swaps[id]
if !has {
return
}
log.Debugf("closing stream: peer=%s protocol=%s",
swap.stream.Conn().RemotePeer(), swap.stream.Protocol(),
)
_ = swap.stream.Close()
}

View File

@@ -4,6 +4,8 @@ import (
"testing"
"time"
"github.com/noot/atomic-swap/common/types"
"github.com/stretchr/testify/require"
)
@@ -23,11 +25,42 @@ func TestHost_Initiate(t *testing.T) {
err = ha.h.Connect(ha.ctx, hb.addrInfo())
require.NoError(t, err)
err = ha.Initiate(hb.addrInfo(), &SendKeysMessage{}, nil)
err = ha.Initiate(hb.addrInfo(), &SendKeysMessage{}, new(mockSwapState))
require.NoError(t, err)
time.Sleep(time.Millisecond * 500)
require.NotNil(t, ha.swapStream)
require.NotNil(t, hb.swapStream)
require.NotNil(t, ha.swapState)
require.NotNil(t, hb.swapState)
require.NotNil(t, ha.swaps[testID])
require.NotNil(t, hb.swaps[testID])
}
func TestHost_ConcurrentSwaps(t *testing.T) {
ha := newHost(t, defaultPort)
err := ha.Start()
require.NoError(t, err)
hb := newHost(t, defaultPort+1)
err = hb.Start()
require.NoError(t, err)
testID2 := types.Hash{98}
defer func() {
_ = ha.Stop()
_ = hb.Stop()
}()
err = ha.h.Connect(ha.ctx, hb.addrInfo())
require.NoError(t, err)
err = ha.Initiate(hb.addrInfo(), &SendKeysMessage{}, new(mockSwapState))
require.NoError(t, err)
time.Sleep(time.Millisecond * 500)
require.NotNil(t, ha.swaps[testID])
require.NotNil(t, hb.swaps[testID])
hb.handler.(*mockHandler).id = testID2
err = ha.Initiate(hb.addrInfo(), &SendKeysMessage{}, &mockSwapState{testID2})
require.NoError(t, err)
time.Sleep(time.Millisecond * 1500)
require.NotNil(t, ha.swaps[testID2])
require.NotNil(t, hb.swaps[testID2])
}

View File

@@ -18,7 +18,7 @@ type (
// MessageSender is implemented by a Host
type MessageSender interface {
SendSwapMessage(Message) error
SendSwapMessage(Message, types.Hash) error
}
// Handler handles swap initiation messages.

View File

@@ -4,15 +4,17 @@ import (
"context"
"crypto/ecdsa"
"math/big"
"sync"
"time"
eth "github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
ethtypes "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/noot/atomic-swap/common"
"github.com/noot/atomic-swap/common/types"
mcrypto "github.com/noot/atomic-swap/crypto/monero"
"github.com/noot/atomic-swap/monero"
"github.com/noot/atomic-swap/net"
@@ -45,11 +47,11 @@ type Backend interface {
// ethclient methods
BalanceAt(ctx context.Context, account ethcommon.Address, blockNumber *big.Int) (*big.Int, error)
CodeAt(ctx context.Context, account ethcommon.Address, blockNumber *big.Int) ([]byte, error)
FilterLogs(ctx context.Context, q eth.FilterQuery) ([]types.Log, error)
TransactionReceipt(ctx context.Context, txHash ethcommon.Hash) (*types.Receipt, error)
FilterLogs(ctx context.Context, q eth.FilterQuery) ([]ethtypes.Log, error)
TransactionReceipt(ctx context.Context, txHash ethcommon.Hash) (*ethtypes.Receipt, error)
// helpers
WaitForReceipt(ctx context.Context, txHash ethcommon.Hash) (*types.Receipt, error)
WaitForReceipt(ctx context.Context, txHash ethcommon.Hash) (*ethtypes.Receipt, error)
NewSwapFactory(addr ethcommon.Address) (*swapfactory.SwapFactory, error)
// getters
@@ -65,13 +67,14 @@ type Backend interface {
Net() net.MessageSender
SwapTimeout() time.Duration
ExternalSender() *txsender.ExternalSender
XMRDepositAddress() mcrypto.Address
XMRDepositAddress(id *types.Hash) (mcrypto.Address, error)
// setters
SetSwapTimeout(timeout time.Duration)
SetGasPrice(uint64)
SetEthAddress(ethcommon.Address)
SetXMRDepositAddress(mcrypto.Address)
SetXMRDepositAddress(mcrypto.Address, types.Hash)
SetBaseXMRDepositAddress(mcrypto.Address)
SetContract(*swapfactory.SwapFactory)
SetContractAddress(ethcommon.Address)
}
@@ -86,7 +89,9 @@ type backend struct {
monero.DaemonClient
// monero deposit address (used if xmrtaker has transferBack set to true)
xmrDepositAddr mcrypto.Address
sync.RWMutex
baseXMRDepositAddr *mcrypto.Address
xmrDepositAddrs map[types.Hash]mcrypto.Address
// ethereum endpoint and variables
ethClient *ethclient.Client
@@ -185,16 +190,17 @@ func NewBackend(cfg *Config) (Backend, error) {
From: addr,
Context: cfg.Ctx,
},
Sender: sender,
ethAddress: addr,
chainID: cfg.ChainID,
gasPrice: cfg.GasPrice,
gasLimit: cfg.GasLimit,
contract: cfg.SwapContract,
contractAddr: cfg.SwapContractAddress,
swapManager: cfg.SwapManager,
swapTimeout: defaultTimeoutDuration,
MessageSender: cfg.Net,
Sender: sender,
ethAddress: addr,
chainID: cfg.ChainID,
gasPrice: cfg.GasPrice,
gasLimit: cfg.GasLimit,
contract: cfg.SwapContract,
contractAddr: cfg.SwapContractAddress,
swapManager: cfg.SwapManager,
swapTimeout: defaultTimeoutDuration,
MessageSender: cfg.Net,
xmrDepositAddrs: make(map[types.Hash]mcrypto.Address),
}, nil
}
@@ -270,11 +276,11 @@ func (b *backend) CodeAt(ctx context.Context, account ethcommon.Address, blockNu
return b.ethClient.CodeAt(ctx, account, blockNumber)
}
func (b *backend) FilterLogs(ctx context.Context, q eth.FilterQuery) ([]types.Log, error) {
func (b *backend) FilterLogs(ctx context.Context, q eth.FilterQuery) ([]ethtypes.Log, error) {
return b.ethClient.FilterLogs(ctx, q)
}
func (b *backend) TransactionReceipt(ctx context.Context, txHash ethcommon.Hash) (*types.Receipt, error) {
func (b *backend) TransactionReceipt(ctx context.Context, txHash ethcommon.Hash) (*ethtypes.Receipt, error) {
return b.ethClient.TransactionReceipt(ctx, txHash)
}
@@ -289,12 +295,28 @@ func (b *backend) TxOpts() (*bind.TransactOpts, error) {
return txOpts, nil
}
func (b *backend) XMRDepositAddress() mcrypto.Address {
return b.xmrDepositAddr
func (b *backend) XMRDepositAddress(id *types.Hash) (mcrypto.Address, error) {
b.RLock()
defer b.RUnlock()
if id == nil && b.baseXMRDepositAddr == nil {
return "", errNoXMRDepositAddress
} else if id == nil {
return *b.baseXMRDepositAddr, nil
}
addr, has := b.xmrDepositAddrs[*id]
if !has && b.baseXMRDepositAddr == nil {
return "", errNoXMRDepositAddress
} else if !has {
return *b.baseXMRDepositAddr, nil
}
return addr, nil
}
// WaitForReceipt waits for the receipt for the given transaction to be available and returns it.
func (b *backend) WaitForReceipt(ctx context.Context, txHash ethcommon.Hash) (*types.Receipt, error) {
func (b *backend) WaitForReceipt(ctx context.Context, txHash ethcommon.Hash) (*ethtypes.Receipt, error) {
for i := 0; i < maxRetries; i++ {
receipt, err := b.ethClient.TransactionReceipt(ctx, txHash)
if err != nil {
@@ -327,10 +349,19 @@ func (b *backend) SetEthAddress(addr ethcommon.Address) {
b.ethAddress = addr
}
func (b *backend) SetXMRDepositAddress(addr mcrypto.Address) {
b.xmrDepositAddr = addr
func (b *backend) SetBaseXMRDepositAddress(addr mcrypto.Address) {
b.baseXMRDepositAddr = &addr
}
func (b *backend) SetXMRDepositAddress(addr mcrypto.Address, id types.Hash) {
b.Lock()
defer b.Unlock()
// TODO: clear this out when swap is done, memory leak!!!
b.xmrDepositAddrs[id] = addr
}
// TODO: these are kinda sus, maybe remove them? forces everyone to use
// the same contract though
func (b *backend) SetContract(contract *swapfactory.SwapFactory) {
b.contract = contract
b.Sender.SetContract(contract)

View File

@@ -8,4 +8,5 @@ var (
errMustProvideDaemonEndpoint = errors.New("environment is development, must provide monero daemon endpoint")
errNilSwapContractOrAddress = errors.New("must provide swap contract and address")
errReceiptTimeOut = errors.New("failed to get receipt, timed out")
errNoXMRDepositAddress = errors.New("no xmr deposit address for given id")
)

View File

@@ -1,9 +0,0 @@
package swap
import (
"errors"
)
var (
errHaveOngoingSwap = errors.New("already have ongoing swap")
)

View File

@@ -6,15 +6,13 @@ import (
"github.com/noot/atomic-swap/common/types"
)
var nextID uint64
type (
Status = types.Status //nolint:revive
)
// Info contains the details of the swap as well as its status.
type Info struct {
id uint64 // ID number of the swap (not the swap offer ID!)
id types.Hash // swap offer ID
provides types.ProvidesCoin
providedAmount float64
receivedAmount float64
@@ -24,9 +22,9 @@ type Info struct {
}
// ID returns the swap ID.
func (i *Info) ID() uint64 {
func (i *Info) ID() types.Hash {
if i == nil {
return 0
return types.Hash{} // TODO: does this ever happen??
}
return i.id
@@ -80,10 +78,10 @@ func (i *Info) SetStatus(s Status) {
}
// NewInfo ...
func NewInfo(provides types.ProvidesCoin, providedAmount, receivedAmount float64,
func NewInfo(id types.Hash, provides types.ProvidesCoin, providedAmount, receivedAmount float64,
exchangeRate types.ExchangeRate, status Status, statusCh <-chan types.Status) *Info {
info := &Info{
id: nextID,
id: id,
provides: provides,
providedAmount: providedAmount,
receivedAmount: receivedAmount,
@@ -91,31 +89,29 @@ func NewInfo(provides types.ProvidesCoin, providedAmount, receivedAmount float64
status: status,
statusCh: statusCh,
}
nextID++
return info
}
// Manager tracks current and past swaps.
type Manager interface {
AddSwap(info *Info) error
GetPastIDs() []uint64
GetPastSwap(id uint64) *Info
GetOngoingSwap() *Info
CompleteOngoingSwap()
GetPastIDs() []types.Hash
GetPastSwap(types.Hash) *Info
GetOngoingSwap(types.Hash) *Info
CompleteOngoingSwap(types.Hash)
}
type manager struct {
sync.RWMutex
ongoing *Info
past map[uint64]*Info
offersTaken map[string]uint64 // map of offerID -> swapID
ongoing map[types.Hash]*Info
past map[types.Hash]*Info
}
// NewManager ...
func NewManager() Manager {
return &manager{
past: make(map[uint64]*Info),
offersTaken: make(map[string]uint64),
ongoing: make(map[types.Hash]*Info),
past: make(map[types.Hash]*Info),
}
}
@@ -126,11 +122,7 @@ func (m *manager) AddSwap(info *Info) error {
switch info.status.IsOngoing() {
case true:
if m.ongoing != nil {
return errHaveOngoingSwap
}
m.ongoing = info
m.ongoing[info.id] = info
default:
m.past[info.id] = info
}
@@ -139,10 +131,10 @@ func (m *manager) AddSwap(info *Info) error {
}
// GetPastIDs returns all past swap IDs.
func (m *manager) GetPastIDs() []uint64 {
func (m *manager) GetPastIDs() []types.Hash {
m.RLock()
defer m.RUnlock()
ids := make([]uint64, len(m.past))
ids := make([]types.Hash, len(m.past))
i := 0
for id := range m.past {
ids[i] = id
@@ -152,25 +144,28 @@ func (m *manager) GetPastIDs() []uint64 {
}
// GetPastSwap returns a swap's *Info given its ID.
func (m *manager) GetPastSwap(id uint64) *Info {
func (m *manager) GetPastSwap(id types.Hash) *Info {
m.RLock()
defer m.RUnlock()
return m.past[id]
}
// GetOngoingSwap returns the ongoing swap's *Info, if there is one.
func (m *manager) GetOngoingSwap() *Info {
return m.ongoing
func (m *manager) GetOngoingSwap(id types.Hash) *Info {
m.RLock()
defer m.RUnlock()
return m.ongoing[id]
}
// CompleteOngoingSwap marks the current ongoing swap as completed.
func (m *manager) CompleteOngoingSwap() {
func (m *manager) CompleteOngoingSwap(id types.Hash) {
m.Lock()
defer m.Unlock()
if m.ongoing == nil {
s, has := m.ongoing[id]
if !has {
return
}
m.past[m.ongoing.id] = m.ongoing
m.ongoing = nil
m.past[id] = s
delete(m.ongoing, id)
}

View File

@@ -10,43 +10,42 @@ import (
func TestManager_AddSwap_Ongoing(t *testing.T) {
m := NewManager().(*manager)
info := NewInfo(types.ProvidesXMR, 1, 1, 0.1, types.ExpectingKeys, nil)
info := NewInfo(types.Hash{}, types.ProvidesXMR, 1, 1, 0.1, types.ExpectingKeys, nil)
err := m.AddSwap(info)
require.NoError(t, err)
err = m.AddSwap(info)
require.Equal(t, errHaveOngoingSwap, err)
require.Equal(t, info, m.GetOngoingSwap())
require.NoError(t, err)
require.Equal(t, info, m.GetOngoingSwap(types.Hash{}))
require.NotNil(t, m.ongoing)
m.CompleteOngoingSwap()
require.Nil(t, m.ongoing)
require.Equal(t, []uint64{0}, m.GetPastIDs())
require.Equal(t, uint64(1), nextID)
m.CompleteOngoingSwap(types.Hash{})
require.Equal(t, 0, len(m.ongoing))
require.Equal(t, []types.Hash{{}}, m.GetPastIDs())
m.CompleteOngoingSwap()
m.CompleteOngoingSwap(types.Hash{})
}
func TestManager_AddSwap_Past(t *testing.T) {
m := NewManager().(*manager)
info := &Info{
id: 1,
id: types.Hash{1},
status: types.CompletedSuccess,
}
err := m.AddSwap(info)
require.NoError(t, err)
require.NotNil(t, m.GetPastSwap(1))
require.NotNil(t, m.GetPastSwap(types.Hash{1}))
info = &Info{
id: 2,
id: types.Hash{2},
status: types.CompletedSuccess,
}
err = m.AddSwap(info)
require.NoError(t, err)
require.NotNil(t, m.GetPastSwap(2))
require.NotNil(t, m.GetPastSwap(types.Hash{2}))
ids := m.GetPastIDs()
require.Equal(t, 2, len(ids))

View File

@@ -5,9 +5,11 @@ import (
"errors"
"fmt"
"math/big"
"sync"
"time"
"github.com/noot/atomic-swap/common"
"github.com/noot/atomic-swap/common/types"
"github.com/noot/atomic-swap/swapfactory"
"github.com/ethereum/go-ethereum/accounts/abi"
@@ -18,6 +20,7 @@ import (
var (
errTransactionTimeout = errors.New("timed out waiting for transaction to be signed")
errNoSwapWithID = errors.New("no swap with given id")
transactionTimeout = time.Minute * 2 // arbitrary, TODO vary this based on env
)
@@ -29,6 +32,13 @@ type Transaction struct {
Value string
}
type swapChs struct {
// outgoing encoded txs to be signed
out chan *Transaction
// incoming tx hashes
in chan ethcommon.Hash
}
// ExternalSender represents a transaction signer and sender that is external to the daemon (ie. a front-end)
type ExternalSender struct {
ctx context.Context
@@ -36,11 +46,9 @@ type ExternalSender struct {
abi *abi.ABI
contractAddr ethcommon.Address
// outgoing encoded txs to be signed
out chan *Transaction
sync.RWMutex
// incoming tx hashes
in chan ethcommon.Hash
swaps map[types.Hash]*swapChs
}
// NewExternalSender returns a new ExternalSender
@@ -56,8 +64,7 @@ func NewExternalSender(ctx context.Context, ec *ethclient.Client,
ec: ec,
abi: abi,
contractAddr: contractAddr,
out: make(chan *Transaction),
in: make(chan ethcommon.Hash),
swaps: make(map[types.Hash]*swapChs),
}, nil
}
@@ -70,18 +77,54 @@ func (s *ExternalSender) SetContractAddress(addr ethcommon.Address) {
}
// OngoingCh returns the channel of outgoing transactions to be signed and submitted
func (s *ExternalSender) OngoingCh() <-chan *Transaction {
return s.out
func (s *ExternalSender) OngoingCh(id types.Hash) (<-chan *Transaction, error) {
s.RLock()
defer s.RUnlock()
chs, has := s.swaps[id]
if !has {
return nil, errNoSwapWithID
}
return chs.out, nil
}
// IncomingCh returns the channel of incoming transaction hashes that have been signed and submitted
func (s *ExternalSender) IncomingCh() chan<- ethcommon.Hash {
return s.in
func (s *ExternalSender) IncomingCh(id types.Hash) (chan<- ethcommon.Hash, error) {
s.RLock()
defer s.RUnlock()
chs, has := s.swaps[id]
if !has {
return nil, errNoSwapWithID
}
return chs.in, nil
}
// AddID initialises the sender with a swap w/ the given ID
func (s *ExternalSender) AddID(id types.Hash) {
s.Lock()
defer s.Unlock()
_, has := s.swaps[id]
if !has {
return
}
s.swaps[id] = &swapChs{
out: make(chan *Transaction),
in: make(chan ethcommon.Hash),
}
}
// DeleteID deletes the swap w/ the given ID from the sender
func (s *ExternalSender) DeleteID(id types.Hash) {
s.Lock()
defer s.Unlock()
delete(s.swaps, id)
}
// NewSwap prompts the external sender to sign a new_swap transaction
func (s *ExternalSender) NewSwap(_pubKeyClaim [32]byte, _pubKeyRefund [32]byte, _claimer ethcommon.Address,
_timeoutDuration *big.Int, _nonce *big.Int, value *big.Int) (ethcommon.Hash, *ethtypes.Receipt, error) {
func (s *ExternalSender) NewSwap(id types.Hash, _pubKeyClaim [32]byte, _pubKeyRefund [32]byte,
_claimer ethcommon.Address, _timeoutDuration *big.Int, _nonce *big.Int,
value *big.Int) (ethcommon.Hash, *ethtypes.Receipt, error) {
input, err := s.abi.Pack("new_swap", _pubKeyClaim, _pubKeyRefund, _claimer, _timeoutDuration, _nonce)
if err != nil {
return ethcommon.Hash{}, nil, err
@@ -93,12 +136,19 @@ func (s *ExternalSender) NewSwap(_pubKeyClaim [32]byte, _pubKeyRefund [32]byte,
Value: fmt.Sprintf("%v", common.EtherAmount(*value).AsEther()),
}
s.out <- tx
s.RLock()
defer s.RUnlock()
chs, has := s.swaps[id]
if !has {
return ethcommon.Hash{}, nil, errNoSwapWithID
}
chs.out <- tx
var txHash ethcommon.Hash
select {
case <-time.After(transactionTimeout):
return ethcommon.Hash{}, nil, errTransactionTimeout
case txHash = <-s.in:
case txHash = <-chs.in:
}
receipt, err := waitForReceipt(s.ctx, s.ec, txHash)
@@ -110,49 +160,58 @@ func (s *ExternalSender) NewSwap(_pubKeyClaim [32]byte, _pubKeyRefund [32]byte,
}
// SetReady prompts the external sender to sign a set_ready transaction
func (s *ExternalSender) SetReady(_swap swapfactory.SwapFactorySwap) (ethcommon.Hash, *ethtypes.Receipt, error) {
func (s *ExternalSender) SetReady(id types.Hash,
_swap swapfactory.SwapFactorySwap) (ethcommon.Hash, *ethtypes.Receipt, error) {
input, err := s.abi.Pack("set_ready", _swap)
if err != nil {
return ethcommon.Hash{}, nil, err
}
return s.sendAndReceive(input)
return s.sendAndReceive(id, input)
}
// Claim prompts the external sender to sign a claim transaction
func (s *ExternalSender) Claim(_swap swapfactory.SwapFactorySwap,
func (s *ExternalSender) Claim(id types.Hash, _swap swapfactory.SwapFactorySwap,
_s [32]byte) (ethcommon.Hash, *ethtypes.Receipt, error) {
input, err := s.abi.Pack("claim", _swap, _s)
if err != nil {
return ethcommon.Hash{}, nil, err
}
return s.sendAndReceive(input)
return s.sendAndReceive(id, input)
}
// Refund prompts the external sender to sign a refund transaction
func (s *ExternalSender) Refund(_swap swapfactory.SwapFactorySwap,
func (s *ExternalSender) Refund(id types.Hash, _swap swapfactory.SwapFactorySwap,
_s [32]byte) (ethcommon.Hash, *ethtypes.Receipt, error) {
input, err := s.abi.Pack("refund", _swap, _s)
if err != nil {
return ethcommon.Hash{}, nil, err
}
return s.sendAndReceive(input)
return s.sendAndReceive(id, input)
}
func (s *ExternalSender) sendAndReceive(input []byte) (ethcommon.Hash, *ethtypes.Receipt, error) {
func (s *ExternalSender) sendAndReceive(id types.Hash,
input []byte) (ethcommon.Hash, *ethtypes.Receipt, error) {
tx := &Transaction{
To: s.contractAddr,
Data: fmt.Sprintf("0x%x", input),
}
s.out <- tx
s.RLock()
defer s.RUnlock()
chs, has := s.swaps[id]
if !has {
return ethcommon.Hash{}, nil, errNoSwapWithID
}
chs.out <- tx
var txHash ethcommon.Hash
select {
case <-time.After(transactionTimeout):
return ethcommon.Hash{}, nil, errTransactionTimeout
case txHash = <-s.in:
case txHash = <-chs.in:
}
receipt, err := waitForReceipt(s.ctx, s.ec, txHash)

View File

@@ -6,6 +6,7 @@ import (
"math/big"
"time"
"github.com/noot/atomic-swap/common/types"
"github.com/noot/atomic-swap/swapfactory"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
@@ -27,11 +28,13 @@ var (
type Sender interface {
SetContract(*swapfactory.SwapFactory)
SetContractAddress(ethcommon.Address)
NewSwap(_pubKeyClaim [32]byte, _pubKeyRefund [32]byte, _claimer ethcommon.Address, _timeoutDuration *big.Int,
_nonce *big.Int, amount *big.Int) (ethcommon.Hash, *ethtypes.Receipt, error)
SetReady(_swap swapfactory.SwapFactorySwap) (ethcommon.Hash, *ethtypes.Receipt, error)
Claim(_swap swapfactory.SwapFactorySwap, _s [32]byte) (ethcommon.Hash, *ethtypes.Receipt, error)
Refund(_swap swapfactory.SwapFactorySwap, _s [32]byte) (ethcommon.Hash, *ethtypes.Receipt, error)
NewSwap(id types.Hash, _pubKeyClaim [32]byte, _pubKeyRefund [32]byte, _claimer ethcommon.Address,
_timeoutDuration *big.Int, _nonce *big.Int, amount *big.Int) (ethcommon.Hash, *ethtypes.Receipt, error)
SetReady(id types.Hash, _swap swapfactory.SwapFactorySwap) (ethcommon.Hash, *ethtypes.Receipt, error)
Claim(id types.Hash, _swap swapfactory.SwapFactorySwap,
_s [32]byte) (ethcommon.Hash, *ethtypes.Receipt, error)
Refund(id types.Hash, _swap swapfactory.SwapFactorySwap,
_s [32]byte) (ethcommon.Hash, *ethtypes.Receipt, error)
}
type privateKeySender struct {
@@ -58,8 +61,9 @@ func (s *privateKeySender) SetContract(contract *swapfactory.SwapFactory) {
func (s *privateKeySender) SetContractAddress(_ ethcommon.Address) {}
func (s *privateKeySender) NewSwap(_pubKeyClaim [32]byte, _pubKeyRefund [32]byte, _claimer ethcommon.Address,
_timeoutDuration *big.Int, _nonce *big.Int, value *big.Int) (ethcommon.Hash, *ethtypes.Receipt, error) {
func (s *privateKeySender) NewSwap(_ types.Hash, _pubKeyClaim [32]byte, _pubKeyRefund [32]byte,
_claimer ethcommon.Address, _timeoutDuration *big.Int, _nonce *big.Int,
value *big.Int) (ethcommon.Hash, *ethtypes.Receipt, error) {
s.txOpts.Value = value
defer func() {
s.txOpts.Value = nil
@@ -78,7 +82,8 @@ func (s *privateKeySender) NewSwap(_pubKeyClaim [32]byte, _pubKeyRefund [32]byte
return tx.Hash(), receipt, nil
}
func (s *privateKeySender) SetReady(_swap swapfactory.SwapFactorySwap) (ethcommon.Hash, *ethtypes.Receipt, error) {
func (s *privateKeySender) SetReady(_ types.Hash,
_swap swapfactory.SwapFactorySwap) (ethcommon.Hash, *ethtypes.Receipt, error) {
tx, err := s.contract.SetReady(s.txOpts, _swap)
if err != nil {
return ethcommon.Hash{}, nil, err
@@ -92,7 +97,7 @@ func (s *privateKeySender) SetReady(_swap swapfactory.SwapFactorySwap) (ethcommo
return tx.Hash(), receipt, nil
}
func (s *privateKeySender) Claim(_swap swapfactory.SwapFactorySwap,
func (s *privateKeySender) Claim(_ types.Hash, _swap swapfactory.SwapFactorySwap,
_s [32]byte) (ethcommon.Hash, *ethtypes.Receipt, error) {
tx, err := s.contract.Claim(s.txOpts, _swap, _s)
if err != nil {
@@ -107,7 +112,7 @@ func (s *privateKeySender) Claim(_swap swapfactory.SwapFactorySwap,
return tx.Hash(), receipt, nil
}
func (s *privateKeySender) Refund(_swap swapfactory.SwapFactorySwap,
func (s *privateKeySender) Refund(_ types.Hash, _swap swapfactory.SwapFactorySwap,
_s [32]byte) (ethcommon.Hash, *ethtypes.Receipt, error) {
tx, err := s.contract.Refund(s.txOpts, _swap, _s)
if err != nil {

View File

@@ -4,6 +4,7 @@ import (
"sync"
"github.com/noot/atomic-swap/common"
"github.com/noot/atomic-swap/common/types"
"github.com/noot/atomic-swap/protocol/backend"
logging "github.com/ipfs/go-log"
@@ -23,8 +24,8 @@ type Instance struct {
offerManager *offerManager
swapMu sync.Mutex
swapState *swapState
swapMu sync.Mutex
swapStates map[types.Hash]*swapState
}
// Config contains the configuration values for a new XMRMaker instance.
@@ -51,6 +52,7 @@ func NewInstance(cfg *Config) (*Instance, error) {
walletFile: cfg.WalletFile,
walletPassword: cfg.WalletPassword,
offerManager: newOfferManager(cfg.Basepath),
swapStates: make(map[types.Hash]*swapState),
}, nil
}
@@ -65,6 +67,6 @@ func (b *Instance) openWallet() error { //nolint
}
// GetOngoingSwapState ...
func (b *Instance) GetOngoingSwapState() common.SwapState {
return b.swapState
func (b *Instance) GetOngoingSwapState(id types.Hash) common.SwapState {
return b.swapStates[id]
}

View File

@@ -93,6 +93,10 @@ func (s *swapState) setNextExpectedMessage(msg net.Message) {
return
}
if msg == nil || s.nextExpectedMessage == nil {
return
}
if msg.Type() == s.nextExpectedMessage.Type() {
return
}
@@ -210,7 +214,7 @@ func (s *swapState) handleNotifyETHLocked(msg *message.NotifyETHLocked) (net.Mes
// send *message.NotifyClaimed
if err := s.SendSwapMessage(&message.NotifyClaimed{
TxHash: txHash.String(),
}); err != nil {
}, s.ID()); err != nil {
log.Errorf("failed to send NotifyClaimed message: err=%s", err)
}
case <-s.readyCh:

View File

@@ -16,6 +16,7 @@ import (
types "github.com/ethereum/go-ethereum/core/types"
gomock "github.com/golang/mock/gomock"
common0 "github.com/noot/atomic-swap/common"
types0 "github.com/noot/atomic-swap/common/types"
mcrypto "github.com/noot/atomic-swap/crypto/monero"
monero "github.com/noot/atomic-swap/monero"
net "github.com/noot/atomic-swap/net"
@@ -92,9 +93,9 @@ func (mr *MockBackendMockRecorder) ChainID() *gomock.Call {
}
// Claim mocks base method.
func (m *MockBackend) Claim(arg0 swapfactory.SwapFactorySwap, arg1 [32]byte) (common.Hash, *types.Receipt, error) {
func (m *MockBackend) Claim(arg0 types0.Hash, arg1 swapfactory.SwapFactorySwap, arg2 [32]byte) (common.Hash, *types.Receipt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Claim", arg0, arg1)
ret := m.ctrl.Call(m, "Claim", arg0, arg1, arg2)
ret0, _ := ret[0].(common.Hash)
ret1, _ := ret[1].(*types.Receipt)
ret2, _ := ret[2].(error)
@@ -102,9 +103,9 @@ func (m *MockBackend) Claim(arg0 swapfactory.SwapFactorySwap, arg1 [32]byte) (co
}
// Claim indicates an expected call of Claim.
func (mr *MockBackendMockRecorder) Claim(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockBackendMockRecorder) Claim(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Claim", reflect.TypeOf((*MockBackend)(nil).Claim), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Claim", reflect.TypeOf((*MockBackend)(nil).Claim), arg0, arg1, arg2)
}
// CloseWallet mocks base method.
@@ -351,6 +352,18 @@ func (mr *MockBackendMockRecorder) GetHeight() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeight", reflect.TypeOf((*MockBackend)(nil).GetHeight))
}
// LockClient mocks base method.
func (m *MockBackend) LockClient() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "LockClient")
}
// LockClient indicates an expected call of LockClient.
func (mr *MockBackendMockRecorder) LockClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockClient", reflect.TypeOf((*MockBackend)(nil).LockClient))
}
// Net mocks base method.
func (m *MockBackend) Net() net.MessageSender {
m.ctrl.T.Helper()
@@ -366,9 +379,9 @@ func (mr *MockBackendMockRecorder) Net() *gomock.Call {
}
// NewSwap mocks base method.
func (m *MockBackend) NewSwap(arg0, arg1 [32]byte, arg2 common.Address, arg3, arg4, arg5 *big.Int) (common.Hash, *types.Receipt, error) {
func (m *MockBackend) NewSwap(arg0 types0.Hash, arg1, arg2 [32]byte, arg3 common.Address, arg4, arg5, arg6 *big.Int) (common.Hash, *types.Receipt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewSwap", arg0, arg1, arg2, arg3, arg4, arg5)
ret := m.ctrl.Call(m, "NewSwap", arg0, arg1, arg2, arg3, arg4, arg5, arg6)
ret0, _ := ret[0].(common.Hash)
ret1, _ := ret[1].(*types.Receipt)
ret2, _ := ret[2].(error)
@@ -376,9 +389,9 @@ func (m *MockBackend) NewSwap(arg0, arg1 [32]byte, arg2 common.Address, arg3, ar
}
// NewSwap indicates an expected call of NewSwap.
func (mr *MockBackendMockRecorder) NewSwap(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
func (mr *MockBackendMockRecorder) NewSwap(arg0, arg1, arg2, arg3, arg4, arg5, arg6 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSwap", reflect.TypeOf((*MockBackend)(nil).NewSwap), arg0, arg1, arg2, arg3, arg4, arg5)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSwap", reflect.TypeOf((*MockBackend)(nil).NewSwap), arg0, arg1, arg2, arg3, arg4, arg5, arg6)
}
// NewSwapFactory mocks base method.
@@ -425,9 +438,9 @@ func (mr *MockBackendMockRecorder) Refresh() *gomock.Call {
}
// Refund mocks base method.
func (m *MockBackend) Refund(arg0 swapfactory.SwapFactorySwap, arg1 [32]byte) (common.Hash, *types.Receipt, error) {
func (m *MockBackend) Refund(arg0 types0.Hash, arg1 swapfactory.SwapFactorySwap, arg2 [32]byte) (common.Hash, *types.Receipt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Refund", arg0, arg1)
ret := m.ctrl.Call(m, "Refund", arg0, arg1, arg2)
ret0, _ := ret[0].(common.Hash)
ret1, _ := ret[1].(*types.Receipt)
ret2, _ := ret[2].(error)
@@ -435,23 +448,35 @@ func (m *MockBackend) Refund(arg0 swapfactory.SwapFactorySwap, arg1 [32]byte) (c
}
// Refund indicates an expected call of Refund.
func (mr *MockBackendMockRecorder) Refund(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockBackendMockRecorder) Refund(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Refund", reflect.TypeOf((*MockBackend)(nil).Refund), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Refund", reflect.TypeOf((*MockBackend)(nil).Refund), arg0, arg1, arg2)
}
// SendSwapMessage mocks base method.
func (m *MockBackend) SendSwapMessage(arg0 message.Message) error {
func (m *MockBackend) SendSwapMessage(arg0 message.Message, arg1 types0.Hash) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendSwapMessage", arg0)
ret := m.ctrl.Call(m, "SendSwapMessage", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SendSwapMessage indicates an expected call of SendSwapMessage.
func (mr *MockBackendMockRecorder) SendSwapMessage(arg0 interface{}) *gomock.Call {
func (mr *MockBackendMockRecorder) SendSwapMessage(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSwapMessage", reflect.TypeOf((*MockBackend)(nil).SendSwapMessage), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSwapMessage", reflect.TypeOf((*MockBackend)(nil).SendSwapMessage), arg0, arg1)
}
// SetBaseXMRDepositAddress mocks base method.
func (m *MockBackend) SetBaseXMRDepositAddress(arg0 mcrypto.Address) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetBaseXMRDepositAddress", arg0)
}
// SetBaseXMRDepositAddress indicates an expected call of SetBaseXMRDepositAddress.
func (mr *MockBackendMockRecorder) SetBaseXMRDepositAddress(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBaseXMRDepositAddress", reflect.TypeOf((*MockBackend)(nil).SetBaseXMRDepositAddress), arg0)
}
// SetContract mocks base method.
@@ -503,9 +528,9 @@ func (mr *MockBackendMockRecorder) SetGasPrice(arg0 interface{}) *gomock.Call {
}
// SetReady mocks base method.
func (m *MockBackend) SetReady(arg0 swapfactory.SwapFactorySwap) (common.Hash, *types.Receipt, error) {
func (m *MockBackend) SetReady(arg0 types0.Hash, arg1 swapfactory.SwapFactorySwap) (common.Hash, *types.Receipt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetReady", arg0)
ret := m.ctrl.Call(m, "SetReady", arg0, arg1)
ret0, _ := ret[0].(common.Hash)
ret1, _ := ret[1].(*types.Receipt)
ret2, _ := ret[2].(error)
@@ -513,9 +538,9 @@ func (m *MockBackend) SetReady(arg0 swapfactory.SwapFactorySwap) (common.Hash, *
}
// SetReady indicates an expected call of SetReady.
func (mr *MockBackendMockRecorder) SetReady(arg0 interface{}) *gomock.Call {
func (mr *MockBackendMockRecorder) SetReady(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReady", reflect.TypeOf((*MockBackend)(nil).SetReady), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReady", reflect.TypeOf((*MockBackend)(nil).SetReady), arg0, arg1)
}
// SetSwapTimeout mocks base method.
@@ -531,15 +556,15 @@ func (mr *MockBackendMockRecorder) SetSwapTimeout(arg0 interface{}) *gomock.Call
}
// SetXMRDepositAddress mocks base method.
func (m *MockBackend) SetXMRDepositAddress(arg0 mcrypto.Address) {
func (m *MockBackend) SetXMRDepositAddress(arg0 mcrypto.Address, arg1 types0.Hash) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetXMRDepositAddress", arg0)
m.ctrl.Call(m, "SetXMRDepositAddress", arg0, arg1)
}
// SetXMRDepositAddress indicates an expected call of SetXMRDepositAddress.
func (mr *MockBackendMockRecorder) SetXMRDepositAddress(arg0 interface{}) *gomock.Call {
func (mr *MockBackendMockRecorder) SetXMRDepositAddress(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetXMRDepositAddress", reflect.TypeOf((*MockBackend)(nil).SetXMRDepositAddress), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetXMRDepositAddress", reflect.TypeOf((*MockBackend)(nil).SetXMRDepositAddress), arg0, arg1)
}
// SwapManager mocks base method.
@@ -630,6 +655,18 @@ func (mr *MockBackendMockRecorder) TxOpts() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxOpts", reflect.TypeOf((*MockBackend)(nil).TxOpts))
}
// UnlockClient mocks base method.
func (m *MockBackend) UnlockClient() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "UnlockClient")
}
// UnlockClient indicates an expected call of UnlockClient.
func (mr *MockBackendMockRecorder) UnlockClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnlockClient", reflect.TypeOf((*MockBackend)(nil).UnlockClient))
}
// WaitForReceipt mocks base method.
func (m *MockBackend) WaitForReceipt(arg0 context.Context, arg1 common.Hash) (*types.Receipt, error) {
m.ctrl.T.Helper()
@@ -646,15 +683,16 @@ func (mr *MockBackendMockRecorder) WaitForReceipt(arg0, arg1 interface{}) *gomoc
}
// XMRDepositAddress mocks base method.
func (m *MockBackend) XMRDepositAddress() mcrypto.Address {
func (m *MockBackend) XMRDepositAddress(arg0 *types0.Hash) (mcrypto.Address, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XMRDepositAddress")
ret := m.ctrl.Call(m, "XMRDepositAddress", arg0)
ret0, _ := ret[0].(mcrypto.Address)
return ret0
ret1, _ := ret[1].(error)
return ret0, ret1
}
// XMRDepositAddress indicates an expected call of XMRDepositAddress.
func (mr *MockBackendMockRecorder) XMRDepositAddress() *gomock.Call {
func (mr *MockBackendMockRecorder) XMRDepositAddress(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XMRDepositAddress", reflect.TypeOf((*MockBackend)(nil).XMRDepositAddress))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XMRDepositAddress", reflect.TypeOf((*MockBackend)(nil).XMRDepositAddress), arg0)
}

View File

@@ -19,7 +19,7 @@ func (b *Instance) initiate(offer *types.Offer, offerExtra *types.OfferExtra, pr
b.swapMu.Lock()
defer b.swapMu.Unlock()
if b.swapState != nil {
if b.swapStates[offer.GetID()] != nil {
return errProtocolAlreadyInProgress
}
@@ -33,23 +33,24 @@ func (b *Instance) initiate(offer *types.Offer, offerExtra *types.OfferExtra, pr
return errBalanceTooLow
}
b.swapState, err = newSwapState(b.backend, offer, b.offerManager, offerExtra.StatusCh,
s, err := newSwapState(b.backend, offer, b.offerManager, offerExtra.StatusCh,
offerExtra.InfoFile, providesAmount, desiredAmount)
if err != nil {
return err
}
go func() {
<-b.swapState.done
b.swapState = nil
<-s.done
delete(b.swapStates, offer.GetID())
}()
log.Info(color.New(color.Bold).Sprintf("**initiated swap with ID=%d**", b.swapState.ID()))
log.Info(color.New(color.Bold).Sprintf("**initiated swap with ID=%s**", s.ID()))
log.Info(color.New(color.Bold).Sprint("DO NOT EXIT THIS PROCESS OR FUNDS MAY BE LOST!"))
log.Infof(color.New(color.Bold).Sprintf("receiving %v ETH for %v XMR",
b.swapState.info.ReceivedAmount(),
b.swapState.info.ProvidedAmount()),
s.info.ReceivedAmount(),
s.info.ProvidedAmount()),
)
b.swapStates[offer.GetID()] = s
return nil
}
@@ -86,18 +87,27 @@ func (b *Instance) HandleInitiateMessage(msg *net.SendKeysMessage) (net.SwapStat
return nil, nil, err
}
offerExtra.IDCh <- b.swapState.info.ID()
close(offerExtra.IDCh)
if err = b.swapState.handleSendKeysMessage(msg); err != nil {
return nil, nil, err
}
resp, err := b.swapState.SendKeysMessage()
offerID, err := types.HexToHash(msg.OfferID)
if err != nil {
return nil, nil, err
}
defer b.swapState.setNextExpectedMessage(&message.NotifyETHLocked{})
return b.swapState, resp, nil
s, has := b.swapStates[offerID]
if !has {
panic("did not store swap state in Instance map")
}
if err = s.handleSendKeysMessage(msg); err != nil {
return nil, nil, err
}
resp, err := s.SendKeysMessage()
if err != nil {
return nil, nil, err
}
defer func() {
s.setNextExpectedMessage(&message.NotifyETHLocked{})
}()
return s, resp, nil
}

View File

@@ -18,11 +18,8 @@ func TestXMRMaker_HandleInitiateMessage(t *testing.T) {
MaximumAmount: 0.002,
ExchangeRate: 0.1,
}
extra, err := b.MakeOffer(offer)
_, err := b.MakeOffer(offer)
require.NoError(t, err)
go func() {
<-extra.IDCh
}()
msg, _ := newTestXMRTakerSendKeysMessage(t)
msg.OfferID = offer.GetID().String()
@@ -31,5 +28,5 @@ func TestXMRMaker_HandleInitiateMessage(t *testing.T) {
_, resp, err := b.HandleInitiateMessage(msg)
require.NoError(t, err)
require.Equal(t, message.SendKeysType, resp.Type())
require.NotNil(t, b.swapState)
require.NotNil(t, b.swapStates[offer.GetID()])
}

View File

@@ -30,7 +30,7 @@ func (om *offerManager) putOffer(o *types.Offer) *types.OfferExtra {
}
extra := &types.OfferExtra{
IDCh: make(chan uint64, 1),
//IDCh: make(chan uint64, 1),
StatusCh: make(chan types.Status, 7),
InfoFile: pcommon.GetSwapInfoFilepath(om.basepath),
}
@@ -56,6 +56,9 @@ func (om *offerManager) getAndDeleteOffer(id types.Hash) (*types.Offer, *types.O
// MakeOffer makes a new swap offer.
func (b *Instance) MakeOffer(o *types.Offer) (*types.OfferExtra, error) {
b.backend.LockClient()
defer b.backend.UnlockClient()
balance, err := b.backend.GetBalance(0)
if err != nil {
return nil, err

View File

@@ -22,11 +22,6 @@ type recoveryState struct {
func NewRecoveryState(b backend.Backend, basepath string, secret *mcrypto.PrivateSpendKey,
contractAddr ethcommon.Address,
contractSwapID [32]byte, contractSwap swapfactory.SwapFactorySwap) (*recoveryState, error) { //nolint:revive
txOpts, err := b.TxOpts()
if err != nil {
return nil, err
}
kp, err := secret.AsPrivateKeyPair()
if err != nil {
return nil, err
@@ -34,9 +29,6 @@ func NewRecoveryState(b backend.Backend, basepath string, secret *mcrypto.Privat
pubkp := kp.PublicKeyPair()
// txOpts.GasPrice = b.gasPrice
// txOpts.GasLimit = b.gasLimit
var sc [32]byte
copy(sc[:], secret.Bytes())
@@ -45,7 +37,6 @@ func NewRecoveryState(b backend.Backend, basepath string, secret *mcrypto.Privat
ctx: ctx,
cancel: cancel,
Backend: b,
txOpts: txOpts,
privkeys: kp,
pubkeys: pubkp,
dleqProof: dleq.NewProofWithSecret(sc),

View File

@@ -33,9 +33,11 @@ func newTestRecoveryState(t *testing.T) *recoveryState {
func TestClaimOrRecover_Claim(t *testing.T) {
// test case where XMRMaker is able to claim ether from the contract
rs := newTestRecoveryState(t)
txOpts, err := rs.ss.TxOpts()
require.NoError(t, err)
// set contract to Ready
_, err := rs.ss.Contract().SetReady(rs.ss.txOpts, rs.ss.contractSwap)
_, err = rs.ss.Contract().SetReady(txOpts, rs.ss.contractSwap)
require.NoError(t, err)
// assert we can claim ether
@@ -51,6 +53,8 @@ func TestClaimOrRecover_Recover(t *testing.T) {
// test case where XMRMaker is able to reclaim his monero, after XMRTaker refunds
rs := newTestRecoveryState(t)
txOpts, err := rs.ss.TxOpts()
require.NoError(t, err)
daemonClient := monero.NewClient(common.DefaultMoneroDaemonEndpoint)
addr, err := rs.ss.GetAddress(0)
@@ -64,7 +68,7 @@ func TestClaimOrRecover_Recover(t *testing.T) {
// call refund w/ XMRTaker's spend key
sc := rs.ss.getSecret()
_, err = rs.ss.Contract().Refund(rs.ss.txOpts, rs.ss.contractSwap, sc)
_, err = rs.ss.Contract().Refund(txOpts, rs.ss.contractSwap, sc)
require.NoError(t, err)
// assert XMRMaker can reclaim his monero

View File

@@ -11,7 +11,6 @@ import (
"time"
eth "github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
ethcommon "github.com/ethereum/go-ethereum/common"
ethtypes "github.com/ethereum/go-ethereum/core/types"
"github.com/fatih/color" //nolint:misspell
@@ -61,7 +60,6 @@ type swapState struct {
contractSwapID [32]byte
contractSwap swapfactory.SwapFactorySwap
t0, t1 time.Time
txOpts *bind.TransactOpts
// XMRTaker's keys for this session
xmrtakerPublicKeys *mcrypto.PublicKeyPair
@@ -81,18 +79,13 @@ type swapState struct {
func newSwapState(b backend.Backend, offer *types.Offer, om *offerManager, statusCh chan types.Status, infofile string,
providesAmount common.MoneroAmount, desiredAmount common.EtherAmount) (*swapState, error) {
txOpts, err := b.TxOpts()
if err != nil {
return nil, err
}
exchangeRate := types.ExchangeRate(providesAmount.AsMonero() / desiredAmount.AsEther())
stage := types.ExpectingKeys
if statusCh == nil {
statusCh = make(chan types.Status, 7)
}
statusCh <- stage
info := pswap.NewInfo(types.ProvidesXMR, providesAmount.AsMonero(), desiredAmount.AsEther(),
info := pswap.NewInfo(offer.GetID(), types.ProvidesXMR, providesAmount.AsMonero(), desiredAmount.AsEther(),
exchangeRate, stage, statusCh)
if err := b.SwapManager().AddSwap(info); err != nil {
return nil, err
@@ -108,7 +101,6 @@ func newSwapState(b backend.Backend, offer *types.Offer, om *offerManager, statu
infofile: infofile,
nextExpectedMessage: &net.SendKeysMessage{},
readyCh: make(chan struct{}),
txOpts: txOpts,
info: info,
statusCh: statusCh,
done: make(chan struct{}),
@@ -144,7 +136,7 @@ func (s *swapState) ReceivedAmount() float64 {
}
// ID returns the ID of the swap
func (s *swapState) ID() uint64 {
func (s *swapState) ID() types.Hash {
return s.info.ID()
}
@@ -177,7 +169,7 @@ func (s *swapState) exit() error {
defer func() {
// stop all running goroutines
s.cancel()
s.SwapManager().CompleteOngoingSwap()
s.SwapManager().CompleteOngoingSwap(s.offer.GetID())
if s.info.Status() != types.CompletedSuccess {
// re-add offer, as it wasn't taken successfully
@@ -188,13 +180,13 @@ func (s *swapState) exit() error {
}()
if s.info.Status() == types.CompletedSuccess {
str := color.New(color.Bold).Sprintf("**swap completed successfully: id=%d**", s.ID())
str := color.New(color.Bold).Sprintf("**swap completed successfully: id=%s**", s.ID())
log.Info(str)
return nil
}
if s.info.Status() == types.CompletedRefund {
str := color.New(color.Bold).Sprintf("**swap refunded successfully: id=%d**", s.ID())
str := color.New(color.Bold).Sprintf("**swap refunded successfully: id=%s**", s.ID())
log.Info(str)
return nil
}
@@ -274,6 +266,8 @@ func (s *swapState) reclaimMonero(skA *mcrypto.PrivateSpendKey) (mcrypto.Address
}
// TODO: check balance
s.LockClient()
defer s.UnlockClient()
return monero.CreateMoneroWallet("xmrmaker-swap-wallet", s.Env(), s, kpAB)
}
@@ -461,6 +455,9 @@ func (s *swapState) lockFunds(amount common.MoneroAmount) (mcrypto.Address, erro
kp := mcrypto.SumSpendAndViewKeys(s.xmrtakerPublicKeys, s.pubkeys)
log.Infof("going to lock XMR funds, amount(piconero)=%d", amount)
s.LockClient()
defer s.UnlockClient()
balance, err := s.GetBalance(0)
if err != nil {
return "", err
@@ -516,7 +513,7 @@ func (s *swapState) claimFunds() (ethcommon.Hash, error) {
// call swap.Swap.Claim() w/ b.privkeys.sk, revealing XMRMaker's secret spend key
sc := s.getSecret()
txHash, _, err := s.Claim(s.contractSwap, sc)
txHash, _, err := s.Claim(s.ID(), s.contractSwap, sc)
if err != nil {
return ethcommon.Hash{}, err
}

View File

@@ -37,7 +37,7 @@ type mockNet struct {
msg net.Message
}
func (n *mockNet) SendSwapMessage(msg net.Message) error {
func (n *mockNet) SendSwapMessage(msg net.Message, _ types.Hash) error {
n.msg = msg
return nil
}
@@ -175,6 +175,10 @@ func TestSwapState_GenerateAndSetKeys(t *testing.T) {
}
func TestSwapState_ClaimFunds(t *testing.T) {
if testing.Short() {
t.Skip() // TODO: randomly fails on CI with "no contract code at given address"
}
_, swapState := newTestInstance(t)
err := swapState.generateAndSetKeys()
require.NoError(t, err)
@@ -183,7 +187,9 @@ func TestSwapState_ClaimFunds(t *testing.T) {
newSwap(t, swapState, claimKey,
[32]byte{}, big.NewInt(33), defaultTimeoutDuration)
_, err = swapState.Contract().SetReady(swapState.txOpts, swapState.contractSwap)
txOpts, err := swapState.TxOpts()
require.NoError(t, err)
_, err = swapState.Contract().SetReady(txOpts, swapState.contractSwap)
require.NoError(t, err)
txHash, err := swapState.claimFunds()
@@ -316,7 +322,9 @@ func TestSwapState_HandleProtocolMessage_NotifyReady(t *testing.T) {
require.NoError(t, err)
newSwap(t, s, [32]byte{}, [32]byte{}, desiredAmount.BigInt(), duration)
_, err = s.Contract().SetReady(s.txOpts, s.contractSwap)
txOpts, err := s.TxOpts()
require.NoError(t, err)
_, err = s.Contract().SetReady(txOpts, s.contractSwap)
require.NoError(t, err)
msg := &message.NotifyReady{}
@@ -354,7 +362,9 @@ func TestSwapState_handleRefund(t *testing.T) {
var sc [32]byte
copy(sc[:], common.Reverse(secret))
tx, err := s.Contract().Refund(s.txOpts, s.contractSwap, sc)
txOpts, err := s.TxOpts()
require.NoError(t, err)
tx, err := s.Contract().Refund(txOpts, s.contractSwap, sc)
require.NoError(t, err)
addr, err := s.handleRefund(tx.Hash().String())
@@ -387,7 +397,9 @@ func TestSwapState_HandleProtocolMessage_NotifyRefund(t *testing.T) {
var sc [32]byte
copy(sc[:], common.Reverse(secret[:]))
tx, err := s.Contract().Refund(s.txOpts, s.contractSwap, sc)
txOpts, err := s.TxOpts()
require.NoError(t, err)
tx, err := s.Contract().Refund(txOpts, s.contractSwap, sc)
require.NoError(t, err)
msg := &message.NotifyRefund{
@@ -427,7 +439,9 @@ func TestSwapState_Exit_Reclaim(t *testing.T) {
var sc [32]byte
copy(sc[:], common.Reverse(secret[:]))
tx, err := s.Contract().Refund(s.txOpts, s.contractSwap, sc)
txOpts, err := s.TxOpts()
require.NoError(t, err)
tx, err := s.Contract().Refund(txOpts, s.contractSwap, sc)
require.NoError(t, err)
receipt, err := s.TransactionReceipt(s.ctx, tx.Hash())

View File

@@ -6,7 +6,7 @@ import (
var (
// various instance and swap errors
errNoOngoingSwap = errors.New("no ongoing swap")
errNoOngoingSwap = errors.New("no ongoing swap with given offer ID")
errUnexpectedMessageType = errors.New("unexpected message type")
errMissingKeys = errors.New("did not receive XMRMaker's public spend or private view key")
errMissingAddress = errors.New("did not receive XMRMaker's address")

View File

@@ -7,6 +7,7 @@ import (
ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/noot/atomic-swap/common"
"github.com/noot/atomic-swap/common/types"
mcrypto "github.com/noot/atomic-swap/crypto/monero"
"github.com/noot/atomic-swap/monero"
"github.com/noot/atomic-swap/protocol/backend"
@@ -32,8 +33,9 @@ type Instance struct {
transferBack bool // transfer xmr back to original account
// non-nil if a swap is currently happening, nil otherwise
swapMu sync.Mutex
swapState *swapState
// map of offer IDs -> ongoing swaps
swapStates map[types.Hash]*swapState
swapMu sync.Mutex // lock for above map
}
// Config contains the configuration values for a new XMRTaker instance.
@@ -58,7 +60,7 @@ func NewInstance(cfg *Config) (*Instance, error) {
if err != nil {
return nil, err
}
cfg.Backend.SetXMRDepositAddress(address)
cfg.Backend.SetBaseXMRDepositAddress(address)
}
// TODO: check that XMRTaker's monero-wallet-cli endpoint has wallet-dir configured
@@ -67,6 +69,7 @@ func NewInstance(cfg *Config) (*Instance, error) {
basepath: cfg.Basepath,
walletFile: cfg.MoneroWalletFile,
walletPassword: cfg.MoneroWalletPassword,
swapStates: make(map[types.Hash]*swapState),
}, nil
}
@@ -103,18 +106,19 @@ func getAddress(walletClient monero.Client, file, password string) (mcrypto.Addr
// Refund is called by the RPC function swap_refund.
// If it's possible to refund the ongoing swap, it does that, then notifies the counterparty.
func (a *Instance) Refund() (ethcommon.Hash, error) {
func (a *Instance) Refund(offerID types.Hash) (ethcommon.Hash, error) {
a.swapMu.Lock()
defer a.swapMu.Unlock()
if a.swapState == nil {
s, has := a.swapStates[offerID]
if !has {
return ethcommon.Hash{}, errNoOngoingSwap
}
return a.swapState.doRefund()
return s.doRefund()
}
// GetOngoingSwapState ...
func (a *Instance) GetOngoingSwapState() common.SwapState {
return a.swapState
func (a *Instance) GetOngoingSwapState(offerID types.Hash) common.SwapState {
return a.swapStates[offerID]
}

View File

@@ -178,7 +178,7 @@ func (s *swapState) handleSendKeysMessage(msg *net.SendKeysMessage) (net.Message
// send NotifyRefund msg
if err := s.SendSwapMessage(&message.NotifyRefund{
TxHash: txhash.String(),
}); err != nil {
}, s.ID()); err != nil {
log.Errorf("failed to send refund message: err=%s", err)
}
case <-s.xmrLockedCh:
@@ -213,6 +213,9 @@ func (s *swapState) handleNotifyXMRLock(msg *message.NotifyXMRLock) (net.Message
return nil, fmt.Errorf("address received in message does not match expected address")
}
s.LockClient()
defer s.UnlockClient()
t := time.Now().Format("2006-Jan-2-15:04:05")
walletName := fmt.Sprintf("xmrtaker-viewonly-wallet-%s", t)
if err := s.GenerateViewOnlyWalletFromKeys(vk, kp.Address(s.Env()), walletName, ""); err != nil {
@@ -314,7 +317,7 @@ func (s *swapState) handleNotifyXMRLock(msg *message.NotifyXMRLock) (net.Message
// send NotifyRefund msg
if err = s.SendSwapMessage(&message.NotifyRefund{
TxHash: txhash.String(),
}); err != nil {
}, s.ID()); err != nil {
log.Errorf("failed to send refund message: err=%s", err)
}

View File

@@ -18,20 +18,20 @@ func (a *Instance) Provides() types.ProvidesCoin {
func (a *Instance) InitiateProtocol(providesAmount float64, offer *types.Offer) (common.SwapState, error) {
receivedAmount := offer.ExchangeRate.ToXMR(providesAmount)
err := a.initiate(common.EtherToWei(providesAmount), common.MoneroToPiconero(receivedAmount),
offer.ExchangeRate)
offer.ExchangeRate, offer.GetID())
if err != nil {
return nil, err
}
return a.swapState, nil
return a.swapStates[offer.GetID()], nil
}
func (a *Instance) initiate(providesAmount common.EtherAmount, receivedAmount common.MoneroAmount,
exchangeRate types.ExchangeRate) error {
exchangeRate types.ExchangeRate, offerID types.Hash) error {
a.swapMu.Lock()
defer a.swapMu.Unlock()
if a.swapState != nil {
if a.swapStates[offerID] != nil {
return errProtocolAlreadyInProgress
}
@@ -45,18 +45,19 @@ func (a *Instance) initiate(providesAmount common.EtherAmount, receivedAmount co
return errBalanceTooLow
}
a.swapState, err = newSwapState(a.backend, pcommon.GetSwapInfoFilepath(a.basepath), a.transferBack,
s, err := newSwapState(a.backend, offerID, pcommon.GetSwapInfoFilepath(a.basepath), a.transferBack,
providesAmount, receivedAmount, exchangeRate)
if err != nil {
return err
}
go func() {
<-a.swapState.done
a.swapState = nil
<-s.done
delete(a.swapStates, offerID)
}()
log.Info(color.New(color.Bold).Sprintf("**initiated swap with ID=%d**", a.swapState.info.ID()))
log.Info(color.New(color.Bold).Sprintf("**initiated swap with ID=%s**", s.info.ID()))
log.Info(color.New(color.Bold).Sprint("DO NOT EXIT THIS PROCESS OR FUNDS MAY BE LOST!"))
a.swapStates[offerID] = s
return nil
}

View File

@@ -22,9 +22,10 @@ func newTestXMRTaker(t *testing.T) *Instance {
func TestXMRTaker_InitiateProtocol(t *testing.T) {
a := newTestXMRTaker(t)
s, err := a.InitiateProtocol(3.33, &types.Offer{
offer := &types.Offer{
ExchangeRate: 1,
})
}
s, err := a.InitiateProtocol(3.33, offer)
require.NoError(t, err)
require.Equal(t, a.swapState, s)
require.Equal(t, a.swapStates[offer.GetID()], s)
}

View File

@@ -69,21 +69,22 @@ type swapState struct {
exited bool
}
func newSwapState(b backend.Backend, infofile string, transferBack bool,
func newSwapState(b backend.Backend, offerID types.Hash, infofile string, transferBack bool,
providesAmount common.EtherAmount, receivedAmount common.MoneroAmount,
exchangeRate types.ExchangeRate) (*swapState, error) {
if b.Contract() == nil {
return nil, errNoSwapContractSet
}
if transferBack && b.XMRDepositAddress() == "" {
_, err := b.XMRDepositAddress(nil)
if transferBack && err != nil {
return nil, errMustProvideWalletAddress
}
stage := types.ExpectingKeys
statusCh := make(chan types.Status, 16)
statusCh <- stage
info := pswap.NewInfo(types.ProvidesETH, providesAmount.AsEther(), receivedAmount.AsMonero(),
info := pswap.NewInfo(offerID, types.ProvidesETH, providesAmount.AsEther(), receivedAmount.AsMonero(),
exchangeRate, stage, statusCh)
if err := b.SwapManager().AddSwap(info); err != nil {
return nil, err
@@ -167,7 +168,7 @@ func (s *swapState) receivedAmountInPiconero() common.MoneroAmount {
}
// ID returns the ID of the swap
func (s *swapState) ID() uint64 {
func (s *swapState) ID() types.Hash {
return s.info.ID()
}
@@ -187,17 +188,17 @@ func (s *swapState) Exit() error {
defer func() {
// stop all running goroutines
s.cancel()
s.SwapManager().CompleteOngoingSwap()
s.SwapManager().CompleteOngoingSwap(s.info.ID())
close(s.done)
if s.info.Status() == types.CompletedSuccess {
str := color.New(color.Bold).Sprintf("**swap completed successfully: id=%d**", s.info.ID())
str := color.New(color.Bold).Sprintf("**swap completed successfully: id=%s**", s.info.ID())
log.Info(str)
return
}
if s.info.Status() == types.CompletedRefund {
str := color.New(color.Bold).Sprintf("**swap refunded successfully! id=%d**", s.info.ID())
str := color.New(color.Bold).Sprintf("**swap refunded successfully! id=%s**", s.info.ID())
log.Info(str)
return
}
@@ -293,7 +294,7 @@ func (s *swapState) doRefund() (ethcommon.Hash, error) {
// send NotifyRefund msg
if err = s.SendSwapMessage(&message.NotifyRefund{
TxHash: txHash.String(),
}); err != nil {
}, s.ID()); err != nil {
return ethcommon.Hash{}, fmt.Errorf("failed to send refund message: err=%w", err)
}
@@ -383,7 +384,7 @@ func (s *swapState) lockETH(amount common.EtherAmount) (ethcommon.Hash, error) {
cmtXMRMaker := s.xmrmakerSecp256k1PublicKey.Keccak256()
nonce := generateNonce()
txHash, receipt, err := s.NewSwap(cmtXMRMaker, cmtXMRTaker,
txHash, receipt, err := s.NewSwap(s.ID(), cmtXMRMaker, cmtXMRTaker,
s.xmrmakerAddress, big.NewInt(int64(s.SwapTimeout().Seconds())), nonce, amount.BigInt())
if err != nil {
return ethcommon.Hash{}, fmt.Errorf("failed to instantiate swap on-chain: %w", err)
@@ -429,7 +430,7 @@ func (s *swapState) lockETH(amount common.EtherAmount) (ethcommon.Hash, error) {
// call Claim(). Ready() should only be called once XMRTaker sees XMRMaker lock his XMR.
// If time t_0 has passed, there is no point of calling Ready().
func (s *swapState) ready() error {
_, _, err := s.SetReady(s.contractSwap)
_, _, err := s.SetReady(s.ID(), s.contractSwap)
if err != nil {
if strings.Contains(err.Error(), revertSwapCompleted) && !s.info.Status().IsOngoing() {
return nil
@@ -452,7 +453,7 @@ func (s *swapState) refund() (ethcommon.Hash, error) {
sc := s.getSecret()
log.Infof("attempting to call Refund()...")
txHash, _, err := s.Refund(s.contractSwap, sc)
txHash, _, err := s.Refund(s.ID(), s.contractSwap, sc)
if err != nil {
return ethcommon.Hash{}, err
}
@@ -475,6 +476,9 @@ func (s *swapState) claimMonero(skB *mcrypto.PrivateSpendKey) (mcrypto.Address,
return "", err
}
s.LockClient()
defer s.UnlockClient()
addr, err := monero.CreateMoneroWallet("xmrtaker-swap-wallet", s.Env(), s.Backend, kpAB)
if err != nil {
return "", err
@@ -485,21 +489,27 @@ func (s *swapState) claimMonero(skB *mcrypto.PrivateSpendKey) (mcrypto.Address,
return addr, nil
}
log.Infof("monero claimed in account %s; transferring to original account %s",
addr, s.XMRDepositAddress())
id := s.ID()
depositAddr, err := s.XMRDepositAddress(&id)
if err != nil {
return "", err
}
err = mcrypto.ValidateAddress(string(s.XMRDepositAddress()))
log.Infof("monero claimed in account %s; transferring to original account %s",
addr, depositAddr)
err = mcrypto.ValidateAddress(string(depositAddr))
if err != nil {
log.Errorf("failed to transfer to original account, address %s is invalid", addr)
return addr, nil
}
err = s.waitUntilBalanceUnlocks()
err = s.waitUntilBalanceUnlocks(depositAddr)
if err != nil {
return "", fmt.Errorf("failed to wait for balance to unlock: %w", err)
}
res, err := s.SweepAll(s.XMRDepositAddress(), 0)
res, err := s.SweepAll(depositAddr, 0)
if err != nil {
return "", fmt.Errorf("failed to send funds to original account: %w", err)
}
@@ -511,14 +521,14 @@ func (s *swapState) claimMonero(skB *mcrypto.PrivateSpendKey) (mcrypto.Address,
amount := res.AmountList[0]
log.Infof("transferred %v XMR to %s",
common.MoneroAmount(amount).AsMonero(),
s.XMRDepositAddress(),
depositAddr,
)
close(s.claimedCh)
return addr, nil
}
func (s *swapState) waitUntilBalanceUnlocks() error {
func (s *swapState) waitUntilBalanceUnlocks(depositAddr mcrypto.Address) error {
for {
if s.ctx.Err() != nil {
return s.ctx.Err()
@@ -527,7 +537,7 @@ func (s *swapState) waitUntilBalanceUnlocks() error {
log.Infof("checking if balance unlocked...")
if s.Env() == common.Development {
_ = s.GenerateBlocks(string(s.XMRDepositAddress()), 64)
_ = s.GenerateBlocks(string(depositAddr), 64)
_ = s.Refresh()
}

View File

@@ -34,7 +34,7 @@ type mockNet struct {
msg net.Message
}
func (n *mockNet) SendSwapMessage(msg net.Message) error {
func (n *mockNet) SendSwapMessage(msg net.Message, _ types.Hash) error {
n.msg = msg
return nil
}
@@ -103,7 +103,7 @@ func newXMRMakerBackend(t *testing.T) backend.Backend {
func newTestInstance(t *testing.T) *swapState {
b := newBackend(t)
swapState, err := newSwapState(b, infofile, false,
swapState, err := newSwapState(b, types.Hash{}, infofile, false,
common.NewEtherAmount(1), common.MoneroAmount(0), 1)
require.NoError(t, err)
return swapState

View File

@@ -21,8 +21,8 @@ type Net interface {
Advertise()
Discover(provides types.ProvidesCoin, searchTime time.Duration) ([]peer.AddrInfo, error)
Query(who peer.AddrInfo) (*net.QueryResponse, error)
Initiate(who peer.AddrInfo, msg *net.SendKeysMessage, s common.SwapState) error
CloseProtocolStream()
Initiate(who peer.AddrInfo, msg *net.SendKeysMessage, s common.SwapStateNet) error
CloseProtocolStream(types.Hash)
}
// NetService is the RPC service prefixed by net_.
@@ -106,26 +106,25 @@ func (s *NetService) QueryPeer(_ *http.Request, req *rpctypes.QueryPeerRequest,
// TakeOffer initiates a swap with the given peer by taking an offer they've made.
func (s *NetService) TakeOffer(_ *http.Request, req *rpctypes.TakeOfferRequest,
resp *rpctypes.TakeOfferResponse) error {
id, _, infofile, err := s.takeOffer(req.Multiaddr, req.OfferID, req.ProvidesAmount)
_, infofile, err := s.takeOffer(req.Multiaddr, req.OfferID, req.ProvidesAmount)
if err != nil {
return err
}
resp.ID = id
resp.InfoFile = infofile
return nil
}
func (s *NetService) takeOffer(multiaddr, offerID string,
providesAmount float64) (uint64, <-chan types.Status, string, error) {
providesAmount float64) (<-chan types.Status, string, error) {
who, err := net.StringToAddrInfo(multiaddr)
if err != nil {
return 0, nil, "", err
return nil, "", err
}
queryResp, err := s.net.Query(who)
if err != nil {
return 0, nil, "", err
return nil, "", err
}
var (
@@ -141,17 +140,17 @@ func (s *NetService) takeOffer(multiaddr, offerID string,
}
if !found {
return 0, nil, "", errNoOfferWithID
return nil, "", errNoOfferWithID
}
swapState, err := s.xmrtaker.InitiateProtocol(providesAmount, offer)
if err != nil {
return 0, nil, "", err
return nil, "", err
}
skm, err := swapState.SendKeysMessage()
if err != nil {
return 0, nil, "", err
return nil, "", err
}
skm.OfferID = offerID
@@ -159,20 +158,24 @@ func (s *NetService) takeOffer(multiaddr, offerID string,
if err = s.net.Initiate(who, skm, swapState); err != nil {
_ = swapState.Exit()
return 0, nil, "", err
return nil, "", err
}
info := s.sm.GetOngoingSwap()
id, err := offerIDStringToHash(offerID)
if err != nil {
return nil, "", err
}
info := s.sm.GetOngoingSwap(id)
if info == nil {
return 0, nil, "", errFailedToGetSwapInfo
return nil, "", errFailedToGetSwapInfo
}
return swapState.ID(), info.StatusCh(), swapState.InfoFile(), nil
return info.StatusCh(), swapState.InfoFile(), nil
}
// TakeOfferSyncResponse ...
type TakeOfferSyncResponse struct {
ID uint64 `json:"id"`
InfoFile string `json:"infoFile"`
Status string `json:"status"`
}
@@ -181,12 +184,16 @@ type TakeOfferSyncResponse struct {
// It synchronously waits until the swap is completed before returning its status.
func (s *NetService) TakeOfferSync(_ *http.Request, req *rpctypes.TakeOfferRequest,
resp *TakeOfferSyncResponse) error {
id, _, infofile, err := s.takeOffer(req.Multiaddr, req.OfferID, req.ProvidesAmount)
offerID, err := offerIDStringToHash(req.OfferID)
if err != nil {
return err
}
_, infofile, err := s.takeOffer(req.Multiaddr, req.OfferID, req.ProvidesAmount)
if err != nil {
return err
}
resp.ID = id
resp.InfoFile = infofile
const checkSwapSleepDuration = time.Millisecond * 100
@@ -194,7 +201,7 @@ func (s *NetService) TakeOfferSync(_ *http.Request, req *rpctypes.TakeOfferReque
for {
time.Sleep(checkSwapSleepDuration)
info := s.sm.GetPastSwap(resp.ID)
info := s.sm.GetPastSwap(offerID)
if info == nil {
continue
}

View File

@@ -4,7 +4,6 @@ import (
"testing"
"github.com/noot/atomic-swap/common/rpctypes"
"github.com/noot/atomic-swap/common/types"
"github.com/stretchr/testify/require"
)
@@ -40,11 +39,9 @@ func TestNet_Query(t *testing.T) {
func TestNet_TakeOffer(t *testing.T) {
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager))
offer := &types.Offer{}
req := &rpctypes.TakeOfferRequest{
Multiaddr: "/ip4/127.0.0.1/tcp/9900/p2p/12D3KooWDqCzbjexHEa8Rut7bzxHFpRMZyDRW1L6TGkL1KY24JH5",
OfferID: offer.GetID().String(),
OfferID: testSwapID.String(),
ProvidesAmount: 1,
}
@@ -52,17 +49,14 @@ func TestNet_TakeOffer(t *testing.T) {
err := ns.TakeOffer(nil, req, resp)
require.NoError(t, err)
require.Equal(t, testSwapID, resp.ID)
}
func TestNet_TakeOfferSync(t *testing.T) {
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager))
offer := &types.Offer{}
req := &rpctypes.TakeOfferRequest{
Multiaddr: "/ip4/127.0.0.1/tcp/9900/p2p/12D3KooWDqCzbjexHEa8Rut7bzxHFpRMZyDRW1L6TGkL1KY24JH5",
OfferID: offer.GetID().String(),
OfferID: testSwapID.String(),
ProvidesAmount: 1,
}
@@ -70,5 +64,4 @@ func TestNet_TakeOfferSync(t *testing.T) {
err := ns.TakeOfferSync(nil, req, resp)
require.NoError(t, err)
require.Equal(t, testSwapID, resp.ID)
}

View File

@@ -109,7 +109,7 @@ func (s *Server) Start() <-chan error {
// Protocol represents the functions required by the rpc service into the protocol handler.
type Protocol interface {
Provides() types.ProvidesCoin
GetOngoingSwapState() common.SwapState
GetOngoingSwapState(types.Hash) common.SwapState
}
// ProtocolBackend represents protocol/backend.Backend
@@ -119,14 +119,14 @@ type ProtocolBackend interface {
SwapManager() swap.Manager
ExternalSender() *txsender.ExternalSender
SetEthAddress(ethcommon.Address)
SetXMRDepositAddress(mcrypto.Address)
SetXMRDepositAddress(mcrypto.Address, types.Hash)
}
// XMRTaker ...
type XMRTaker interface {
Protocol
InitiateProtocol(providesAmount float64, offer *types.Offer) (common.SwapState, error)
Refund() (ethcommon.Hash, error)
Refund(types.Hash) (ethcommon.Hash, error)
}
// XMRMaker ...

View File

@@ -1,6 +1,7 @@
package rpc
import (
"encoding/hex"
"fmt"
"net/http"
@@ -28,18 +29,22 @@ func NewSwapService(sm SwapManager, xmrtaker XMRTaker, xmrmaker XMRMaker, net Ne
// GetPastIDsResponse ...
type GetPastIDsResponse struct {
IDs []uint64 `json:"ids"`
IDs []string `json:"ids"`
}
// GetPastIDs returns all past swap IDs
func (s *SwapService) GetPastIDs(_ *http.Request, _ *interface{}, resp *GetPastIDsResponse) error {
resp.IDs = s.sm.GetPastIDs()
ids := s.sm.GetPastIDs()
resp.IDs = make([]string, len(ids))
for i := range resp.IDs {
resp.IDs[i] = ids[i].String()
}
return nil
}
// GetPastRequest ...
type GetPastRequest struct {
ID uint64 `json:"id"`
OfferID string `json:"offerID"`
}
// GetPastResponse ...
@@ -53,7 +58,12 @@ type GetPastResponse struct {
// GetPast returns information about a past swap, given its ID.
func (s *SwapService) GetPast(_ *http.Request, req *GetPastRequest, resp *GetPastResponse) error {
info := s.sm.GetPastSwap(req.ID)
offerID, err := offerIDStringToHash(req.OfferID)
if err != nil {
return err
}
info := s.sm.GetPastSwap(offerID)
if info == nil {
return errNoSwapWithID
}
@@ -68,7 +78,6 @@ func (s *SwapService) GetPast(_ *http.Request, req *GetPastRequest, resp *GetPas
// GetOngoingResponse ...
type GetOngoingResponse struct {
ID uint64 `json:"id"`
Provided types.ProvidesCoin `json:"provided"`
ProvidedAmount float64 `json:"providedAmount"`
ReceivedAmount float64 `json:"receivedAmount"`
@@ -76,14 +85,23 @@ type GetOngoingResponse struct {
Status string `json:"status"`
}
// GetOngoingRequest ...
type GetOngoingRequest struct {
OfferID string `json:"id"`
}
// GetOngoing returns information about the ongoing swap, if there is one.
func (s *SwapService) GetOngoing(_ *http.Request, _ *interface{}, resp *GetOngoingResponse) error {
info := s.sm.GetOngoingSwap()
func (s *SwapService) GetOngoing(_ *http.Request, req *GetOngoingRequest, resp *GetOngoingResponse) error {
offerID, err := offerIDStringToHash(req.OfferID)
if err != nil {
return err
}
info := s.sm.GetOngoingSwap(offerID)
if info == nil {
return errNoOngoingSwap
}
resp.ID = info.ID()
resp.Provided = info.Provides()
resp.ProvidedAmount = info.ProvidedAmount()
resp.ReceivedAmount = info.ReceivedAmount()
@@ -92,6 +110,11 @@ func (s *SwapService) GetOngoing(_ *http.Request, _ *interface{}, resp *GetOngoi
return nil
}
// RefundRequest ...
type RefundRequest struct {
OfferID string `json:"id"`
}
// RefundResponse ...
type RefundResponse struct {
TxHash string `json:"transactionHash"`
@@ -99,8 +122,13 @@ type RefundResponse struct {
// Refund refunds the ongoing swap if we are the ETH provider.
// TODO: remove in favour of swap_cancel?
func (s *SwapService) Refund(_ *http.Request, _ *interface{}, resp *RefundResponse) error {
info := s.sm.GetOngoingSwap()
func (s *SwapService) Refund(_ *http.Request, req *RefundRequest, resp *RefundResponse) error {
offerID, err := offerIDStringToHash(req.OfferID)
if err != nil {
return err
}
info := s.sm.GetOngoingSwap(offerID)
if info == nil {
return errNoOngoingSwap
}
@@ -109,7 +137,7 @@ func (s *SwapService) Refund(_ *http.Request, _ *interface{}, resp *RefundRespon
return errCannotRefund
}
txHash, err := s.xmrtaker.Refund()
txHash, err := s.xmrtaker.Refund(offerID)
if err != nil {
return fmt.Errorf("failed to refund: %w", err)
}
@@ -118,6 +146,11 @@ func (s *SwapService) Refund(_ *http.Request, _ *interface{}, resp *RefundRespon
return nil
}
// GetStageRequest ...
type GetStageRequest struct {
OfferID string `json:"id"`
}
// GetStageResponse ...
type GetStageResponse struct {
Stage string `json:"stage"`
@@ -125,8 +158,13 @@ type GetStageResponse struct {
}
// GetStage returns the stage of the ongoing swap, if there is one.
func (s *SwapService) GetStage(_ *http.Request, _ *interface{}, resp *GetStageResponse) error {
info := s.sm.GetOngoingSwap()
func (s *SwapService) GetStage(_ *http.Request, req *GetStageRequest, resp *GetStageResponse) error {
offerID, err := offerIDStringToHash(req.OfferID)
if err != nil {
return err
}
info := s.sm.GetOngoingSwap(offerID)
if info == nil {
return errNoOngoingSwap
}
@@ -147,14 +185,24 @@ func (s *SwapService) GetOffers(_ *http.Request, _ *interface{}, resp *GetOffers
return nil
}
// CancelRequest ...
type CancelRequest struct {
OfferID string `json:"id"`
}
// CancelResponse ...
type CancelResponse struct {
Status types.Status `json:"status"`
}
// Cancel attempts to cancel the currently ongoing swap, if there is one.
func (s *SwapService) Cancel(_ *http.Request, _ *interface{}, resp *CancelResponse) error {
info := s.sm.GetOngoingSwap()
func (s *SwapService) Cancel(_ *http.Request, req *CancelRequest, resp *CancelResponse) error {
offerID, err := offerIDStringToHash(req.OfferID)
if err != nil {
return err
}
info := s.sm.GetOngoingSwap(offerID)
if info == nil {
return errNoOngoingSwap
}
@@ -162,17 +210,27 @@ func (s *SwapService) Cancel(_ *http.Request, _ *interface{}, resp *CancelRespon
var ss common.SwapState
switch info.Provides() {
case types.ProvidesETH:
ss = s.xmrtaker.GetOngoingSwapState()
ss = s.xmrtaker.GetOngoingSwapState(offerID)
case types.ProvidesXMR:
ss = s.xmrmaker.GetOngoingSwapState()
ss = s.xmrmaker.GetOngoingSwapState(offerID)
}
if err := ss.Exit(); err != nil {
return err
}
s.net.CloseProtocolStream()
s.net.CloseProtocolStream(offerID)
info = s.sm.GetPastSwap(info.ID())
resp.Status = info.Status()
return nil
}
func offerIDStringToHash(s string) (types.Hash, error) {
offerIDBytes, err := hex.DecodeString(s)
if err != nil {
return types.Hash{}, err
}
var offerID types.Hash
copy(offerID[:], offerIDBytes)
return offerID, nil
}

View File

@@ -32,12 +32,11 @@ func checkOriginFunc(r *http.Request) bool {
}
type wsServer struct {
ctx context.Context
sm SwapManager
ns *NetService
backend ProtocolBackend
txsOutCh <-chan *txsender.Transaction
txsInCh chan<- ethcommon.Hash
ctx context.Context
sm SwapManager
ns *NetService
backend ProtocolBackend
signer *txsender.ExternalSender
}
func newWsServer(ctx context.Context, sm SwapManager, ns *NetService, backend ProtocolBackend,
@@ -47,11 +46,7 @@ func newWsServer(ctx context.Context, sm SwapManager, ns *NetService, backend Pr
sm: sm,
ns: ns,
backend: backend,
}
if signer != nil {
s.txsOutCh = signer.OngoingCh()
s.txsInCh = signer.IncomingCh()
signer: signer,
}
return s
@@ -139,12 +134,12 @@ func (s *wsServer) handleRequest(conn *websocket.Conn, req *rpctypes.Request) er
return fmt.Errorf("failed to unmarshal parameters: %w", err)
}
id, ch, infofile, err := s.ns.takeOffer(params.Multiaddr, params.OfferID, params.ProvidesAmount)
ch, infofile, err := s.ns.takeOffer(params.Multiaddr, params.OfferID, params.ProvidesAmount)
if err != nil {
return err
}
return s.subscribeTakeOffer(s.ctx, conn, id, ch, infofile)
return s.subscribeTakeOffer(s.ctx, conn, ch, infofile)
case subscribeMakeOffer:
var params *rpctypes.MakeOfferRequest
if err := json.Unmarshal(req.Params, &params); err != nil {
@@ -163,8 +158,9 @@ func (s *wsServer) handleRequest(conn *websocket.Conn, req *rpctypes.Request) er
}
}
func (s *wsServer) handleSigner(ctx context.Context, conn *websocket.Conn, offerID, ethAddress, xmrAddr string) error {
if s.txsOutCh == nil {
func (s *wsServer) handleSigner(ctx context.Context, conn *websocket.Conn, offerIDStr, ethAddress,
xmrAddr string) error {
if s.signer == nil {
return errSignerNotRequired
}
@@ -173,16 +169,35 @@ func (s *wsServer) handleSigner(ctx context.Context, conn *websocket.Conn, offer
}
s.backend.SetEthAddress(ethcommon.HexToAddress(ethAddress))
s.backend.SetXMRDepositAddress(mcrypto.Address(xmrAddr))
offerID, err := offerIDStringToHash(offerIDStr)
if err != nil {
return err
}
s.backend.SetXMRDepositAddress(mcrypto.Address(xmrAddr), offerID)
s.signer.AddID(offerID)
defer s.signer.DeleteID(offerID)
txsOutCh, err := s.signer.OngoingCh(offerID)
if err != nil {
return err
}
txsInCh, err := s.signer.IncomingCh(offerID)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
return nil
case tx := <-s.txsOutCh:
case tx := <-txsOutCh:
log.Debugf("outbound tx: %v", tx)
resp := &rpctypes.SignerResponse{
OfferID: offerID,
OfferID: offerIDStr,
To: tx.To.String(),
Data: tx.Data,
Value: tx.Value,
@@ -203,19 +218,18 @@ func (s *wsServer) handleSigner(ctx context.Context, conn *websocket.Conn, offer
return fmt.Errorf("failed to unmarshal parameters: %w", err)
}
if params.OfferID != offerID {
if params.OfferID != offerIDStr {
return fmt.Errorf("got unexpected offerID %s, expected %s", params.OfferID, offerID)
}
s.txsInCh <- ethcommon.HexToHash(params.TxHash)
txsInCh <- ethcommon.HexToHash(params.TxHash)
}
}
}
func (s *wsServer) subscribeTakeOffer(ctx context.Context, conn *websocket.Conn,
id uint64, statusCh <-chan types.Status, infofile string) error {
statusCh <-chan types.Status, infofile string) error {
resp := &rpctypes.TakeOfferResponse{
ID: id,
InfoFile: infofile,
}
@@ -258,30 +272,6 @@ func (s *wsServer) subscribeMakeOffer(ctx context.Context, conn *websocket.Conn,
return err
}
// then check for swap ID to be sent when swap is initiated
var taken bool
for {
if taken {
break
}
select {
case id := <-offerExtra.IDCh:
idMsg := map[string]uint64{
"id": id,
}
if err := writeResponse(conn, idMsg); err != nil {
return err
}
taken = true
case <-ctx.Done():
return nil
}
}
// finally, read the swap's status
for {
select {
case status, ok := <-offerExtra.StatusCh:
@@ -309,8 +299,8 @@ func (s *wsServer) subscribeMakeOffer(ctx context.Context, conn *websocket.Conn,
// subscribeSwapStatus writes the swap's stage to the connection every time it updates.
// when the swap completes, it writes the final status then closes the connection.
// example: `{"jsonrpc":"2.0", "method":"swap_subscribeStatus", "params": {"id": 0}, "id": 0}`
func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn, id uint64) error {
info := s.sm.GetOngoingSwap()
func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn, id types.Hash) error {
info := s.sm.GetOngoingSwap(id)
if info == nil {
return s.writeSwapExitStatus(conn, id)
}
@@ -340,7 +330,7 @@ func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn
}
}
func (s *wsServer) writeSwapExitStatus(conn *websocket.Conn, id uint64) error {
func (s *wsServer) writeSwapExitStatus(conn *websocket.Conn, id types.Hash) error {
info := s.sm.GetPastSwap(id)
if info == nil {
return errNoSwapWithID

View File

@@ -22,11 +22,11 @@ import (
)
const (
testSwapID uint64 = 77
testMultiaddr = "/ip4/192.168.0.102/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2"
testMultiaddr = "/ip4/192.168.0.102/tcp/9933/p2p/12D3KooWAYn1T8Lu122Pav4zAogjpeU61usLTNZpLRNh9gCqY6X2"
)
var (
testSwapID = types.Hash{99}
testTImeout = time.Second * 5
defaultRPCPort uint16 = 3001
defaultWSPort uint16 = 4002
@@ -48,28 +48,29 @@ func (*mockNet) Discover(provides types.ProvidesCoin, searchTime time.Duration)
func (*mockNet) Query(who peer.AddrInfo) (*net.QueryResponse, error) {
return &net.QueryResponse{
Offers: []*types.Offer{
{},
{ID: testSwapID},
},
}, nil
}
func (*mockNet) Initiate(who peer.AddrInfo, msg *net.SendKeysMessage, s common.SwapState) error {
func (*mockNet) Initiate(who peer.AddrInfo, msg *net.SendKeysMessage, s common.SwapStateNet) error {
return nil
}
func (*mockNet) CloseProtocolStream() {}
func (*mockNet) CloseProtocolStream(types.Hash) {}
type mockSwapManager struct{}
func (*mockSwapManager) GetPastIDs() []uint64 {
return []uint64{}
func (*mockSwapManager) GetPastIDs() []types.Hash {
return []types.Hash{}
}
func (*mockSwapManager) GetPastSwap(id uint64) *swap.Info {
func (*mockSwapManager) GetPastSwap(id types.Hash) *swap.Info {
return &swap.Info{}
}
func (*mockSwapManager) GetOngoingSwap() *swap.Info {
func (*mockSwapManager) GetOngoingSwap(id types.Hash) *swap.Info {
statusCh := make(chan types.Status, 1)
statusCh <- types.CompletedSuccess
return swap.NewInfo(
id,
types.ProvidesETH,
1,
1,
@@ -81,7 +82,7 @@ func (*mockSwapManager) GetOngoingSwap() *swap.Info {
func (*mockSwapManager) AddSwap(*swap.Info) error {
return nil
}
func (*mockSwapManager) CompleteOngoingSwap() {}
func (*mockSwapManager) CompleteOngoingSwap(types.Hash) {}
type mockXMRTaker struct{}
@@ -89,13 +90,13 @@ func (*mockXMRTaker) Provides() types.ProvidesCoin {
return types.ProvidesETH
}
func (*mockXMRTaker) SetGasPrice(gasPrice uint64) {}
func (*mockXMRTaker) GetOngoingSwapState() common.SwapState {
func (*mockXMRTaker) GetOngoingSwapState(types.Hash) common.SwapState {
return new(mockSwapState)
}
func (*mockXMRTaker) InitiateProtocol(providesAmount float64, _ *types.Offer) (common.SwapState, error) {
return new(mockSwapState), nil
}
func (*mockXMRTaker) Refund() (ethcommon.Hash, error) {
func (*mockXMRTaker) Refund(types.Hash) (ethcommon.Hash, error) {
return ethcommon.Hash{}, nil
}
func (*mockXMRTaker) SetSwapTimeout(_ time.Duration) {}
@@ -111,7 +112,7 @@ func (*mockSwapState) Exit() error {
func (*mockSwapState) SendKeysMessage() (*message.SendKeysMessage, error) {
return &message.SendKeysMessage{}, nil
}
func (*mockSwapState) ID() uint64 {
func (*mockSwapState) ID() types.Hash {
return testSwapID
}
func (*mockSwapState) InfoFile() string {
@@ -136,8 +137,8 @@ func (b *mockProtocolBackend) SwapManager() swap.Manager {
func (*mockProtocolBackend) ExternalSender() *txsender.ExternalSender {
return nil
}
func (*mockProtocolBackend) SetEthAddress(ethcommon.Address) {}
func (*mockProtocolBackend) SetXMRDepositAddress(mcrypto.Address) {}
func (*mockProtocolBackend) SetEthAddress(ethcommon.Address) {}
func (*mockProtocolBackend) SetXMRDepositAddress(mcrypto.Address, types.Hash) {}
func newServer(t *testing.T) *Server {
ctx, cancel := context.WithCancel(context.Background())
@@ -222,11 +223,9 @@ func TestSubscribeTakeOffer(t *testing.T) {
c, err := wsclient.NewWsClient(ctx, defaultWSEndpoint())
require.NoError(t, err)
offerID := (&types.Offer{}).GetID()
id, ch, err := c.TakeOfferAndSubscribe(testMultiaddr, offerID.String(), 1)
ch, err := c.TakeOfferAndSubscribe(testMultiaddr, testSwapID.String(), 1)
require.NoError(t, err)
require.Equal(t, id, testSwapID)
select {
case status := <-ch:
require.Equal(t, types.CompletedSuccess, status)

View File

@@ -9,12 +9,21 @@ import (
)
// Cancel calls swap_cancel.
func (c *Client) Cancel() (types.Status, error) {
func (c *Client) Cancel(id string) (types.Status, error) {
const (
method = "swap_cancel"
)
resp, err := rpctypes.PostRPC(c.endpoint, method, "{}")
req := &rpc.CancelRequest{
OfferID: id,
}
params, err := json.Marshal(req)
if err != nil {
return 0, err
}
resp, err := rpctypes.PostRPC(c.endpoint, method, string(params))
if err != nil {
return 0, err
}

View File

@@ -9,7 +9,7 @@ import (
)
// GetPastSwapIDs calls swap_getPastIDs
func (c *Client) GetPastSwapIDs() ([]uint64, error) {
func (c *Client) GetPastSwapIDs() ([]string, error) {
const (
method = "swap_getPastIDs"
)
@@ -32,12 +32,21 @@ func (c *Client) GetPastSwapIDs() ([]uint64, error) {
}
// GetOngoingSwap calls swap_getOngoing
func (c *Client) GetOngoingSwap() (*rpc.GetOngoingResponse, error) {
func (c *Client) GetOngoingSwap(id string) (*rpc.GetOngoingResponse, error) {
const (
method = "swap_getOngoing"
)
resp, err := rpctypes.PostRPC(c.endpoint, method, "{}")
req := &rpc.GetOngoingRequest{
OfferID: id,
}
params, err := json.Marshal(req)
if err != nil {
return nil, err
}
resp, err := rpctypes.PostRPC(c.endpoint, method, string(params))
if err != nil {
return nil, err
}
@@ -55,13 +64,13 @@ func (c *Client) GetOngoingSwap() (*rpc.GetOngoingResponse, error) {
}
// GetPastSwap calls swap_getPast
func (c *Client) GetPastSwap(id uint64) (*rpc.GetPastResponse, error) {
func (c *Client) GetPastSwap(id string) (*rpc.GetPastResponse, error) {
const (
method = "swap_getPast"
)
req := &rpc.GetPastRequest{
ID: id,
OfferID: id,
}
params, err := json.Marshal(req)
@@ -87,12 +96,21 @@ func (c *Client) GetPastSwap(id uint64) (*rpc.GetPastResponse, error) {
}
// Refund calls swap_refund
func (c *Client) Refund() (*rpc.RefundResponse, error) {
func (c *Client) Refund(id string) (*rpc.RefundResponse, error) {
const (
method = "swap_refund"
)
resp, err := rpctypes.PostRPC(c.endpoint, method, "{}")
req := &rpc.RefundRequest{
OfferID: id,
}
params, err := json.Marshal(req)
if err != nil {
return nil, err
}
resp, err := rpctypes.PostRPC(c.endpoint, method, string(params))
if err != nil {
return nil, err
}
@@ -110,12 +128,21 @@ func (c *Client) Refund() (*rpc.RefundResponse, error) {
}
// GetStage calls swap_getStage
func (c *Client) GetStage() (*rpc.GetStageResponse, error) {
func (c *Client) GetStage(id string) (*rpc.GetStageResponse, error) {
const (
method = "swap_getStage"
)
resp, err := rpctypes.PostRPC(c.endpoint, method, "{}")
req := &rpc.GetStageRequest{
OfferID: id,
}
params, err := json.Marshal(req)
if err != nil {
return nil, err
}
resp, err := rpctypes.PostRPC(c.endpoint, method, string(params))
if err != nil {
return nil, err
}

View File

@@ -8,7 +8,7 @@ import (
)
// TakeOffer calls net_takeOffer.
func (c *Client) TakeOffer(maddr string, offerID string, providesAmount float64) (uint64, error) {
func (c *Client) TakeOffer(maddr string, offerID string, providesAmount float64) error {
const (
method = "net_takeOffer"
)
@@ -21,22 +21,22 @@ func (c *Client) TakeOffer(maddr string, offerID string, providesAmount float64)
params, err := json.Marshal(req)
if err != nil {
return 0, err
return err
}
resp, err := rpctypes.PostRPC(c.endpoint, method, string(params))
if err != nil {
return 0, err
return err
}
if resp.Error != nil {
return 0, fmt.Errorf("failed to call %s: %w", method, resp.Error)
return fmt.Errorf("failed to call %s: %w", method, resp.Error)
}
var res *rpctypes.TakeOfferResponse
if err = json.Unmarshal(resp.Result, &res); err != nil {
return 0, err
return err
}
return res.ID, nil
return nil
}

View File

@@ -24,7 +24,7 @@ type WsClient interface {
TakeOfferAndSubscribe(multiaddr, offerID string,
providesAmount float64) (id uint64, ch <-chan types.Status, err error)
MakeOfferAndSubscribe(min, max float64,
exchangeRate types.ExchangeRate) (string, <-chan *MakeOfferTakenResponse, <-chan types.Status, error)
exchangeRate types.ExchangeRate) (string, <-chan types.Status, error)
}
type wsClient struct {
@@ -164,7 +164,7 @@ func (c *wsClient) Query(maddr string) (*rpctypes.QueryPeerResponse, error) {
// SubscribeSwapStatus returns a channel that is written to each time the swap's status updates.
// If there is no swap with the given ID, it returns an error.
func (c *wsClient) SubscribeSwapStatus(id uint64) (<-chan types.Status, error) {
func (c *wsClient) SubscribeSwapStatus(id types.Hash) (<-chan types.Status, error) {
params := &rpctypes.SubscribeSwapStatusRequest{
ID: id,
}
@@ -228,7 +228,7 @@ func (c *wsClient) SubscribeSwapStatus(id uint64) (<-chan types.Status, error) {
}
func (c *wsClient) TakeOfferAndSubscribe(multiaddr, offerID string,
providesAmount float64) (id uint64, ch <-chan types.Status, err error) {
providesAmount float64) (ch <-chan types.Status, err error) {
params := &rpctypes.TakeOfferRequest{
Multiaddr: multiaddr,
OfferID: offerID,
@@ -237,7 +237,7 @@ func (c *wsClient) TakeOfferAndSubscribe(multiaddr, offerID string,
bz, err := json.Marshal(params)
if err != nil {
return 0, nil, err
return nil, err
}
req := &rpctypes.Request{
@@ -248,29 +248,29 @@ func (c *wsClient) TakeOfferAndSubscribe(multiaddr, offerID string,
}
if err = c.writeJSON(req); err != nil {
return 0, nil, err
return nil, err
}
// read ID from connection
message, err := c.read()
if err != nil {
return 0, nil, fmt.Errorf("failed to read websockets message: %s", err)
return nil, fmt.Errorf("failed to read websockets message: %s", err)
}
var resp *rpctypes.Response
err = json.Unmarshal(message, &resp)
if err != nil {
return 0, nil, fmt.Errorf("failed to unmarshal response: %w", err)
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if resp.Error != nil {
return 0, nil, fmt.Errorf("websocket server returned error: %w", resp.Error)
return nil, fmt.Errorf("websocket server returned error: %w", resp.Error)
}
log.Debugf("received message over websockets: %s", message)
var idResp *rpctypes.TakeOfferResponse
if err := json.Unmarshal(resp.Result, &idResp); err != nil {
return 0, nil, fmt.Errorf("failed to unmarshal swap ID response: %s", err)
return nil, fmt.Errorf("failed to unmarshal swap ID response: %s", err)
}
respCh := make(chan types.Status)
@@ -312,16 +312,11 @@ func (c *wsClient) TakeOfferAndSubscribe(multiaddr, offerID string,
}
}()
return idResp.ID, respCh, nil
}
// MakeOfferTakenResponse contains the swap ID
type MakeOfferTakenResponse struct {
ID uint64 `json:"id"`
return respCh, nil
}
func (c *wsClient) MakeOfferAndSubscribe(min, max float64,
exchangeRate types.ExchangeRate) (string, <-chan *MakeOfferTakenResponse, <-chan types.Status, error) {
exchangeRate types.ExchangeRate) (string, <-chan types.Status, error) {
params := &rpctypes.MakeOfferRequest{
MinimumAmount: min,
MaximumAmount: max,
@@ -330,7 +325,7 @@ func (c *wsClient) MakeOfferAndSubscribe(min, max float64,
bz, err := json.Marshal(params)
if err != nil {
return "", nil, nil, err
return "", nil, err
}
req := &rpctypes.Request{
@@ -340,73 +335,36 @@ func (c *wsClient) MakeOfferAndSubscribe(min, max float64,
ID: 0,
}
log.Debug("writing net_makeOfferAndSubscribe")
if err = c.writeJSON(req); err != nil {
return "", nil, nil, err
return "", nil, err
}
log.Debugf("wrote")
// read ID from connection
message, err := c.read()
if err != nil {
return "", nil, nil, fmt.Errorf("failed to read websockets message: %s", err)
return "", nil, fmt.Errorf("failed to read websockets message: %s", err)
}
log.Debugf("got response")
var resp *rpctypes.Response
err = json.Unmarshal(message, &resp)
if err != nil {
return "", nil, nil, fmt.Errorf("failed to unmarshal response: %w", err)
return "", nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if resp.Error != nil {
return "", nil, nil, fmt.Errorf("websocket server returned error: %w", resp.Error)
return "", nil, fmt.Errorf("websocket server returned error: %w", resp.Error)
}
// read synchronous response (offer ID and infofile)
log.Debugf("received message over websockets: %s", message)
var respData *rpctypes.MakeOfferResponse
if err := json.Unmarshal(resp.Result, &respData); err != nil {
return "", nil, nil, fmt.Errorf("failed to unmarshal response: %s", err)
return "", nil, fmt.Errorf("failed to unmarshal response: %s", err)
}
takenCh := make(chan *MakeOfferTakenResponse)
respCh := make(chan types.Status)
go func() {
defer close(respCh)
defer close(takenCh)
// read if swap was taken
message, err := c.read()
if err != nil {
log.Warnf("failed to read websockets message: %s", err)
return
}
var resp *rpctypes.Response
err = json.Unmarshal(message, &resp)
if err != nil {
log.Warnf("failed to unmarshal response: %s", err)
return
}
if resp.Error != nil {
log.Warnf("websocket server returned error: %s", resp.Error)
return
}
log.Debugf("received message over websockets: %s", message)
var taken *MakeOfferTakenResponse
if err := json.Unmarshal(resp.Result, &taken); err != nil {
log.Warnf("failed to unmarshal response: %s", err)
return
}
takenCh <- taken
for {
message, err := c.read()
@@ -442,5 +400,5 @@ func (c *wsClient) MakeOfferAndSubscribe(min, max float64,
}
}()
return respData.ID, takenCh, respCh, nil
return respData.ID, respCh, nil
}

View File

@@ -133,7 +133,7 @@ func TestXMRTaker_Query(t *testing.T) {
}
func TestSuccess(t *testing.T) {
const testTimeout = time.Second * 5
const testTimeout = time.Second * 60
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -141,7 +141,7 @@ func TestSuccess(t *testing.T) {
bwsc, err := wsclient.NewWsClient(ctx, defaultXMRMakerDaemonWSEndpoint)
require.NoError(t, err)
offerID, takenCh, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
offerID, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
types.ExchangeRate(exchangeRate))
require.NoError(t, err)
@@ -149,7 +149,6 @@ func TestSuccess(t *testing.T) {
offersBefore, err := bc.GetOffers()
require.NoError(t, err)
xmrmakerIDCh := make(chan uint64, 1)
errCh := make(chan error, 2)
var wg sync.WaitGroup
@@ -158,26 +157,22 @@ func TestSuccess(t *testing.T) {
go func() {
defer wg.Done()
select {
case taken := <-takenCh:
require.NotNil(t, taken)
t.Log("swap ID:", taken.ID)
xmrmakerIDCh <- taken.ID
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
for {
select {
case status := <-statusCh:
fmt.Println("> XMRMaker got status:", status)
if status.IsOngoing() {
continue
}
for status := range statusCh {
fmt.Println("> XMRMaker got status:", status)
if status.IsOngoing() {
continue
if status != types.CompletedSuccess {
errCh <- fmt.Errorf("swap did not complete successfully: got %s", status)
}
return
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
if status != types.CompletedSuccess {
errCh <- fmt.Errorf("swap did not complete successfully: got %s", status)
}
return
}
}()
@@ -191,7 +186,7 @@ func TestSuccess(t *testing.T) {
require.Equal(t, 1, len(providers))
require.GreaterOrEqual(t, len(providers[0]), 2)
id, takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
require.NoError(t, err)
go func() {
@@ -218,9 +213,6 @@ func TestSuccess(t *testing.T) {
default:
}
xmrmakerSwapID := <-xmrmakerIDCh
require.Equal(t, id, xmrmakerSwapID)
offersAfter, err := bc.GetOffers()
require.NoError(t, err)
require.Equal(t, 1, len(offersBefore)-len(offersAfter))
@@ -232,7 +224,7 @@ func TestRefund_XMRTakerCancels(t *testing.T) {
}
const (
testTimeout = time.Second * 5
testTimeout = time.Second * 60
swapTimeout = 5 // 5s
)
@@ -242,7 +234,7 @@ func TestRefund_XMRTakerCancels(t *testing.T) {
bwsc, err := wsclient.NewWsClient(ctx, defaultXMRMakerDaemonWSEndpoint)
require.NoError(t, err)
offerID, takenCh, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
offerID, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
types.ExchangeRate(exchangeRate))
require.NoError(t, err)
@@ -250,7 +242,6 @@ func TestRefund_XMRTakerCancels(t *testing.T) {
offersBefore, err := bc.GetOffers()
require.NoError(t, err)
xmrmakerIDCh := make(chan uint64, 1)
errCh := make(chan error, 2)
var wg sync.WaitGroup
@@ -259,26 +250,22 @@ func TestRefund_XMRTakerCancels(t *testing.T) {
go func() {
defer wg.Done()
select {
case taken := <-takenCh:
require.NotNil(t, taken)
t.Log("swap ID:", taken.ID)
xmrmakerIDCh <- taken.ID
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
for {
select {
case status := <-statusCh:
fmt.Println("> XMRMaker got status:", status)
if status.IsOngoing() {
continue
}
for status := range statusCh {
fmt.Println("> XMRMaker got status:", status)
if status.IsOngoing() {
continue
if status != types.CompletedRefund {
errCh <- fmt.Errorf("swap did not refund successfully for XMRMaker: got %s", status)
}
return
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
if status != types.CompletedRefund {
errCh <- fmt.Errorf("swap did not refund successfully for XMRMaker: exit status was %s", status)
}
return
}
}()
@@ -294,7 +281,7 @@ func TestRefund_XMRTakerCancels(t *testing.T) {
require.Equal(t, 1, len(providers))
require.GreaterOrEqual(t, len(providers[0]), 2)
id, takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
require.NoError(t, err)
go func() {
@@ -306,7 +293,7 @@ func TestRefund_XMRTakerCancels(t *testing.T) {
}
fmt.Println("> XMRTaker cancelled swap!")
exitStatus, err := c.Cancel() //nolint:govet
exitStatus, err := c.Cancel(offerID) //nolint:govet
if err != nil {
t.Log("XMRTaker got error", err)
errCh <- err
@@ -329,9 +316,6 @@ func TestRefund_XMRTakerCancels(t *testing.T) {
default:
}
xmrmakerSwapID := <-xmrmakerIDCh
require.Equal(t, id, xmrmakerSwapID)
offersAfter, err := bc.GetOffers()
require.NoError(t, err)
require.Equal(t, len(offersBefore), len(offersAfter))
@@ -358,7 +342,7 @@ func testRefundXMRMakerCancels(t *testing.T, swapTimeout uint64, expectedExitSta
generateBlocks(64)
}
const testTimeout = time.Second * 5
const testTimeout = time.Second * 60
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -367,14 +351,13 @@ func testRefundXMRMakerCancels(t *testing.T, swapTimeout uint64, expectedExitSta
bwsc, err := wsclient.NewWsClient(ctx, defaultXMRMakerDaemonWSEndpoint)
require.NoError(t, err)
offerID, takenCh, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
offerID, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
types.ExchangeRate(exchangeRate))
require.NoError(t, err)
offersBefore, err := bcli.GetOffers()
require.NoError(t, err)
xmrmakerIDCh := make(chan uint64, 1)
errCh := make(chan error, 2)
var wg sync.WaitGroup
@@ -383,35 +366,31 @@ func testRefundXMRMakerCancels(t *testing.T, swapTimeout uint64, expectedExitSta
go func() {
defer wg.Done()
select {
case taken := <-takenCh:
require.NotNil(t, taken)
t.Log("swap ID:", taken.ID)
xmrmakerIDCh <- taken.ID
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
for {
select {
case status := <-statusCh:
fmt.Println("> XMRMaker got status:", status)
if status != types.XMRLocked {
continue
}
for status := range statusCh {
fmt.Println("> XMRMaker got status:", status)
if status != types.XMRLocked {
continue
}
fmt.Println("> XMRMaker cancelled swap!")
exitStatus, err := bcli.Cancel(offerID) //nolint:govet
if err != nil {
errCh <- err
return
}
fmt.Println("> XMRMaker cancelled swap!")
exitStatus, err := bcli.Cancel() //nolint:govet
if err != nil {
errCh <- err
if exitStatus != expectedExitStatus {
errCh <- fmt.Errorf("did not get expected exit status for XMRMaker: got %s, expected %s", exitStatus, expectedExitStatus) //nolint:lll
return
}
fmt.Println("> XMRMaker refunded successfully")
return
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
if exitStatus != expectedExitStatus {
errCh <- fmt.Errorf("did not get expected exit status for XMRMaker: got %s, expected %s", exitStatus, expectedExitStatus) //nolint:lll
return
}
fmt.Println("> XMRMaker refunded successfully")
return
}
}()
@@ -427,7 +406,7 @@ func testRefundXMRMakerCancels(t *testing.T, swapTimeout uint64, expectedExitSta
require.Equal(t, 1, len(providers))
require.GreaterOrEqual(t, len(providers[0]), 2)
id, takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
require.NoError(t, err)
go func() {
@@ -457,9 +436,6 @@ func testRefundXMRMakerCancels(t *testing.T, swapTimeout uint64, expectedExitSta
default:
}
xmrmakerSwapID := <-xmrmakerIDCh
require.Equal(t, id, xmrmakerSwapID)
offersAfter, err := bcli.GetOffers()
require.NoError(t, err)
if expectedExitStatus != types.CompletedSuccess {
@@ -476,7 +452,7 @@ func TestAbort_XMRTakerCancels(t *testing.T) {
generateBlocks(64)
}
const testTimeout = time.Second * 5
const testTimeout = time.Second * 60
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -484,7 +460,7 @@ func TestAbort_XMRTakerCancels(t *testing.T) {
bwsc, err := wsclient.NewWsClient(ctx, defaultXMRMakerDaemonWSEndpoint)
require.NoError(t, err)
offerID, takenCh, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
offerID, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
types.ExchangeRate(exchangeRate))
require.NoError(t, err)
@@ -492,7 +468,6 @@ func TestAbort_XMRTakerCancels(t *testing.T) {
offersBefore, err := bc.GetOffers()
require.NoError(t, err)
xmrmakerIDCh := make(chan uint64, 1)
errCh := make(chan error, 2)
var wg sync.WaitGroup
@@ -501,26 +476,22 @@ func TestAbort_XMRTakerCancels(t *testing.T) {
go func() {
defer wg.Done()
select {
case taken := <-takenCh:
require.NotNil(t, taken)
t.Log("swap ID:", taken.ID)
xmrmakerIDCh <- taken.ID
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
for {
select {
case status := <-statusCh:
fmt.Println("> XMRMaker got status:", status)
if status.IsOngoing() {
continue
}
for status := range statusCh {
fmt.Println("> XMRMaker got status:", status)
if status.IsOngoing() {
continue
if status != types.CompletedAbort {
errCh <- fmt.Errorf("swap did not exit successfully for XMRMaker: got %s", status)
}
return
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
if status != types.CompletedAbort {
errCh <- fmt.Errorf("swap did not exit successfully: got %s", status)
}
return
}
}()
@@ -533,7 +504,7 @@ func TestAbort_XMRTakerCancels(t *testing.T) {
require.Equal(t, 1, len(providers))
require.GreaterOrEqual(t, len(providers[0]), 2)
id, takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
require.NoError(t, err)
go func() {
@@ -545,7 +516,7 @@ func TestAbort_XMRTakerCancels(t *testing.T) {
}
fmt.Println("> XMRTaker cancelled swap!")
exitStatus, err := c.Cancel() //nolint:govet
exitStatus, err := c.Cancel(offerID) //nolint:govet
if err != nil {
errCh <- err
return
@@ -567,9 +538,6 @@ func TestAbort_XMRTakerCancels(t *testing.T) {
default:
}
xmrmakerSwapID := <-xmrmakerIDCh
require.Equal(t, id, xmrmakerSwapID)
offersAfter, err := bc.GetOffers()
require.NoError(t, err)
require.Equal(t, len(offersBefore), len(offersAfter))
@@ -584,7 +552,7 @@ func TestAbort_XMRMakerCancels(t *testing.T) {
generateBlocks(64)
}
const testTimeout = time.Second * 5
const testTimeout = time.Second * 60
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -593,7 +561,7 @@ func TestAbort_XMRMakerCancels(t *testing.T) {
bwsc, err := wsclient.NewWsClient(ctx, defaultXMRMakerDaemonWSEndpoint)
require.NoError(t, err)
offerID, takenCh, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
offerID, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
types.ExchangeRate(exchangeRate))
require.NoError(t, err)
@@ -601,7 +569,6 @@ func TestAbort_XMRMakerCancels(t *testing.T) {
offersBefore, err := bc.GetOffers()
require.NoError(t, err)
xmrmakerIDCh := make(chan uint64, 1)
errCh := make(chan error, 2)
var wg sync.WaitGroup
@@ -610,35 +577,31 @@ func TestAbort_XMRMakerCancels(t *testing.T) {
go func() {
defer wg.Done()
select {
case taken := <-takenCh:
require.NotNil(t, taken)
t.Log("swap ID:", taken.ID)
xmrmakerIDCh <- taken.ID
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
for {
select {
case status := <-statusCh:
fmt.Println("> XMRMaker got status:", status)
if status != types.KeysExchanged {
continue
}
for status := range statusCh {
fmt.Println("> XMRMaker got status:", status)
if status != types.KeysExchanged {
continue
}
fmt.Println("> XMRMaker cancelled swap!")
exitStatus, err := bcli.Cancel(offerID) //nolint:govet
if err != nil {
errCh <- err
return
}
fmt.Println("> XMRMaker cancelled swap!")
exitStatus, err := bcli.Cancel() //nolint:govet
if err != nil {
errCh <- err
if exitStatus != types.CompletedAbort {
errCh <- fmt.Errorf("did not abort successfully: exit status was %s", exitStatus)
return
}
fmt.Println("> XMRMaker exited successfully")
return
case <-time.After(testTimeout):
errCh <- errors.New("make offer subscription timed out")
}
if exitStatus != types.CompletedAbort {
errCh <- fmt.Errorf("did not abort successfully: exit status was %s", exitStatus)
return
}
fmt.Println("> XMRMaker exited successfully")
return
}
}()
@@ -651,7 +614,7 @@ func TestAbort_XMRMakerCancels(t *testing.T) {
require.Equal(t, 1, len(providers))
require.GreaterOrEqual(t, len(providers[0]), 2)
id, takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
require.NoError(t, err)
go func() {
@@ -681,9 +644,6 @@ func TestAbort_XMRMakerCancels(t *testing.T) {
default:
}
xmrmakerSwapID := <-xmrmakerIDCh
require.Equal(t, id, xmrmakerSwapID)
offersAfter, err := bc.GetOffers()
require.NoError(t, err)
require.Equal(t, len(offersBefore), len(offersAfter))
@@ -692,7 +652,7 @@ func TestAbort_XMRMakerCancels(t *testing.T) {
// TestError_ShouldOnlyTakeOfferOnce tests the case where two takers try to take the same offer concurrently.
// Only one should succeed, the other should return an error or Abort status.
func TestError_ShouldOnlyTakeOfferOnce(t *testing.T) {
const testTimeout = time.Second * 30
const testTimeout = time.Second * 60
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -714,7 +674,7 @@ func TestError_ShouldOnlyTakeOfferOnce(t *testing.T) {
wsc, err := wsclient.NewWsClient(ctx, defaultXMRTakerDaemonWSEndpoint)
require.NoError(t, err)
_, takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
if err != nil {
errCh <- err
return
@@ -740,7 +700,7 @@ func TestError_ShouldOnlyTakeOfferOnce(t *testing.T) {
wsc, err := wsclient.NewWsClient(ctx, defaultCharlieDaemonWSEndpoint)
require.NoError(t, err)
_, takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
if err != nil {
errCh <- err
return
@@ -776,3 +736,137 @@ func TestError_ShouldOnlyTakeOfferOnce(t *testing.T) {
default:
}
}
func TestSuccess_ConcurrentSwaps(t *testing.T) {
const testTimeout = time.Second * 180
const numConcurrentSwaps = 10
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
type makerTest struct {
offerID string
statusCh <-chan types.Status
errCh chan error
}
makerTests := make([]*makerTest, numConcurrentSwaps)
for i := 0; i < numConcurrentSwaps; i++ {
bwsc, err := wsclient.NewWsClient(ctx, defaultXMRMakerDaemonWSEndpoint)
require.NoError(t, err)
offerID, statusCh, err := bwsc.MakeOfferAndSubscribe(0.1, xmrmakerProvideAmount,
types.ExchangeRate(exchangeRate))
require.NoError(t, err)
fmt.Println("maker made offer ", offerID)
makerTests[i] = &makerTest{
offerID: offerID,
statusCh: statusCh,
errCh: make(chan error, 2),
}
}
bc := rpcclient.NewClient(defaultXMRMakerDaemonEndpoint)
offersBefore, err := bc.GetOffers()
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(2 * numConcurrentSwaps)
for _, tc := range makerTests {
go func(tc *makerTest) {
defer wg.Done()
for {
select {
case status := <-tc.statusCh:
fmt.Println("> XMRMaker got status:", status)
if status.IsOngoing() {
continue
}
if status != types.CompletedSuccess {
tc.errCh <- fmt.Errorf("swap did not complete successfully: got %s", status)
}
return
case <-time.After(testTimeout):
tc.errCh <- errors.New("make offer subscription timed out")
}
}
}(tc)
}
type takerTest struct {
statusCh <-chan types.Status
errCh chan error
}
takerTests := make([]*takerTest, numConcurrentSwaps)
for i := 0; i < numConcurrentSwaps; i++ {
c := rpcclient.NewClient(defaultXMRTakerDaemonEndpoint)
wsc, err := wsclient.NewWsClient(ctx, defaultXMRTakerDaemonWSEndpoint) //nolint:govet
require.NoError(t, err)
// TODO: implement discovery over websockets
providers, err := c.Discover(types.ProvidesXMR, defaultDiscoverTimeout)
require.NoError(t, err)
require.Equal(t, 1, len(providers))
require.GreaterOrEqual(t, len(providers[0]), 2)
offerID := makerTests[i].offerID
takerStatusCh, err := wsc.TakeOfferAndSubscribe(providers[0][0], offerID, 0.05)
require.NoError(t, err)
fmt.Println("taker took offer ", offerID)
takerTests[i] = &takerTest{
statusCh: takerStatusCh,
errCh: make(chan error, 2),
}
}
for _, tc := range takerTests {
go func(tc *takerTest) {
defer wg.Done()
for status := range tc.statusCh {
fmt.Println("> XMRTaker got status:", status)
if status.IsOngoing() {
continue
}
if status != types.CompletedSuccess {
tc.errCh <- fmt.Errorf("swap did not complete successfully: got %s", status)
}
return
}
}(tc)
}
wg.Wait()
for _, tc := range makerTests {
select {
case err = <-tc.errCh:
require.NoError(t, err)
default:
}
}
for _, tc := range takerTests {
select {
case err = <-tc.errCh:
require.NoError(t, err)
default:
}
}
offersAfter, err := bc.GetOffers()
require.NoError(t, err)
require.Equal(t, numConcurrentSwaps, len(offersBefore)-len(offersAfter))
}