ensure that the status channel is always initialized (#471)

This commit is contained in:
Dmitry Holodov
2023-05-20 15:04:26 -05:00
committed by GitHub
parent 53a5c97c4f
commit f41c9af432
35 changed files with 605 additions and 271 deletions

View File

@@ -174,12 +174,6 @@ func (o *Offer) validate() error {
return nil
}
// OfferExtra represents extra data that is passed when an offer is made.
type OfferExtra struct {
StatusCh chan Status `json:"-"`
UseRelayer bool `json:"useRelayer,omitempty"`
}
// UnmarshalOffer deserializes a JSON offer, checking the version for compatibility before
// attempting to deserialize the whole blob.
func UnmarshalOffer(jsonData []byte) (*Offer, error) {
@@ -228,3 +222,19 @@ func (o *Offer) UnmarshalJSON(data []byte) error {
}
return o.validate()
}
// OfferExtra represents extra data that is passed when an offer is made.
type OfferExtra struct {
// UseRelayer forces the XMR maker to claim using the relayer even when he
// has enough funds to make the claim himself. Setting it to false will not
// prevent the relayer from being used if there are insufficient ETH funds
// to claim.
UseRelayer bool `json:"useRelayer,omitempty"`
}
// NewOfferExtra creates an OfferExtra instance
func NewOfferExtra(forceUseRelayer bool) *OfferExtra {
return &OfferExtra{
UseRelayer: forceUseRelayer,
}
}

View File

@@ -4,6 +4,7 @@
package types
import (
"encoding/json"
"fmt"
"testing"
@@ -250,3 +251,17 @@ func TestUnmarshalOffer_VersionTooNew(t *testing.T) {
_, err := UnmarshalOffer([]byte(offerJSON))
require.ErrorContains(t, err, fmt.Sprintf("offer version %q not supported", unsupportedVersion))
}
func TestOfferExtra_JSON(t *testing.T) {
// Marshal test
extra := NewOfferExtra(true)
data, err := vjson.MarshalStruct(extra)
require.NoError(t, err)
require.JSONEq(t, `{"useRelayer":true}`, string(data))
// Unmarshal test
extra = new(OfferExtra)
err = json.Unmarshal(data, extra)
require.NoError(t, err)
require.True(t, extra.UseRelayer)
}

View File

@@ -0,0 +1,141 @@
package daemon
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/athanorlabs/atomic-swap/coins"
"github.com/athanorlabs/atomic-swap/common/types"
"github.com/athanorlabs/atomic-swap/monero"
"github.com/athanorlabs/atomic-swap/rpcclient/wsclient"
"github.com/athanorlabs/atomic-swap/tests"
)
func TestXMRNotLockedAndETHRefundedAfterAliceRestarts(t *testing.T) {
minXMR := coins.StrToDecimal("1")
maxXMR := minXMR
exRate := coins.StrToExchangeRate("0.1")
providesAmt, err := exRate.ToETH(minXMR)
require.NoError(t, err)
bobConf := CreateTestConf(t, tests.GetMakerTestKey(t))
monero.MineMinXMRBalance(t, bobConf.MoneroClient, coins.MoneroToPiconero(maxXMR))
aliceConf := CreateTestConf(t, tests.GetTakerTestKey(t))
timeout := 7 * time.Minute
ctx, cancel := LaunchDaemons(t, timeout, aliceConf, bobConf)
// clients use a separate context and will work across server restarts
clientCtx := context.Background()
bc, err := wsclient.NewWsClient(clientCtx, fmt.Sprintf("ws://127.0.0.1:%d/ws", bobConf.RPCPort))
require.NoError(t, err)
ac, err := wsclient.NewWsClient(clientCtx, fmt.Sprintf("ws://127.0.0.1:%d/ws", aliceConf.RPCPort))
require.NoError(t, err)
// Bob makes an offer
makeResp, bobStatusCh, err := bc.MakeOfferAndSubscribe(minXMR, maxXMR, exRate, types.EthAssetETH, false)
require.NoError(t, err)
// Alice takes the offer
aliceStatusCh, err := ac.TakeOfferAndSubscribe(makeResp.PeerID, makeResp.OfferID, providesAmt)
require.NoError(t, err)
var statusWG sync.WaitGroup
statusWG.Add(2)
// Alice shuts down both servers as soon as she locks her ETH
go func() {
defer statusWG.Done()
for {
select {
case status := <-aliceStatusCh:
t.Log("> Alice got status:", status)
switch status {
case types.ExpectingKeys:
continue
case types.ETHLocked:
cancel() // stop both Alice's and Bob's daemons
default:
cancel()
t.Errorf("Alice should not have reached status=%s", status)
}
case <-ctx.Done():
t.Logf("Alice's context cancelled (expected)")
return
}
}
}()
// Bob is not playing a significant role in this test. His swapd instance is
// shut down before he can lock any XMR and we don't bring it back online.
go func() {
defer statusWG.Done()
for {
select {
case status := <-bobStatusCh:
t.Log("> Bob got status:", status)
switch status {
case types.KeysExchanged:
continue
default:
cancel()
t.Errorf("Bob should not have reached status=%s", status)
}
case <-ctx.Done():
t.Logf("Bob's context cancelled (expected)")
return
}
}
}()
statusWG.Wait()
if t.Failed() {
return
}
// Make sure both servers had time to fully shut down
time.Sleep(3 * time.Second)
// relaunch Alice's daemon
t.Logf("daemons stopped, now re-launching Alice's daemon in isolation")
ctx, cancel = LaunchDaemons(t, 3*time.Minute, aliceConf)
// This is a bug that we need to recreate Alice's websocket client here. Remove this
// code when we fix https://github.com/AthanorLabs/atomic-swap/issues/353.
ac, err = wsclient.NewWsClient(clientCtx, fmt.Sprintf("ws://127.0.0.1:%d/ws", aliceConf.RPCPort))
require.NoError(t, err)
aliceStatusCh, err = ac.SubscribeSwapStatus(makeResp.OfferID)
require.NoError(t, err)
// Ensure Alice completes the swap with a refund
statusWG.Add(1)
go func() {
defer statusWG.Done()
for {
select {
case status := <-aliceStatusCh:
t.Log("> Alice, after restart, got status:", status)
if !status.IsOngoing() {
assert.Equal(t, types.CompletedRefund.String(), status.String())
return
}
case <-ctx.Done():
// Alice's context has a deadline. If we get here, the context
// expired before we got any Refund status update.
t.Errorf("Alice's context cancelled before she completed the swap")
return
}
}
}()
statusWG.Wait()
// TODO: Add some additional checks here when the rest of the test is working
}

View File

@@ -11,6 +11,7 @@ import (
"fmt"
"math/big"
"net"
"path"
"sync"
"syscall"
"testing"
@@ -87,13 +88,15 @@ func CreateTestBootnode(t *testing.T) (uint16, string) {
wg.Wait()
})
dataDir := t.TempDir()
conf := &bootnode.Config{
Env: common.Development,
DataDir: t.TempDir(),
Bootnodes: nil,
P2PListenIP: "127.0.0.1",
Libp2pPort: 0,
Libp2pKeyFile: common.DefaultLibp2pKeyFileName,
Libp2pKeyFile: path.Join(dataDir, common.DefaultLibp2pKeyFileName),
RPCPort: uint16(rpcPort),
}

View File

@@ -331,11 +331,10 @@ func TestDatabase_SwapTable_Update(t *testing.T) {
// infoB mostly the same as infoA (same ID, importantly), but with
// a couple updated fields.
infoB := new(swap.Info)
*infoB = *infoA
infoB, err := infoA.DeepCopy()
require.NoError(t, err)
infoB.Status = types.CompletedSuccess
endTime := time.Now()
infoB.EndTime = &endTime
infoB.MarkSwapComplete()
err = db.PutSwap(infoB)
require.NoError(t, err)

View File

@@ -83,16 +83,14 @@ func TestRecoveryDB_SwapRelayerInfo(t *testing.T) {
rdb := newTestRecoveryDB(t)
offerID := types.Hash{5, 6, 7, 8}
info := &types.OfferExtra{
UseRelayer: true,
}
extra := types.NewOfferExtra(true)
err := rdb.PutSwapRelayerInfo(offerID, info)
err := rdb.PutSwapRelayerInfo(offerID, extra)
require.NoError(t, err)
res, err := rdb.GetSwapRelayerInfo(offerID)
require.NoError(t, err)
require.Equal(t, info, res)
require.Equal(t, extra, res)
}
func TestRecoveryDB_SwapPrivateKey(t *testing.T) {
@@ -165,13 +163,11 @@ func TestRecoveryDB_DeleteSwap(t *testing.T) {
SwapCreatorAddr: ethcommon.HexToAddress("0xd2b5d6252d0645e4cf4bb547e82a485f527befb7"),
}
info := &types.OfferExtra{
UseRelayer: true,
}
extra := types.NewOfferExtra(true)
err = rdb.PutContractSwapInfo(offerID, si)
require.NoError(t, err)
err = rdb.PutSwapRelayerInfo(offerID, info)
err = rdb.PutSwapRelayerInfo(offerID, extra)
require.NoError(t, err)
err = rdb.PutSwapPrivateKey(offerID, kp.SpendKey())
require.NoError(t, err)

View File

@@ -262,7 +262,7 @@ func (b *backend) ClearXMRDepositAddress(offerID types.Hash) {
// HasOngoingSwapAsTaker returns nil if we have an ongoing swap with the given peer where
// we're the xmrtaker, otherwise returns an error.
func (b *backend) HasOngoingSwapAsTaker(remotePeer peer.ID) error {
swaps, err := b.swapManager.GetOngoingSwaps()
swaps, err := b.swapManager.GetOngoingSwapsSnapshot()
if err != nil {
return err
}
@@ -293,7 +293,7 @@ func (b *backend) HandleRelayClaimRequest(
return nil, fmt.Errorf("cannot relay taker-specific claim request; no ongoing swap for swap %s", *request.OfferID)
}
info, err := b.swapManager.GetOngoingSwap(*request.OfferID)
info, err := b.swapManager.GetOngoingSwapSnapshot(*request.OfferID)
if err != nil {
return nil, err
}

View File

@@ -24,6 +24,7 @@ var (
// SwapManager is the subset of the swap.Manager interface needed by ClaimMonero
type SwapManager interface {
WriteSwapToDB(info *swap.Info) error
PushNewStatus(offerID types.Hash, status types.Status)
}
// GetClaimKeypair returns the private key pair required for a monero claim.
@@ -103,6 +104,7 @@ func ClaimMonero(
// setSweepStatus sets the swap's status as `SweepingXMR` and writes it to the db.
func setSweepStatus(info *swap.Info, sm SwapManager) error {
info.SetStatus(types.SweepingXMR)
sm.PushNewStatus(info.OfferID, types.SweepingXMR)
err := sm.WriteSwapToDB(info)
if err != nil {
return fmt.Errorf("failed to write swap to db: %w", err)

View File

@@ -10,6 +10,7 @@ import (
"github.com/athanorlabs/atomic-swap/coins"
"github.com/athanorlabs/atomic-swap/common"
"github.com/athanorlabs/atomic-swap/common/types"
mcrypto "github.com/athanorlabs/atomic-swap/crypto/monero"
"github.com/athanorlabs/atomic-swap/monero"
"github.com/athanorlabs/atomic-swap/protocol/swap"
@@ -29,6 +30,9 @@ func (*mockSwapManager) WriteSwapToDB(info *swap.Info) error {
return nil
}
func (*mockSwapManager) PushNewStatus(_ types.Hash, _ types.Status) {
}
func TestClaimMonero_NoTransferBack(t *testing.T) {
env := common.Development

View File

@@ -0,0 +1,75 @@
package swap
import (
"sync"
"github.com/athanorlabs/atomic-swap/common/types"
)
// statusManager provides lookup for the status channels. Status channels are
// ephemeral between runs of swapd.
type statusManager struct {
mu sync.Mutex
statusChannels map[types.Hash]chan Status
}
func newStatusManager() *statusManager {
return &statusManager{
mu: sync.Mutex{},
statusChannels: make(map[types.Hash]chan Status),
}
}
// getStatusChan returns any existing status channel or a new status channel for
// reading or writing.
func (sm *statusManager) getStatusChan(offerID types.Hash) chan Status {
sm.mu.Lock()
defer sm.mu.Unlock()
_, ok := sm.statusChannels[offerID]
if !ok {
sm.statusChannels[offerID] = newStatusChannel()
}
return sm.statusChannels[offerID]
}
// GetStatusChan returns any existing status channel or a new status channel for
// reading only.
func (sm *statusManager) GetStatusChan(offerID types.Hash) <-chan Status {
return sm.getStatusChan(offerID)
}
// DeleteStatusChan deletes any status channel associated with the offer ID.
func (sm *statusManager) DeleteStatusChan(offerID types.Hash) {
sm.mu.Lock()
defer sm.mu.Unlock()
delete(sm.statusChannels, offerID)
}
// PushNewStatus adds a new status to the offer ID's channel
func (sm *statusManager) PushNewStatus(offerID types.Hash, status types.Status) {
ch := sm.getStatusChan(offerID)
ch <- status
// If the status is not ongoing, existing subscribers will get the status
// via the channel since they already have a reference to it. New
// subscribers will get the final status from the past swaps map.
if !status.IsOngoing() {
// We grabbed the status channel before calling IsOngoing to avoid a
// race condition where the status becomes complete after the check, but
// before we grab a reference to the channel. If the status was complete
// before we grabbed the channel, we created a new channel, which we
// remove below.
sm.DeleteStatusChan(offerID)
}
}
// newStatusChannel creates a status channel using the the correct size
func newStatusChannel() chan Status {
// The channel size should be large enough to handle the max number of
// stages a swap can potentially go through.
const statusChSize = 6
ch := make(chan Status, statusChSize)
return ch
}

View File

@@ -0,0 +1,26 @@
package swap
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/athanorlabs/atomic-swap/common/types"
)
func TestStatusManager(t *testing.T) {
offerID1 := types.Hash{0x1}
statusMgr := newStatusManager()
ch1 := statusMgr.GetStatusChan(offerID1)
ch2 := statusMgr.GetStatusChan(offerID1)
require.Equal(t, ch1, ch2)
statusMgr.PushNewStatus(offerID1, types.CompletedSuccess)
status := <-ch1
require.Equal(t, types.CompletedSuccess, status)
statusMgr.DeleteStatusChan(offerID1)
ch3 := statusMgr.GetStatusChan(offerID1)
require.NotEqual(t, ch1, ch3)
}

View File

@@ -15,7 +15,7 @@ import (
"github.com/ChainSafe/chaindb"
)
var errNoSwapWithID = errors.New("unable to find swap with given ID")
var errNoSwapWithOfferID = errors.New("unable to find swap with given offer ID")
// Manager tracks current and past swaps.
type Manager interface {
@@ -23,10 +23,15 @@ type Manager interface {
WriteSwapToDB(info *Info) error
GetPastIDs() ([]types.Hash, error)
GetPastSwap(types.Hash) (*Info, error)
GetOngoingSwap(types.Hash) (Info, error)
GetOngoingSwaps() ([]*Info, error)
GetOngoingSwap(hash types.Hash) (*Info, error)
GetOngoingSwapSnapshot(types.Hash) (*Info, error)
GetOngoingSwapOfferIDs() ([]*types.Hash, error)
GetOngoingSwapsSnapshot() ([]*Info, error)
CompleteOngoingSwap(info *Info) error
HasOngoingSwap(types.Hash) bool
GetStatusChan(offerID types.Hash) <-chan types.Status
DeleteStatusChan(offerID types.Hash)
PushNewStatus(offerID types.Hash, status types.Status)
}
// manager implements Manager.
@@ -38,6 +43,7 @@ type manager struct {
sync.RWMutex
ongoing map[types.Hash]*Info
past map[types.Hash]*Info
*statusManager
}
var _ Manager = (*manager)(nil)
@@ -62,9 +68,10 @@ func NewManager(db Database) (Manager, error) {
}
return &manager{
db: db,
ongoing: ongoing,
past: make(map[types.Hash]*Info),
db: db,
ongoing: ongoing,
past: make(map[types.Hash]*Info),
statusManager: newStatusManager(),
}, nil
}
@@ -140,30 +147,70 @@ func (m *manager) GetPastSwap(id types.Hash) (*Info, error) {
return s, nil
}
// GetOngoingSwap returns the ongoing swap's *Info, if there is one.
func (m *manager) GetOngoingSwap(id types.Hash) (Info, error) {
// GetOngoingSwap returns the ongoing swap's *Info, if there is one. The
// returned Info structure of an active swap can be modified as the swap's state
// changes and should only be read or written by a single go process.
func (m *manager) GetOngoingSwap(offerID types.Hash) (*Info, error) {
m.RLock()
defer m.RUnlock()
s, has := m.ongoing[id]
s, has := m.ongoing[offerID]
if !has {
return Info{}, errNoSwapWithID
return nil, errNoSwapWithOfferID
}
return *s, nil
return s, nil
}
// GetOngoingSwaps returns all ongoing swaps.
func (m *manager) GetOngoingSwaps() ([]*Info, error) {
// GetOngoingSwapSnapshot returns a copy of the ongoing swap's Info, if the
// offerID has an ongoing swap.
func (m *manager) GetOngoingSwapSnapshot(offerID types.Hash) (*Info, error) {
m.RLock()
defer m.RUnlock()
swaps := make([]*Info, len(m.ongoing))
i := 0
for _, s := range m.ongoing {
sCopy := new(Info)
*sCopy = *s
swaps[i] = sCopy
i++
s, has := m.ongoing[offerID]
if !has {
return nil, errNoSwapWithOfferID
}
sc, err := s.DeepCopy()
if err != nil {
return nil, err
}
return sc, nil
}
// GetOngoingSwapOfferIDs returns a list of the offer IDs of all ongoing
// swaps.
func (m *manager) GetOngoingSwapOfferIDs() ([]*types.Hash, error) {
m.RLock()
defer m.RUnlock()
offerIDs := make([]*types.Hash, 0, len(m.ongoing))
for _, s := range m.ongoing {
offerIDs = append(offerIDs, &s.OfferID)
}
return offerIDs, nil
}
// GetOngoingSwapsSnapshot returns a copy of all ongoing swaps. If you need to
// modify the result, call `GetOngoingSwap` on the offerID to get the "live"
// Info object.
func (m *manager) GetOngoingSwapsSnapshot() ([]*Info, error) {
m.RLock()
defer m.RUnlock()
swaps := make([]*Info, 0, len(m.ongoing))
for _, s := range m.ongoing {
sc, err := s.DeepCopy()
if err != nil {
return nil, err
}
swaps = append(swaps, sc)
}
return swaps, nil
}
@@ -171,9 +218,10 @@ func (m *manager) GetOngoingSwaps() ([]*Info, error) {
func (m *manager) CompleteOngoingSwap(info *Info) error {
m.Lock()
defer m.Unlock()
_, has := m.ongoing[info.OfferID]
if !has {
return errNoSwapWithID
return errNoSwapWithOfferID
}
now := time.Now()
@@ -190,6 +238,7 @@ func (m *manager) CompleteOngoingSwap(info *Info) error {
func (m *manager) HasOngoingSwap(id types.Hash) bool {
m.RLock()
defer m.RUnlock()
_, has := m.ongoing[id]
return has
}
@@ -197,7 +246,7 @@ func (m *manager) HasOngoingSwap(id types.Hash) bool {
func (m *manager) getSwapFromDB(id types.Hash) (*Info, error) {
s, err := m.db.GetSwap(id)
if errors.Is(chaindb.ErrKeyNotFound, err) {
return nil, errNoSwapWithID
return nil, errNoSwapWithOfferID
}
if err != nil {
return nil, err

View File

@@ -38,7 +38,6 @@ func TestNewManager(t *testing.T) {
types.EthAssetETH,
types.ExpectingKeys,
100,
nil,
)
db.EXPECT().PutSwap(infoA)
err = m.AddSwap(infoA)
@@ -54,7 +53,6 @@ func TestNewManager(t *testing.T) {
types.EthAssetETH,
types.CompletedSuccess,
100,
nil,
)
db.EXPECT().PutSwap(infoB)
err = m.AddSwap(infoB)
@@ -88,7 +86,6 @@ func TestManager_AddSwap_Ongoing(t *testing.T) {
types.EthAssetETH,
types.ExpectingKeys,
100,
nil,
)
db.EXPECT().PutSwap(info)
@@ -100,7 +97,7 @@ func TestManager_AddSwap_Ongoing(t *testing.T) {
s, err := m.GetOngoingSwap(types.Hash{})
require.NoError(t, err)
require.Equal(t, info, &s)
require.Equal(t, info, s)
require.NotNil(t, m.ongoing)
db.EXPECT().PutSwap(info)

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/Masterminds/semver/v3"
@@ -18,8 +19,6 @@ import (
"github.com/athanorlabs/atomic-swap/common/vjson"
)
const statusChSize = 6 // the max number of stages a swap can potentially go through
var (
// CurInfoVersion is the latest supported version of a serialised Info struct
CurInfoVersion, _ = semver.NewVersion("0.3.0")
@@ -63,8 +62,18 @@ type Info struct {
// (and after Timeout1), the ETH-taker is able to claim, but
// after this timeout, the ETH-taker can no longer claim, only
// the ETH-maker can refund.
Timeout2 *time.Time `json:"timeout2,omitempty"`
statusCh chan types.Status `json:"-"`
Timeout2 *time.Time `json:"timeout2,omitempty"`
// rwMu handles synchronization when LastStatusUpdateTime, Timeout1,
// Timeout2 and EndTime are updated. This Info struct is modified by the
// maker or taker's swapState go process as the state of the swap
// progresses. The swapState go process does not need synchronization when
// reading its own changes, but it needs to grab a write lock when modifying
// the the structure. Readers from other go-processes only get copies of
// this structure. They exclusively use the DeepCopy method to get their
// copy, which grabs the read lock ensuring that they always capture the
// up-to-date state of this Info struct.
rwMu sync.RWMutex
}
// NewInfo creates a new *Info from the given parameters.
@@ -78,7 +87,6 @@ func NewInfo(
ethAsset types.EthAsset,
status Status,
moneroStartHeight uint64,
statusCh chan types.Status,
) *Info {
info := &Info{
Version: CurInfoVersion,
@@ -92,26 +100,42 @@ func NewInfo(
Status: status,
LastStatusUpdateTime: time.Now(),
MoneroStartHeight: moneroStartHeight,
statusCh: statusCh,
StartTime: time.Now(),
EndTime: nil,
Timeout1: nil,
Timeout2: nil,
rwMu: sync.RWMutex{},
}
return info
}
// StatusCh returns the swap's status update channel.
func (i *Info) StatusCh() <-chan types.Status {
return i.statusCh
}
// SetStatus ...
// SetStatus updates the status and status modification timestamp
func (i *Info) SetStatus(s Status) {
i.rwMu.Lock()
defer i.rwMu.Unlock()
i.Status = s
i.LastStatusUpdateTime = time.Now()
if i.statusCh == nil {
// this case only happens in tests.
return
}
i.statusCh <- s
}
// SetTimeouts sets the 2 timeout fields, , grabbing the needed lock before
// modifying fields.
func (i *Info) SetTimeouts(t1 *time.Time, t2 *time.Time) {
i.rwMu.Lock()
defer i.rwMu.Unlock()
i.Timeout1 = t1
i.Timeout2 = t2
}
// MarkSwapComplete sets the EndTime field to the current wall time, grabbing the
// needed lock before modifying fields.
func (i *Info) MarkSwapComplete() {
i.rwMu.Lock()
defer i.rwMu.Unlock()
now := time.Now()
i.EndTime = &now
}
// IsTaker returns true if the node is the xmr-taker in the swap.
@@ -119,32 +143,66 @@ func (i *Info) IsTaker() bool {
return i.Provides == coins.ProvidesETH
}
// UnmarshalInfo deserializes a JSON Info struct, checking the version for compatibility
// before attempting to deserialize the whole blob.
// UnmarshalInfo unmarshalls the passed JSON into a freshly created Info object.
func UnmarshalInfo(jsonData []byte) (*Info, error) {
ov := struct {
info := new(Info)
if err := json.Unmarshal(jsonData, info); err != nil {
return nil, err
}
return info, nil
}
// UnmarshalJSON deserializes a JSON Info struct, checking the version for
// compatibility.
func (i *Info) UnmarshalJSON(jsonData []byte) error {
iv := struct {
Version *semver.Version `json:"version"`
}{}
if err := json.Unmarshal(jsonData, &ov); err != nil {
return nil, err
if err := json.Unmarshal(jsonData, &iv); err != nil {
return err
}
if ov.Version == nil {
return nil, errInfoVersionMissing
if iv.Version == nil {
return errInfoVersionMissing
}
if ov.Version.GreaterThan(CurInfoVersion) {
return nil, fmt.Errorf("info version %q not supported, latest is %q", ov.Version, CurInfoVersion)
if iv.Version.GreaterThan(CurInfoVersion) {
return fmt.Errorf("info version %q not supported, latest is %q", iv.Version, CurInfoVersion)
}
info := new(Info)
if err := vjson.UnmarshalStruct(jsonData, info); err != nil {
return nil, err
}
// Assuming any version less than the current version is forwards
// compatible. If that is not the case in the future, add code here to
// upgrade the older version to the current version when deserializing.
// (Or error if it is completely incompatible.)
info.statusCh = make(chan types.Status, statusChSize)
// Unmarshal without recursion
type _Info Info
if err := vjson.UnmarshalStruct(jsonData, (*_Info)(i)); err != nil {
return err
}
// TODO: Are there additional sanity checks we can perform on the Provided and Received amounts
// (or other fields) here when decoding the JSON?
return info, nil
return nil
}
// DeepCopy returns a deep copy of the Info data structure
func (i *Info) DeepCopy() (*Info, error) {
i.rwMu.RLock()
defer i.rwMu.RUnlock()
// This is not the most efficient means of getting a deep copy, but for our
// needs it is fast enough and least prone to human error, as the structure
// has numerous nested pointer types.
jsonData, err := json.Marshal(i)
if err != nil {
return nil, err
}
clone := new(Info)
if err = clone.UnmarshalJSON(jsonData); err != nil {
return nil, err
}
return clone, nil
}

View File

@@ -32,7 +32,6 @@ func Test_InfoMarshal(t *testing.T) {
types.EthAssetETH,
types.CompletedSuccess,
200,
make(chan types.Status),
)
err := info.StartTime.UnmarshalJSON([]byte("\"2023-02-20T17:29:43.471020297-05:00\""))
require.NoError(t, err)

View File

@@ -134,6 +134,5 @@ func (s *swapState) checkAndSetTimeouts(t1, t2 *big.Int) error {
func (s *swapState) setTimeouts(t1, t2 *big.Int) {
s.t1 = time.Unix(t1.Int64(), 0)
s.t2 = time.Unix(t2.Int64(), 0)
s.info.Timeout1 = &s.t1
s.info.Timeout2 = &s.t2
s.info.SetTimeouts(&s.t1, &s.t2)
}

View File

@@ -36,7 +36,7 @@ func TestSwapState_handleEvent_EventContractReady(t *testing.T) {
// runContractEventWatcher will trigger EventContractReady,
// which will then set the next expected event to EventExit.
for status := range s.info.StatusCh() {
for status := range s.SwapManager().GetStatusChan(s.OfferID()) {
if !status.IsOngoing() {
break
}

View File

@@ -84,12 +84,17 @@ func NewInstance(cfg *Config) (*Instance, error) {
}
func (inst *Instance) checkForOngoingSwaps() error {
swaps, err := inst.backend.SwapManager().GetOngoingSwaps()
ongoingIDs, err := inst.backend.SwapManager().GetOngoingSwapOfferIDs()
if err != nil {
return err
}
for _, s := range swaps {
for _, offerID := range ongoingIDs {
s, err := inst.backend.SwapManager().GetOngoingSwap(*offerID)
if err != nil {
return err
}
if s.Provides != coins.ProvidesXMR {
continue
}
@@ -178,7 +183,7 @@ func (inst *Instance) createOngoingSwap(s *swap.Info) error {
if err != nil {
// we can ignore the error; if the key doesn't exist,
// then no relayer was set for this swap.
relayerInfo = &types.OfferExtra{}
relayerInfo = types.NewOfferExtra(false)
}
ss, err := newSwapStateFromOngoing(

View File

@@ -52,7 +52,7 @@ func (s *swapState) HandleProtocolMessage(msg common.Message) error {
func (s *swapState) clearNextExpectedEvent(status types.Status) {
s.nextExpectedEvent = EventNoneType
s.info.SetStatus(status)
s.updateStatus(status)
}
func (s *swapState) setNextExpectedEvent(event EventType) error {
@@ -72,7 +72,7 @@ func (s *swapState) setNextExpectedEvent(event EventType) error {
panic("status corresponding to event cannot be UnknownStatus")
}
s.info.SetStatus(status)
s.updateStatus(status)
err := s.Backend.SwapManager().WriteSwapToDB(s.info)
if err != nil {
return err

View File

@@ -10,12 +10,10 @@ import (
"github.com/ChainSafe/chaindb"
"github.com/athanorlabs/atomic-swap/common/types"
logging "github.com/ipfs/go-log"
)
const statusChSize = 6 // the max number of stages a swap can potentially go through
"github.com/athanorlabs/atomic-swap/common/types"
)
var (
log = logging.Logger("offers")
@@ -49,13 +47,9 @@ func NewManager(dataDir string, db Database) (*Manager, error) {
offers := make(map[types.Hash]*offerWithExtra)
for _, offer := range savedOffers {
extra := &types.OfferExtra{
StatusCh: make(chan types.Status, statusChSize),
}
offers[offer.ID] = &offerWithExtra{
offer: offer,
extra: extra,
extra: types.NewOfferExtra(false),
}
log.Infof("loaded offer %s from database", offer.ID)
@@ -83,10 +77,7 @@ func (m *Manager) GetOffer(id types.Hash) (*types.Offer, *types.OfferExtra, erro
}
// AddOffer adds a new offer to the manager and returns its OffersExtra data
func (m *Manager) AddOffer(
offer *types.Offer,
useRelayer bool,
) (*types.OfferExtra, error) {
func (m *Manager) AddOffer(offer *types.Offer, useRelayer bool) (*types.OfferExtra, error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -101,10 +92,7 @@ func (m *Manager) AddOffer(
return nil, err
}
extra := &types.OfferExtra{
StatusCh: make(chan types.Status, statusChSize),
UseRelayer: useRelayer,
}
extra := types.NewOfferExtra(useRelayer)
m.offers[id] = &offerWithExtra{
offer: offer,

View File

@@ -112,9 +112,6 @@ func newSwapStateFromStart(
// and we'll send our own after this function returns.
// see HandleInitiateMessage().
stage := types.KeysExchanged
if offerExtra.StatusCh == nil {
offerExtra.StatusCh = make(chan types.Status, 7)
}
if offerExtra.UseRelayer {
if err := b.RecoveryDB().PutSwapRelayerInfo(offer.ID, offerExtra); err != nil {
@@ -146,7 +143,6 @@ func newSwapStateFromStart(
offer.EthAsset,
stage,
moneroStartHeight,
offerExtra.StatusCh,
)
if err = b.SwapManager().AddSwap(info); err != nil {
@@ -171,7 +167,8 @@ func newSwapStateFromStart(
return nil, err
}
offerExtra.StatusCh <- stage
s.SwapManager().PushNewStatus(offer.ID, stage)
return s, nil
}
@@ -247,6 +244,7 @@ func checkIfAlreadyClaimed(
func completeSwap(info *swap.Info, b backend.Backend, om *offers.Manager) error {
// set swap to completed
info.SetStatus(types.CompletedSuccess)
b.SwapManager().PushNewStatus(info.OfferID, types.CompletedSuccess)
err := b.SwapManager().CompleteOngoingSwap(info)
if err != nil {
return fmt.Errorf("failed to mark swap %s as completed: %s", info.OfferID, err)
@@ -452,6 +450,11 @@ func (s *swapState) SendKeysMessage() common.Message {
}
}
func (s *swapState) updateStatus(status types.Status) {
s.info.SetStatus(status)
s.SwapManager().PushNewStatus(s.OfferID(), status)
}
// ExpectedAmount returns the amount received, or expected to be received, at the end of the swap
func (s *swapState) ExpectedAmount() *apd.Decimal {
return s.info.ExpectedAmount

View File

@@ -44,7 +44,7 @@ func newTestSwapStateAndDB(t *testing.T) (*Instance, *swapState, *offers.MockDat
xmrmaker.backend,
testPeerID,
types.NewOffer("", new(apd.Decimal), new(apd.Decimal), new(coins.ExchangeRate), types.EthAssetETH),
&types.OfferExtra{},
types.NewOfferExtra(false),
xmrmaker.offerManager,
coins.MoneroToPiconero(coins.StrToDecimal("0.05")),
desiredAmount,
@@ -270,7 +270,7 @@ func TestSwapState_HandleProtocolMessage_NotifyETHLocked_timeout(t *testing.T) {
go s.runT1ExpirationHandler()
for status := range s.info.StatusCh() {
for status := range s.SwapManager().GetStatusChan(s.OfferID()) {
if status == types.CompletedSuccess {
break
} else if !status.IsOngoing() {
@@ -319,7 +319,8 @@ func TestSwapState_handleRefund(t *testing.T) {
// runContractEventWatcher will trigger EventETHRefunded,
// which will then set the next expected event to EventExit.
for status := range s.info.StatusCh() {
statusCh := s.SwapManager().GetStatusChan(s.info.OfferID)
for status := range statusCh {
if !status.IsOngoing() {
break
}
@@ -376,7 +377,7 @@ func TestSwapState_Exit_Reclaim(t *testing.T) {
// runContractEventWatcher will trigger EventETHRefunded,
// which will then set the next expected event to EventExit.
for status := range s.info.StatusCh() {
for status := range s.SwapManager().GetStatusChan(s.info.OfferID) {
if !status.IsOngoing() {
require.Equal(t, types.CompletedRefund.String(), status.String())
break
@@ -418,7 +419,7 @@ func TestSwapState_Exit_Success(t *testing.T) {
max := coins.StrToDecimal("0.2")
rate := coins.ToExchangeRate(coins.StrToDecimal("0.1"))
s.offer = types.NewOffer(coins.ProvidesXMR, min, max, rate, types.EthAssetETH)
s.info.SetStatus(types.CompletedSuccess)
s.updateStatus(types.CompletedSuccess)
err := s.Exit()
require.NoError(t, err)
@@ -441,7 +442,7 @@ func TestSwapState_Exit_Refunded(t *testing.T) {
_, err := b.MakeOffer(s.offer, false)
require.NoError(t, err)
s.info.SetStatus(types.CompletedRefund)
s.updateStatus(types.CompletedRefund)
err = s.Exit()
require.NoError(t, err)

View File

@@ -64,12 +64,17 @@ func NewInstance(cfg *Config) (*Instance, error) {
}
func (inst *Instance) checkForOngoingSwaps() error {
swaps, err := inst.backend.SwapManager().GetOngoingSwaps()
ongoingIDs, err := inst.backend.SwapManager().GetOngoingSwapOfferIDs()
if err != nil {
return err
}
for _, s := range swaps {
for _, offerID := range ongoingIDs {
s, err := inst.backend.SwapManager().GetOngoingSwap(*offerID)
if err != nil {
return err
}
if s.Provides != coins.ProvidesETH {
continue
}
@@ -171,7 +176,7 @@ func (inst *Instance) createOngoingSwap(s *swap.Info) error {
}
// completeSwap is called in the case where we find an ongoing swap in the db on startup,
// and the swap already has the counterpary's swap secret stored.
// and the swap already has the counterparty's swap secret stored.
// In this case, we simply claim the XMR, as we have both secrets required.
// It's unlikely for this case to ever be hit, unless the daemon was shut down in-between
// us finding the counterparty's secret and claiming the XMR.

View File

@@ -40,7 +40,7 @@ func (s *swapState) HandleProtocolMessage(msg common.Message) error {
func (s *swapState) clearNextExpectedEvent(status types.Status) {
s.nextExpectedEvent = EventNoneType
s.info.SetStatus(status)
s.updateStatus(status)
}
func (s *swapState) setNextExpectedEvent(event EventType) error {
@@ -61,7 +61,7 @@ func (s *swapState) setNextExpectedEvent(event EventType) error {
}
log.Debugf("setting status to %s", status)
s.info.SetStatus(status)
s.updateStatus(status)
return s.Backend.SwapManager().WriteSwapToDB(s.info)
}

View File

@@ -104,7 +104,6 @@ func newSwapStateFromStart(
ethAsset types.EthAsset,
) (*swapState, error) {
stage := types.ExpectingKeys
statusCh := make(chan types.Status, 16)
moneroStartNumber, err := b.XMRClient().GetHeight()
if err != nil {
@@ -136,7 +135,6 @@ func newSwapStateFromStart(
ethAsset,
stage,
moneroStartNumber,
statusCh,
)
if err = b.SwapManager().AddSwap(info); err != nil {
return nil, err
@@ -157,7 +155,8 @@ func newSwapStateFromStart(
return nil, err
}
statusCh <- stage
s.SwapManager().PushNewStatus(offerID, stage)
return s, nil
}
@@ -318,6 +317,11 @@ func (s *swapState) SendKeysMessage() common.Message {
}
}
func (s *swapState) updateStatus(status types.Status) {
s.info.SetStatus(status)
s.SwapManager().PushNewStatus(s.OfferID(), status)
}
// ExpectedAmount returns the amount received, or expected to be received, at the end of the swap
func (s *swapState) ExpectedAmount() *apd.Decimal {
return s.info.ExpectedAmount

View File

@@ -48,7 +48,7 @@ func setupSwapStateUntilETHLocked(t *testing.T) (*swapState, uint64) {
// shutdown swap state, re-create from ongoing
s.cancel()
rdb.EXPECT().GetCounterpartySwapKeys(s.info.OfferID).Return(
rdb.EXPECT().GetCounterpartySwapKeys(s.OfferID()).Return(
makerKeys.PublicKeyPair.SpendKey(),
makerKeys.PrivateKeyPair.ViewKey(),
nil,

View File

@@ -252,7 +252,7 @@ func TestSwapState_HandleProtocolMessage_SendKeysMessage_Refund(t *testing.T) {
require.Equal(t, xmrmakerKeysAndProof.PrivateKeyPair.ViewKey().String(), s.xmrmakerPrivateViewKey.String())
// ensure we refund before t1
for status := range s.info.StatusCh() {
for status := range s.SwapManager().GetStatusChan(s.OfferID()) {
if status == types.CompletedRefund {
// check this is before t1
// TODO: remove the 10-second buffer, this is needed for now
@@ -346,7 +346,7 @@ func TestSwapState_NotifyXMRLock_Refund(t *testing.T) {
require.NoError(t, err)
require.Equal(t, EventETHClaimedType, s.nextExpectedEvent)
for status := range s.info.StatusCh() {
for status := range s.SwapManager().GetStatusChan(s.OfferID()) {
if status == types.CompletedRefund {
// check this is after t2
require.Less(t, s.t2, time.Now())
@@ -368,7 +368,7 @@ func TestExit_afterSendKeysMessage(t *testing.T) {
s.nextExpectedEvent = EventKeysReceivedType
err := s.Exit()
require.NoError(t, err)
info, err := s.SwapManager().GetPastSwap(s.info.OfferID)
info, err := s.SwapManager().GetPastSwap(s.OfferID())
require.NoError(t, err)
require.Equal(t, types.CompletedAbort, info.Status)
}
@@ -395,7 +395,7 @@ func TestExit_afterNotifyXMRLock(t *testing.T) {
err = s.Exit()
require.NoError(t, err)
info, err := s.SwapManager().GetPastSwap(s.info.OfferID)
info, err := s.SwapManager().GetPastSwap(s.OfferID())
require.NoError(t, err)
require.Equal(t, types.CompletedRefund, info.Status)
}
@@ -422,7 +422,7 @@ func TestExit_afterNotifyClaimed(t *testing.T) {
err = s.Exit()
require.NoError(t, err)
info, err := s.SwapManager().GetPastSwap(s.info.OfferID)
info, err := s.SwapManager().GetPastSwap(s.OfferID())
require.NoError(t, err)
require.Equal(t, types.CompletedRefund, info.Status)
}
@@ -450,7 +450,7 @@ func TestExit_invalidNextMessageType(t *testing.T) {
err = s.Exit()
require.True(t, errors.Is(err, errUnexpectedEventType))
info, err := s.SwapManager().GetPastSwap(s.info.OfferID)
info, err := s.SwapManager().GetPastSwap(s.OfferID())
require.NoError(t, err)
require.Equal(t, types.CompletedAbort, info.Status)
}

View File

@@ -5,19 +5,23 @@ package rpc
import (
"context"
"testing"
"time"
"github.com/ChainSafe/chaindb"
"github.com/MarinX/monerorpc/wallet"
"github.com/cockroachdb/apd/v3"
ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/libp2p/go-libp2p/core/peer"
libp2ptest "github.com/libp2p/go-libp2p/core/test"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
"github.com/athanorlabs/atomic-swap/coins"
"github.com/athanorlabs/atomic-swap/common"
"github.com/athanorlabs/atomic-swap/common/types"
mcrypto "github.com/athanorlabs/atomic-swap/crypto/monero"
"github.com/athanorlabs/atomic-swap/db"
"github.com/athanorlabs/atomic-swap/ethereum/extethclient"
"github.com/athanorlabs/atomic-swap/net/message"
"github.com/athanorlabs/atomic-swap/protocol/swap"
@@ -67,32 +71,21 @@ func (*mockNet) CloseProtocolStream(_ types.Hash) {
panic("not implemented")
}
type mockSwapManager struct{}
func mockSwapManager(t *testing.T) swap.Manager {
db, err := db.NewDatabase(&chaindb.Config{
DataDir: t.TempDir(),
InMemory: true,
})
require.NoError(t, err)
func (*mockSwapManager) WriteSwapToDB(_ *swap.Info) error {
return nil
}
func (*mockSwapManager) GetPastIDs() ([]types.Hash, error) {
panic("not implemented")
}
func (*mockSwapManager) GetPastSwap(_ types.Hash) (*swap.Info, error) {
return &swap.Info{}, nil
}
func (*mockSwapManager) GetOngoingSwaps() ([]*swap.Info, error) {
return nil, nil
}
func (*mockSwapManager) GetOngoingSwap(id types.Hash) (swap.Info, error) {
statusCh := make(chan types.Status, 1)
statusCh <- types.CompletedSuccess
sm, err := swap.NewManager(db)
require.NoError(t, err)
one := apd.New(1, 0)
return *swap.NewInfo(
sm.AddSwap(swap.NewInfo(
testPeerID,
id,
testSwapID,
coins.ProvidesETH,
one,
one,
@@ -100,20 +93,11 @@ func (*mockSwapManager) GetOngoingSwap(id types.Hash) (swap.Info, error) {
types.EthAssetETH,
types.CompletedSuccess,
1,
statusCh,
), nil
}
))
func (*mockSwapManager) AddSwap(_ *swap.Info) error {
panic("not implemented")
}
sm.PushNewStatus(testSwapID, types.CompletedSuccess)
func (*mockSwapManager) CompleteOngoingSwap(_ *swap.Info) error {
panic("not implemented")
}
func (*mockSwapManager) HasOngoingSwap(_ types.Hash) bool {
panic("not implemented")
return sm
}
type mockXMRTaker struct{}
@@ -153,10 +137,7 @@ func (m *mockXMRMaker) GetOngoingSwapState(_ types.Hash) common.SwapState {
}
func (*mockXMRMaker) MakeOffer(_ *types.Offer, _ bool) (*types.OfferExtra, error) {
offerExtra := &types.OfferExtra{
StatusCh: make(chan types.Status, 1),
}
offerExtra.StatusCh <- types.CompletedSuccess
offerExtra := types.NewOfferExtra(false)
return offerExtra, nil
}
@@ -193,12 +174,12 @@ func (*mockSwapState) OfferID() types.Hash {
}
type mockProtocolBackend struct {
sm *mockSwapManager
sm swap.Manager
}
func newMockProtocolBackend() *mockProtocolBackend {
func newMockProtocolBackend(t *testing.T) *mockProtocolBackend {
return &mockProtocolBackend{
sm: new(mockSwapManager),
sm: mockSwapManager(t),
}
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/athanorlabs/atomic-swap/common/rpctypes"
"github.com/athanorlabs/atomic-swap/common/types"
"github.com/athanorlabs/atomic-swap/net/message"
"github.com/athanorlabs/atomic-swap/protocol/swap"
)
const defaultSearchTime = time.Second * 12
@@ -37,12 +38,12 @@ type NetService struct {
net Net
xmrtaker XMRTaker
xmrmaker XMRMaker
sm SwapManager
sm swap.Manager
isBootnode bool
}
// NewNetService ...
func NewNetService(net Net, xmrtaker XMRTaker, xmrmaker XMRMaker, sm SwapManager, isBootnode bool) *NetService {
func NewNetService(net Net, xmrtaker XMRTaker, xmrmaker XMRMaker, sm swap.Manager, isBootnode bool) *NetService {
return &NetService{
net: net,
xmrtaker: xmrtaker,
@@ -160,7 +161,7 @@ func (s *NetService) TakeOffer(
return errUnsupportedForBootnode
}
_, err := s.takeOffer(req.PeerID, req.OfferID, req.ProvidesAmount)
err := s.takeOffer(req.PeerID, req.OfferID, req.ProvidesAmount)
if err != nil {
return err
}
@@ -168,13 +169,10 @@ func (s *NetService) TakeOffer(
return nil
}
func (s *NetService) takeOffer(makerPeerID peer.ID, offerID types.Hash, providesAmount *apd.Decimal) (
<-chan types.Status,
error,
) {
func (s *NetService) takeOffer(makerPeerID peer.ID, offerID types.Hash, providesAmount *apd.Decimal) error {
queryResp, err := s.net.Query(makerPeerID)
if err != nil {
return nil, err
return err
}
var offer *types.Offer
@@ -185,12 +183,12 @@ func (s *NetService) takeOffer(makerPeerID peer.ID, offerID types.Hash, provides
}
}
if offer == nil {
return nil, errNoOfferWithID
return errNoOfferWithID
}
swapState, err := s.xmrtaker.InitiateProtocol(makerPeerID, providesAmount, offer)
if err != nil {
return nil, fmt.Errorf("failed to initiate protocol: %w", err)
return fmt.Errorf("failed to initiate protocol: %w", err)
}
skm := swapState.SendKeysMessage().(*message.SendKeysMessage)
@@ -201,15 +199,10 @@ func (s *NetService) takeOffer(makerPeerID peer.ID, offerID types.Hash, provides
if err = swapState.Exit(); err != nil {
log.Warnf("Swap exit failure: %s", err)
}
return nil, err
return err
}
info, err := s.sm.GetOngoingSwap(offerID)
if err != nil {
return nil, err
}
return info.StatusCh(), nil
return nil
}
// TakeOfferSyncResponse ...
@@ -228,7 +221,7 @@ func (s *NetService) TakeOfferSync(
return errUnsupportedForBootnode
}
if _, err := s.takeOffer(req.PeerID, req.OfferID, req.ProvidesAmount); err != nil {
if err := s.takeOffer(req.PeerID, req.OfferID, req.ProvidesAmount); err != nil {
return err
}
@@ -263,7 +256,7 @@ func (s *NetService) MakeOffer(
return errUnsupportedForBootnode
}
offerResp, _, err := s.makeOffer(req)
offerResp, err := s.makeOffer(req)
if err != nil {
return err
}
@@ -271,7 +264,7 @@ func (s *NetService) MakeOffer(
return nil
}
func (s *NetService) makeOffer(req *rpctypes.MakeOfferRequest) (*rpctypes.MakeOfferResponse, *types.OfferExtra, error) {
func (s *NetService) makeOffer(req *rpctypes.MakeOfferRequest) (*rpctypes.MakeOfferResponse, error) {
offer := types.NewOffer(
coins.ProvidesXMR,
req.MinAmount,
@@ -280,13 +273,13 @@ func (s *NetService) makeOffer(req *rpctypes.MakeOfferRequest) (*rpctypes.MakeOf
req.EthAsset,
)
offerExtra, err := s.xmrmaker.MakeOffer(offer, req.UseRelayer)
_, err := s.xmrmaker.MakeOffer(offer, req.UseRelayer)
if err != nil {
return nil, nil, err
return nil, err
}
return &rpctypes.MakeOfferResponse{
PeerID: s.net.PeerID(),
OfferID: offer.ID,
}, offerExtra, nil
}, nil
}

View File

@@ -14,7 +14,7 @@ import (
)
func TestNet_Discover(t *testing.T) {
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager), false)
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, mockSwapManager(t), false)
req := &rpctypes.DiscoverRequest{
Provides: "",
@@ -28,7 +28,7 @@ func TestNet_Discover(t *testing.T) {
}
func TestNet_Query(t *testing.T) {
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager), false)
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, mockSwapManager(t), false)
req := &rpctypes.QueryPeerRequest{
PeerID: "12D3KooWDqCzbjexHEa8Rut7bzxHFpRMZyDRW1L6TGkL1KY24JH5",
@@ -42,7 +42,7 @@ func TestNet_Query(t *testing.T) {
}
func TestNet_TakeOffer(t *testing.T) {
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager), false)
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, mockSwapManager(t), false)
req := &rpctypes.TakeOfferRequest{
PeerID: "12D3KooWDqCzbjexHEa8Rut7bzxHFpRMZyDRW1L6TGkL1KY24JH5",
@@ -55,7 +55,7 @@ func TestNet_TakeOffer(t *testing.T) {
}
func TestNet_TakeOfferSync(t *testing.T) {
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager), false)
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, mockSwapManager(t), false)
req := &rpctypes.TakeOfferRequest{
PeerID: "12D3KooWDqCzbjexHEa8Rut7bzxHFpRMZyDRW1L6TGkL1KY24JH5",

View File

@@ -30,7 +30,7 @@ type Metrics struct {
func pastSwapsMetric(
factory promauto.Factory,
swapManager SwapManager,
swapManager swap.Manager,
status swap.Status,
statusLabel string,
) prometheus.GaugeFunc {
@@ -108,7 +108,7 @@ func SetupMetrics(
Help: "The number of ongoing swaps",
},
func() float64 {
swaps, err := swapManager.GetOngoingSwaps()
swaps, err := swapManager.GetOngoingSwapsSnapshot()
if err != nil {
return float64(-1)
}

View File

@@ -263,6 +263,3 @@ type XMRMaker interface {
ClearOffers([]types.Hash) error
GetMoneroBalance() (*mcrypto.Address, *wallet.GetBalanceResponse, error)
}
// SwapManager ...
type SwapManager = swap.Manager

View File

@@ -24,7 +24,7 @@ import (
// SwapService handles information about ongoing or past swaps.
type SwapService struct {
ctx context.Context
sm SwapManager
sm swap.Manager
xmrtaker XMRTaker
xmrmaker XMRMaker
net Net
@@ -35,7 +35,7 @@ type SwapService struct {
// NewSwapService ...
func NewSwapService(
ctx context.Context,
sm SwapManager,
sm swap.Manager,
xmrtaker XMRTaker,
xmrmaker XMRMaker,
net Net,
@@ -165,17 +165,17 @@ func (s *SwapService) GetOngoing(_ *http.Request, req *GetOngoingRequest, resp *
)
if req.OfferID == nil {
swaps, err = s.sm.GetOngoingSwaps()
swaps, err = s.sm.GetOngoingSwapsSnapshot()
if err != nil {
return err
}
} else {
info, err := s.sm.GetOngoingSwap(*req.OfferID) //nolint:govet
info, err := s.sm.GetOngoingSwapSnapshot(*req.OfferID) //nolint:govet
if err != nil {
return err
}
swaps = []*swap.Info{&info}
swaps = []*swap.Info{info}
}
resp.Swaps = make([]*OngoingSwap, len(swaps))
@@ -221,7 +221,7 @@ type GetStatusResponse struct {
// GetStatus returns the status of the ongoing swap, if there is one.
func (s *SwapService) GetStatus(_ *http.Request, req *GetStatusRequest, resp *GetStatusResponse) error {
info, err := s.sm.GetOngoingSwap(req.ID)
info, err := s.sm.GetOngoingSwapSnapshot(req.ID)
if err != nil {
return err
}
@@ -272,7 +272,7 @@ type CancelResponse struct {
// Cancel attempts to cancel the currently ongoing swap, if there is one.
func (s *SwapService) Cancel(_ *http.Request, req *CancelRequest, resp *CancelResponse) error {
info, err := s.sm.GetOngoingSwap(req.OfferID)
info, err := s.sm.GetOngoingSwapSnapshot(req.OfferID)
if err != nil {
return fmt.Errorf("failed to get ongoing swap: %w", err)
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/athanorlabs/atomic-swap/common/types"
"github.com/athanorlabs/atomic-swap/common/vjson"
mcrypto "github.com/athanorlabs/atomic-swap/crypto/monero"
"github.com/athanorlabs/atomic-swap/protocol/swap"
ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/gorilla/websocket"
@@ -29,13 +30,13 @@ func checkOriginFunc(_ *http.Request) bool {
type wsServer struct {
ctx context.Context
sm SwapManager
sm swap.Manager
ns *NetService
backend ProtocolBackend
taker XMRTaker
}
func newWsServer(ctx context.Context, sm SwapManager, ns *NetService, backend ProtocolBackend,
func newWsServer(ctx context.Context, sm swap.Manager, ns *NetService, backend ProtocolBackend,
taker XMRTaker) *wsServer {
s := &wsServer{
ctx: ctx,
@@ -142,12 +143,12 @@ func (s *wsServer) handleRequest(conn *websocket.Conn, req *rpctypes.Request) er
return fmt.Errorf("failed to unmarshal parameters: %w", err)
}
ch, err := s.ns.takeOffer(params.PeerID, params.OfferID, params.ProvidesAmount)
err := s.ns.takeOffer(params.PeerID, params.OfferID, params.ProvidesAmount)
if err != nil {
return err
}
return s.subscribeTakeOffer(s.ctx, conn, ch)
return s.subscribeSwapStatus(s.ctx, conn, params.OfferID)
case rpctypes.SubscribeMakeOffer:
if s.ns == nil {
return errNamespaceNotEnabled
@@ -158,12 +159,12 @@ func (s *wsServer) handleRequest(conn *websocket.Conn, req *rpctypes.Request) er
return fmt.Errorf("failed to unmarshal parameters: %w", err)
}
offerResp, offerExtra, err := s.ns.makeOffer(params)
offerResp, err := s.ns.makeOffer(params)
if err != nil {
return err
}
return s.subscribeMakeOffer(s.ctx, conn, offerResp.OfferID, offerExtra)
return s.subscribeMakeOffer(s.ctx, conn, offerResp.OfferID)
default:
return errInvalidMethod
}
@@ -240,8 +241,22 @@ func (s *wsServer) handleSigner(
}
}
func (s *wsServer) subscribeTakeOffer(ctx context.Context, conn *websocket.Conn,
statusCh <-chan types.Status) error {
func (s *wsServer) subscribeMakeOffer(
ctx context.Context,
conn *websocket.Conn,
offerID types.Hash,
) error {
resp := &rpctypes.MakeOfferResponse{
PeerID: s.ns.net.PeerID(),
OfferID: offerID,
}
if err := writeResponse(conn, resp); err != nil {
return err
}
statusCh := s.backend.SwapManager().GetStatusChan(offerID)
for {
select {
case status, ok := <-statusCh:
@@ -266,53 +281,19 @@ func (s *wsServer) subscribeTakeOffer(ctx context.Context, conn *websocket.Conn,
}
}
func (s *wsServer) subscribeMakeOffer(ctx context.Context, conn *websocket.Conn,
offerID types.Hash, offerExtra *types.OfferExtra) error {
resp := &rpctypes.MakeOfferResponse{
PeerID: s.ns.net.PeerID(),
OfferID: offerID,
// subscribeSwapStatus writes the swap's status transitions to the websockets
// connection when the state changes. When the swap completes, it writes the
// final status and then closes the connection. This method is not intended for
// simultaneous requests on the same swap. If more than one request is made
// (including calls to net_[make|take]OfferAndSubscribe), only one of the
// websocket connections will see any individual state transition.
func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn, offerID types.Hash) error {
statusCh := s.backend.SwapManager().GetStatusChan(offerID)
if !s.sm.HasOngoingSwap(offerID) {
return s.writeSwapExitStatus(conn, offerID)
}
if err := writeResponse(conn, resp); err != nil {
return err
}
for {
select {
case status, ok := <-offerExtra.StatusCh:
if !ok {
return nil
}
resp := &rpctypes.SubscribeSwapStatusResponse{
Status: status,
}
if err := writeResponse(conn, resp); err != nil {
return err
}
if !status.IsOngoing() {
return nil
}
case <-ctx.Done():
return nil
}
}
}
// subscribeSwapStatus writes the swap's stage to the connection every time it updates.
// when the swap completes, it writes the final status then closes the connection.
// example: `{"jsonrpc":"2.0", "method":"swap_subscribeStatus", "params": {"id": 0}, "id": 0}`
func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn, id types.Hash) error {
// we can ignore the error here, since the error will only be if the swap cannot be found
// as ongoing, in which case `writeSwapExitStatus` will look for it in the past swaps.
info, err := s.sm.GetOngoingSwap(id)
if err != nil {
return s.writeSwapExitStatus(conn, id)
}
statusCh := info.StatusCh()
for {
select {
case status, ok := <-statusCh:

View File

@@ -26,7 +26,7 @@ var (
testTimeout = time.Second * 5
)
func newServer(t *testing.T) *Server {
func newServer(t *testing.T) (*Server, *Config) {
ctx, cancel := context.WithCancel(context.Background())
cfg := &Config{
@@ -34,7 +34,7 @@ func newServer(t *testing.T) *Server {
Env: common.Development,
Address: "127.0.0.1:0", // OS assigned port
Net: new(mockNet),
ProtocolBackend: newMockProtocolBackend(),
ProtocolBackend: newMockProtocolBackend(t),
XMRTaker: new(mockXMRTaker),
XMRMaker: new(mockXMRMaker),
Namespaces: AllNamespaces(),
@@ -60,11 +60,11 @@ func newServer(t *testing.T) *Server {
wg.Wait() // wait for the server to exit
})
return s
return s, cfg
}
func TestSubscribeSwapStatus(t *testing.T) {
s := newServer(t)
s, _ := newServer(t)
c, err := wsclient.NewWsClient(s.ctx, s.WsURL())
require.NoError(t, err)
@@ -81,7 +81,7 @@ func TestSubscribeSwapStatus(t *testing.T) {
}
func TestSubscribeMakeOffer(t *testing.T) {
s := newServer(t)
s, cfg := newServer(t)
c, err := wsclient.NewWsClient(s.ctx, s.WsURL())
require.NoError(t, err)
@@ -92,6 +92,9 @@ func TestSubscribeMakeOffer(t *testing.T) {
offerResp, ch, err := c.MakeOfferAndSubscribe(min, max, exRate, types.EthAssetETH, false)
require.NoError(t, err)
require.NotEqual(t, offerResp.OfferID, testSwapID)
cfg.ProtocolBackend.SwapManager().PushNewStatus(offerResp.OfferID, types.CompletedSuccess)
select {
case status := <-ch:
require.Equal(t, types.CompletedSuccess, status)
@@ -101,7 +104,7 @@ func TestSubscribeMakeOffer(t *testing.T) {
}
func TestSubscribeTakeOffer(t *testing.T) {
s := newServer(t)
s, _ := newServer(t)
cliCtx, cancel := context.WithCancel(context.Background())
t.Cleanup(func() {