mirror of
https://github.com/AthanorLabs/atomic-swap.git
synced 2026-01-10 06:38:04 -05:00
ensure that the status channel is always initialized (#471)
This commit is contained in:
@@ -174,12 +174,6 @@ func (o *Offer) validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OfferExtra represents extra data that is passed when an offer is made.
|
||||
type OfferExtra struct {
|
||||
StatusCh chan Status `json:"-"`
|
||||
UseRelayer bool `json:"useRelayer,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalOffer deserializes a JSON offer, checking the version for compatibility before
|
||||
// attempting to deserialize the whole blob.
|
||||
func UnmarshalOffer(jsonData []byte) (*Offer, error) {
|
||||
@@ -228,3 +222,19 @@ func (o *Offer) UnmarshalJSON(data []byte) error {
|
||||
}
|
||||
return o.validate()
|
||||
}
|
||||
|
||||
// OfferExtra represents extra data that is passed when an offer is made.
|
||||
type OfferExtra struct {
|
||||
// UseRelayer forces the XMR maker to claim using the relayer even when he
|
||||
// has enough funds to make the claim himself. Setting it to false will not
|
||||
// prevent the relayer from being used if there are insufficient ETH funds
|
||||
// to claim.
|
||||
UseRelayer bool `json:"useRelayer,omitempty"`
|
||||
}
|
||||
|
||||
// NewOfferExtra creates an OfferExtra instance
|
||||
func NewOfferExtra(forceUseRelayer bool) *OfferExtra {
|
||||
return &OfferExtra{
|
||||
UseRelayer: forceUseRelayer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@@ -250,3 +251,17 @@ func TestUnmarshalOffer_VersionTooNew(t *testing.T) {
|
||||
_, err := UnmarshalOffer([]byte(offerJSON))
|
||||
require.ErrorContains(t, err, fmt.Sprintf("offer version %q not supported", unsupportedVersion))
|
||||
}
|
||||
|
||||
func TestOfferExtra_JSON(t *testing.T) {
|
||||
// Marshal test
|
||||
extra := NewOfferExtra(true)
|
||||
data, err := vjson.MarshalStruct(extra)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"useRelayer":true}`, string(data))
|
||||
|
||||
// Unmarshal test
|
||||
extra = new(OfferExtra)
|
||||
err = json.Unmarshal(data, extra)
|
||||
require.NoError(t, err)
|
||||
require.True(t, extra.UseRelayer)
|
||||
}
|
||||
|
||||
141
daemon/refund_after_restart_test.go
Normal file
141
daemon/refund_after_restart_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/athanorlabs/atomic-swap/coins"
|
||||
"github.com/athanorlabs/atomic-swap/common/types"
|
||||
"github.com/athanorlabs/atomic-swap/monero"
|
||||
"github.com/athanorlabs/atomic-swap/rpcclient/wsclient"
|
||||
"github.com/athanorlabs/atomic-swap/tests"
|
||||
)
|
||||
|
||||
func TestXMRNotLockedAndETHRefundedAfterAliceRestarts(t *testing.T) {
|
||||
minXMR := coins.StrToDecimal("1")
|
||||
maxXMR := minXMR
|
||||
exRate := coins.StrToExchangeRate("0.1")
|
||||
providesAmt, err := exRate.ToETH(minXMR)
|
||||
require.NoError(t, err)
|
||||
|
||||
bobConf := CreateTestConf(t, tests.GetMakerTestKey(t))
|
||||
monero.MineMinXMRBalance(t, bobConf.MoneroClient, coins.MoneroToPiconero(maxXMR))
|
||||
|
||||
aliceConf := CreateTestConf(t, tests.GetTakerTestKey(t))
|
||||
|
||||
timeout := 7 * time.Minute
|
||||
ctx, cancel := LaunchDaemons(t, timeout, aliceConf, bobConf)
|
||||
|
||||
// clients use a separate context and will work across server restarts
|
||||
clientCtx := context.Background()
|
||||
bc, err := wsclient.NewWsClient(clientCtx, fmt.Sprintf("ws://127.0.0.1:%d/ws", bobConf.RPCPort))
|
||||
require.NoError(t, err)
|
||||
ac, err := wsclient.NewWsClient(clientCtx, fmt.Sprintf("ws://127.0.0.1:%d/ws", aliceConf.RPCPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bob makes an offer
|
||||
makeResp, bobStatusCh, err := bc.MakeOfferAndSubscribe(minXMR, maxXMR, exRate, types.EthAssetETH, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Alice takes the offer
|
||||
aliceStatusCh, err := ac.TakeOfferAndSubscribe(makeResp.PeerID, makeResp.OfferID, providesAmt)
|
||||
require.NoError(t, err)
|
||||
|
||||
var statusWG sync.WaitGroup
|
||||
statusWG.Add(2)
|
||||
|
||||
// Alice shuts down both servers as soon as she locks her ETH
|
||||
go func() {
|
||||
defer statusWG.Done()
|
||||
for {
|
||||
select {
|
||||
case status := <-aliceStatusCh:
|
||||
t.Log("> Alice got status:", status)
|
||||
switch status {
|
||||
case types.ExpectingKeys:
|
||||
continue
|
||||
case types.ETHLocked:
|
||||
cancel() // stop both Alice's and Bob's daemons
|
||||
default:
|
||||
cancel()
|
||||
t.Errorf("Alice should not have reached status=%s", status)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Logf("Alice's context cancelled (expected)")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Bob is not playing a significant role in this test. His swapd instance is
|
||||
// shut down before he can lock any XMR and we don't bring it back online.
|
||||
go func() {
|
||||
defer statusWG.Done()
|
||||
for {
|
||||
select {
|
||||
case status := <-bobStatusCh:
|
||||
t.Log("> Bob got status:", status)
|
||||
switch status {
|
||||
case types.KeysExchanged:
|
||||
continue
|
||||
default:
|
||||
cancel()
|
||||
t.Errorf("Bob should not have reached status=%s", status)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Logf("Bob's context cancelled (expected)")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
statusWG.Wait()
|
||||
if t.Failed() {
|
||||
return
|
||||
}
|
||||
|
||||
// Make sure both servers had time to fully shut down
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// relaunch Alice's daemon
|
||||
t.Logf("daemons stopped, now re-launching Alice's daemon in isolation")
|
||||
ctx, cancel = LaunchDaemons(t, 3*time.Minute, aliceConf)
|
||||
|
||||
// This is a bug that we need to recreate Alice's websocket client here. Remove this
|
||||
// code when we fix https://github.com/AthanorLabs/atomic-swap/issues/353.
|
||||
ac, err = wsclient.NewWsClient(clientCtx, fmt.Sprintf("ws://127.0.0.1:%d/ws", aliceConf.RPCPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
aliceStatusCh, err = ac.SubscribeSwapStatus(makeResp.OfferID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure Alice completes the swap with a refund
|
||||
statusWG.Add(1)
|
||||
go func() {
|
||||
defer statusWG.Done()
|
||||
for {
|
||||
select {
|
||||
case status := <-aliceStatusCh:
|
||||
t.Log("> Alice, after restart, got status:", status)
|
||||
if !status.IsOngoing() {
|
||||
assert.Equal(t, types.CompletedRefund.String(), status.String())
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// Alice's context has a deadline. If we get here, the context
|
||||
// expired before we got any Refund status update.
|
||||
t.Errorf("Alice's context cancelled before she completed the swap")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
statusWG.Wait()
|
||||
// TODO: Add some additional checks here when the rest of the test is working
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"path"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
@@ -87,13 +88,15 @@ func CreateTestBootnode(t *testing.T) (uint16, string) {
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
dataDir := t.TempDir()
|
||||
|
||||
conf := &bootnode.Config{
|
||||
Env: common.Development,
|
||||
DataDir: t.TempDir(),
|
||||
Bootnodes: nil,
|
||||
P2PListenIP: "127.0.0.1",
|
||||
Libp2pPort: 0,
|
||||
Libp2pKeyFile: common.DefaultLibp2pKeyFileName,
|
||||
Libp2pKeyFile: path.Join(dataDir, common.DefaultLibp2pKeyFileName),
|
||||
RPCPort: uint16(rpcPort),
|
||||
}
|
||||
|
||||
|
||||
@@ -331,11 +331,10 @@ func TestDatabase_SwapTable_Update(t *testing.T) {
|
||||
|
||||
// infoB mostly the same as infoA (same ID, importantly), but with
|
||||
// a couple updated fields.
|
||||
infoB := new(swap.Info)
|
||||
*infoB = *infoA
|
||||
infoB, err := infoA.DeepCopy()
|
||||
require.NoError(t, err)
|
||||
infoB.Status = types.CompletedSuccess
|
||||
endTime := time.Now()
|
||||
infoB.EndTime = &endTime
|
||||
infoB.MarkSwapComplete()
|
||||
|
||||
err = db.PutSwap(infoB)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -83,16 +83,14 @@ func TestRecoveryDB_SwapRelayerInfo(t *testing.T) {
|
||||
rdb := newTestRecoveryDB(t)
|
||||
offerID := types.Hash{5, 6, 7, 8}
|
||||
|
||||
info := &types.OfferExtra{
|
||||
UseRelayer: true,
|
||||
}
|
||||
extra := types.NewOfferExtra(true)
|
||||
|
||||
err := rdb.PutSwapRelayerInfo(offerID, info)
|
||||
err := rdb.PutSwapRelayerInfo(offerID, extra)
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := rdb.GetSwapRelayerInfo(offerID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info, res)
|
||||
require.Equal(t, extra, res)
|
||||
}
|
||||
|
||||
func TestRecoveryDB_SwapPrivateKey(t *testing.T) {
|
||||
@@ -165,13 +163,11 @@ func TestRecoveryDB_DeleteSwap(t *testing.T) {
|
||||
SwapCreatorAddr: ethcommon.HexToAddress("0xd2b5d6252d0645e4cf4bb547e82a485f527befb7"),
|
||||
}
|
||||
|
||||
info := &types.OfferExtra{
|
||||
UseRelayer: true,
|
||||
}
|
||||
extra := types.NewOfferExtra(true)
|
||||
|
||||
err = rdb.PutContractSwapInfo(offerID, si)
|
||||
require.NoError(t, err)
|
||||
err = rdb.PutSwapRelayerInfo(offerID, info)
|
||||
err = rdb.PutSwapRelayerInfo(offerID, extra)
|
||||
require.NoError(t, err)
|
||||
err = rdb.PutSwapPrivateKey(offerID, kp.SpendKey())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -262,7 +262,7 @@ func (b *backend) ClearXMRDepositAddress(offerID types.Hash) {
|
||||
// HasOngoingSwapAsTaker returns nil if we have an ongoing swap with the given peer where
|
||||
// we're the xmrtaker, otherwise returns an error.
|
||||
func (b *backend) HasOngoingSwapAsTaker(remotePeer peer.ID) error {
|
||||
swaps, err := b.swapManager.GetOngoingSwaps()
|
||||
swaps, err := b.swapManager.GetOngoingSwapsSnapshot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -293,7 +293,7 @@ func (b *backend) HandleRelayClaimRequest(
|
||||
return nil, fmt.Errorf("cannot relay taker-specific claim request; no ongoing swap for swap %s", *request.OfferID)
|
||||
}
|
||||
|
||||
info, err := b.swapManager.GetOngoingSwap(*request.OfferID)
|
||||
info, err := b.swapManager.GetOngoingSwapSnapshot(*request.OfferID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ var (
|
||||
// SwapManager is the subset of the swap.Manager interface needed by ClaimMonero
|
||||
type SwapManager interface {
|
||||
WriteSwapToDB(info *swap.Info) error
|
||||
PushNewStatus(offerID types.Hash, status types.Status)
|
||||
}
|
||||
|
||||
// GetClaimKeypair returns the private key pair required for a monero claim.
|
||||
@@ -103,6 +104,7 @@ func ClaimMonero(
|
||||
// setSweepStatus sets the swap's status as `SweepingXMR` and writes it to the db.
|
||||
func setSweepStatus(info *swap.Info, sm SwapManager) error {
|
||||
info.SetStatus(types.SweepingXMR)
|
||||
sm.PushNewStatus(info.OfferID, types.SweepingXMR)
|
||||
err := sm.WriteSwapToDB(info)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write swap to db: %w", err)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/athanorlabs/atomic-swap/coins"
|
||||
"github.com/athanorlabs/atomic-swap/common"
|
||||
"github.com/athanorlabs/atomic-swap/common/types"
|
||||
mcrypto "github.com/athanorlabs/atomic-swap/crypto/monero"
|
||||
"github.com/athanorlabs/atomic-swap/monero"
|
||||
"github.com/athanorlabs/atomic-swap/protocol/swap"
|
||||
@@ -29,6 +30,9 @@ func (*mockSwapManager) WriteSwapToDB(info *swap.Info) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*mockSwapManager) PushNewStatus(_ types.Hash, _ types.Status) {
|
||||
}
|
||||
|
||||
func TestClaimMonero_NoTransferBack(t *testing.T) {
|
||||
env := common.Development
|
||||
|
||||
|
||||
75
protocol/swap/status_manager.go
Normal file
75
protocol/swap/status_manager.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package swap
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/athanorlabs/atomic-swap/common/types"
|
||||
)
|
||||
|
||||
// statusManager provides lookup for the status channels. Status channels are
|
||||
// ephemeral between runs of swapd.
|
||||
type statusManager struct {
|
||||
mu sync.Mutex
|
||||
statusChannels map[types.Hash]chan Status
|
||||
}
|
||||
|
||||
func newStatusManager() *statusManager {
|
||||
return &statusManager{
|
||||
mu: sync.Mutex{},
|
||||
statusChannels: make(map[types.Hash]chan Status),
|
||||
}
|
||||
}
|
||||
|
||||
// getStatusChan returns any existing status channel or a new status channel for
|
||||
// reading or writing.
|
||||
func (sm *statusManager) getStatusChan(offerID types.Hash) chan Status {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
_, ok := sm.statusChannels[offerID]
|
||||
if !ok {
|
||||
sm.statusChannels[offerID] = newStatusChannel()
|
||||
}
|
||||
|
||||
return sm.statusChannels[offerID]
|
||||
}
|
||||
|
||||
// GetStatusChan returns any existing status channel or a new status channel for
|
||||
// reading only.
|
||||
func (sm *statusManager) GetStatusChan(offerID types.Hash) <-chan Status {
|
||||
return sm.getStatusChan(offerID)
|
||||
}
|
||||
|
||||
// DeleteStatusChan deletes any status channel associated with the offer ID.
|
||||
func (sm *statusManager) DeleteStatusChan(offerID types.Hash) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
delete(sm.statusChannels, offerID)
|
||||
}
|
||||
|
||||
// PushNewStatus adds a new status to the offer ID's channel
|
||||
func (sm *statusManager) PushNewStatus(offerID types.Hash, status types.Status) {
|
||||
ch := sm.getStatusChan(offerID)
|
||||
ch <- status
|
||||
// If the status is not ongoing, existing subscribers will get the status
|
||||
// via the channel since they already have a reference to it. New
|
||||
// subscribers will get the final status from the past swaps map.
|
||||
if !status.IsOngoing() {
|
||||
// We grabbed the status channel before calling IsOngoing to avoid a
|
||||
// race condition where the status becomes complete after the check, but
|
||||
// before we grab a reference to the channel. If the status was complete
|
||||
// before we grabbed the channel, we created a new channel, which we
|
||||
// remove below.
|
||||
sm.DeleteStatusChan(offerID)
|
||||
}
|
||||
}
|
||||
|
||||
// newStatusChannel creates a status channel using the the correct size
|
||||
func newStatusChannel() chan Status {
|
||||
// The channel size should be large enough to handle the max number of
|
||||
// stages a swap can potentially go through.
|
||||
const statusChSize = 6
|
||||
ch := make(chan Status, statusChSize)
|
||||
return ch
|
||||
}
|
||||
26
protocol/swap/status_manager_test.go
Normal file
26
protocol/swap/status_manager_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package swap
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/athanorlabs/atomic-swap/common/types"
|
||||
)
|
||||
|
||||
func TestStatusManager(t *testing.T) {
|
||||
offerID1 := types.Hash{0x1}
|
||||
|
||||
statusMgr := newStatusManager()
|
||||
ch1 := statusMgr.GetStatusChan(offerID1)
|
||||
ch2 := statusMgr.GetStatusChan(offerID1)
|
||||
require.Equal(t, ch1, ch2)
|
||||
|
||||
statusMgr.PushNewStatus(offerID1, types.CompletedSuccess)
|
||||
status := <-ch1
|
||||
require.Equal(t, types.CompletedSuccess, status)
|
||||
|
||||
statusMgr.DeleteStatusChan(offerID1)
|
||||
ch3 := statusMgr.GetStatusChan(offerID1)
|
||||
require.NotEqual(t, ch1, ch3)
|
||||
}
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/ChainSafe/chaindb"
|
||||
)
|
||||
|
||||
var errNoSwapWithID = errors.New("unable to find swap with given ID")
|
||||
var errNoSwapWithOfferID = errors.New("unable to find swap with given offer ID")
|
||||
|
||||
// Manager tracks current and past swaps.
|
||||
type Manager interface {
|
||||
@@ -23,10 +23,15 @@ type Manager interface {
|
||||
WriteSwapToDB(info *Info) error
|
||||
GetPastIDs() ([]types.Hash, error)
|
||||
GetPastSwap(types.Hash) (*Info, error)
|
||||
GetOngoingSwap(types.Hash) (Info, error)
|
||||
GetOngoingSwaps() ([]*Info, error)
|
||||
GetOngoingSwap(hash types.Hash) (*Info, error)
|
||||
GetOngoingSwapSnapshot(types.Hash) (*Info, error)
|
||||
GetOngoingSwapOfferIDs() ([]*types.Hash, error)
|
||||
GetOngoingSwapsSnapshot() ([]*Info, error)
|
||||
CompleteOngoingSwap(info *Info) error
|
||||
HasOngoingSwap(types.Hash) bool
|
||||
GetStatusChan(offerID types.Hash) <-chan types.Status
|
||||
DeleteStatusChan(offerID types.Hash)
|
||||
PushNewStatus(offerID types.Hash, status types.Status)
|
||||
}
|
||||
|
||||
// manager implements Manager.
|
||||
@@ -38,6 +43,7 @@ type manager struct {
|
||||
sync.RWMutex
|
||||
ongoing map[types.Hash]*Info
|
||||
past map[types.Hash]*Info
|
||||
*statusManager
|
||||
}
|
||||
|
||||
var _ Manager = (*manager)(nil)
|
||||
@@ -62,9 +68,10 @@ func NewManager(db Database) (Manager, error) {
|
||||
}
|
||||
|
||||
return &manager{
|
||||
db: db,
|
||||
ongoing: ongoing,
|
||||
past: make(map[types.Hash]*Info),
|
||||
db: db,
|
||||
ongoing: ongoing,
|
||||
past: make(map[types.Hash]*Info),
|
||||
statusManager: newStatusManager(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -140,30 +147,70 @@ func (m *manager) GetPastSwap(id types.Hash) (*Info, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GetOngoingSwap returns the ongoing swap's *Info, if there is one.
|
||||
func (m *manager) GetOngoingSwap(id types.Hash) (Info, error) {
|
||||
// GetOngoingSwap returns the ongoing swap's *Info, if there is one. The
|
||||
// returned Info structure of an active swap can be modified as the swap's state
|
||||
// changes and should only be read or written by a single go process.
|
||||
func (m *manager) GetOngoingSwap(offerID types.Hash) (*Info, error) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
s, has := m.ongoing[id]
|
||||
|
||||
s, has := m.ongoing[offerID]
|
||||
if !has {
|
||||
return Info{}, errNoSwapWithID
|
||||
return nil, errNoSwapWithOfferID
|
||||
}
|
||||
|
||||
return *s, nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GetOngoingSwaps returns all ongoing swaps.
|
||||
func (m *manager) GetOngoingSwaps() ([]*Info, error) {
|
||||
// GetOngoingSwapSnapshot returns a copy of the ongoing swap's Info, if the
|
||||
// offerID has an ongoing swap.
|
||||
func (m *manager) GetOngoingSwapSnapshot(offerID types.Hash) (*Info, error) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
swaps := make([]*Info, len(m.ongoing))
|
||||
i := 0
|
||||
for _, s := range m.ongoing {
|
||||
sCopy := new(Info)
|
||||
*sCopy = *s
|
||||
swaps[i] = sCopy
|
||||
i++
|
||||
|
||||
s, has := m.ongoing[offerID]
|
||||
if !has {
|
||||
return nil, errNoSwapWithOfferID
|
||||
}
|
||||
|
||||
sc, err := s.DeepCopy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sc, nil
|
||||
}
|
||||
|
||||
// GetOngoingSwapOfferIDs returns a list of the offer IDs of all ongoing
|
||||
// swaps.
|
||||
func (m *manager) GetOngoingSwapOfferIDs() ([]*types.Hash, error) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
offerIDs := make([]*types.Hash, 0, len(m.ongoing))
|
||||
for _, s := range m.ongoing {
|
||||
offerIDs = append(offerIDs, &s.OfferID)
|
||||
}
|
||||
|
||||
return offerIDs, nil
|
||||
}
|
||||
|
||||
// GetOngoingSwapsSnapshot returns a copy of all ongoing swaps. If you need to
|
||||
// modify the result, call `GetOngoingSwap` on the offerID to get the "live"
|
||||
// Info object.
|
||||
func (m *manager) GetOngoingSwapsSnapshot() ([]*Info, error) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
swaps := make([]*Info, 0, len(m.ongoing))
|
||||
for _, s := range m.ongoing {
|
||||
sc, err := s.DeepCopy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
swaps = append(swaps, sc)
|
||||
}
|
||||
|
||||
return swaps, nil
|
||||
}
|
||||
|
||||
@@ -171,9 +218,10 @@ func (m *manager) GetOngoingSwaps() ([]*Info, error) {
|
||||
func (m *manager) CompleteOngoingSwap(info *Info) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
_, has := m.ongoing[info.OfferID]
|
||||
if !has {
|
||||
return errNoSwapWithID
|
||||
return errNoSwapWithOfferID
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
@@ -190,6 +238,7 @@ func (m *manager) CompleteOngoingSwap(info *Info) error {
|
||||
func (m *manager) HasOngoingSwap(id types.Hash) bool {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
|
||||
_, has := m.ongoing[id]
|
||||
return has
|
||||
}
|
||||
@@ -197,7 +246,7 @@ func (m *manager) HasOngoingSwap(id types.Hash) bool {
|
||||
func (m *manager) getSwapFromDB(id types.Hash) (*Info, error) {
|
||||
s, err := m.db.GetSwap(id)
|
||||
if errors.Is(chaindb.ErrKeyNotFound, err) {
|
||||
return nil, errNoSwapWithID
|
||||
return nil, errNoSwapWithOfferID
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -38,7 +38,6 @@ func TestNewManager(t *testing.T) {
|
||||
types.EthAssetETH,
|
||||
types.ExpectingKeys,
|
||||
100,
|
||||
nil,
|
||||
)
|
||||
db.EXPECT().PutSwap(infoA)
|
||||
err = m.AddSwap(infoA)
|
||||
@@ -54,7 +53,6 @@ func TestNewManager(t *testing.T) {
|
||||
types.EthAssetETH,
|
||||
types.CompletedSuccess,
|
||||
100,
|
||||
nil,
|
||||
)
|
||||
db.EXPECT().PutSwap(infoB)
|
||||
err = m.AddSwap(infoB)
|
||||
@@ -88,7 +86,6 @@ func TestManager_AddSwap_Ongoing(t *testing.T) {
|
||||
types.EthAssetETH,
|
||||
types.ExpectingKeys,
|
||||
100,
|
||||
nil,
|
||||
)
|
||||
|
||||
db.EXPECT().PutSwap(info)
|
||||
@@ -100,7 +97,7 @@ func TestManager_AddSwap_Ongoing(t *testing.T) {
|
||||
|
||||
s, err := m.GetOngoingSwap(types.Hash{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info, &s)
|
||||
require.Equal(t, info, s)
|
||||
require.NotNil(t, m.ongoing)
|
||||
|
||||
db.EXPECT().PutSwap(info)
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
@@ -18,8 +19,6 @@ import (
|
||||
"github.com/athanorlabs/atomic-swap/common/vjson"
|
||||
)
|
||||
|
||||
const statusChSize = 6 // the max number of stages a swap can potentially go through
|
||||
|
||||
var (
|
||||
// CurInfoVersion is the latest supported version of a serialised Info struct
|
||||
CurInfoVersion, _ = semver.NewVersion("0.3.0")
|
||||
@@ -63,8 +62,18 @@ type Info struct {
|
||||
// (and after Timeout1), the ETH-taker is able to claim, but
|
||||
// after this timeout, the ETH-taker can no longer claim, only
|
||||
// the ETH-maker can refund.
|
||||
Timeout2 *time.Time `json:"timeout2,omitempty"`
|
||||
statusCh chan types.Status `json:"-"`
|
||||
Timeout2 *time.Time `json:"timeout2,omitempty"`
|
||||
|
||||
// rwMu handles synchronization when LastStatusUpdateTime, Timeout1,
|
||||
// Timeout2 and EndTime are updated. This Info struct is modified by the
|
||||
// maker or taker's swapState go process as the state of the swap
|
||||
// progresses. The swapState go process does not need synchronization when
|
||||
// reading its own changes, but it needs to grab a write lock when modifying
|
||||
// the the structure. Readers from other go-processes only get copies of
|
||||
// this structure. They exclusively use the DeepCopy method to get their
|
||||
// copy, which grabs the read lock ensuring that they always capture the
|
||||
// up-to-date state of this Info struct.
|
||||
rwMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewInfo creates a new *Info from the given parameters.
|
||||
@@ -78,7 +87,6 @@ func NewInfo(
|
||||
ethAsset types.EthAsset,
|
||||
status Status,
|
||||
moneroStartHeight uint64,
|
||||
statusCh chan types.Status,
|
||||
) *Info {
|
||||
info := &Info{
|
||||
Version: CurInfoVersion,
|
||||
@@ -92,26 +100,42 @@ func NewInfo(
|
||||
Status: status,
|
||||
LastStatusUpdateTime: time.Now(),
|
||||
MoneroStartHeight: moneroStartHeight,
|
||||
statusCh: statusCh,
|
||||
StartTime: time.Now(),
|
||||
EndTime: nil,
|
||||
Timeout1: nil,
|
||||
Timeout2: nil,
|
||||
rwMu: sync.RWMutex{},
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// StatusCh returns the swap's status update channel.
|
||||
func (i *Info) StatusCh() <-chan types.Status {
|
||||
return i.statusCh
|
||||
}
|
||||
|
||||
// SetStatus ...
|
||||
// SetStatus updates the status and status modification timestamp
|
||||
func (i *Info) SetStatus(s Status) {
|
||||
i.rwMu.Lock()
|
||||
defer i.rwMu.Unlock()
|
||||
|
||||
i.Status = s
|
||||
i.LastStatusUpdateTime = time.Now()
|
||||
if i.statusCh == nil {
|
||||
// this case only happens in tests.
|
||||
return
|
||||
}
|
||||
i.statusCh <- s
|
||||
}
|
||||
|
||||
// SetTimeouts sets the 2 timeout fields, , grabbing the needed lock before
|
||||
// modifying fields.
|
||||
func (i *Info) SetTimeouts(t1 *time.Time, t2 *time.Time) {
|
||||
i.rwMu.Lock()
|
||||
defer i.rwMu.Unlock()
|
||||
|
||||
i.Timeout1 = t1
|
||||
i.Timeout2 = t2
|
||||
}
|
||||
|
||||
// MarkSwapComplete sets the EndTime field to the current wall time, grabbing the
|
||||
// needed lock before modifying fields.
|
||||
func (i *Info) MarkSwapComplete() {
|
||||
i.rwMu.Lock()
|
||||
defer i.rwMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
i.EndTime = &now
|
||||
}
|
||||
|
||||
// IsTaker returns true if the node is the xmr-taker in the swap.
|
||||
@@ -119,32 +143,66 @@ func (i *Info) IsTaker() bool {
|
||||
return i.Provides == coins.ProvidesETH
|
||||
}
|
||||
|
||||
// UnmarshalInfo deserializes a JSON Info struct, checking the version for compatibility
|
||||
// before attempting to deserialize the whole blob.
|
||||
// UnmarshalInfo unmarshalls the passed JSON into a freshly created Info object.
|
||||
func UnmarshalInfo(jsonData []byte) (*Info, error) {
|
||||
ov := struct {
|
||||
info := new(Info)
|
||||
if err := json.Unmarshal(jsonData, info); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON deserializes a JSON Info struct, checking the version for
|
||||
// compatibility.
|
||||
func (i *Info) UnmarshalJSON(jsonData []byte) error {
|
||||
iv := struct {
|
||||
Version *semver.Version `json:"version"`
|
||||
}{}
|
||||
if err := json.Unmarshal(jsonData, &ov); err != nil {
|
||||
return nil, err
|
||||
if err := json.Unmarshal(jsonData, &iv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ov.Version == nil {
|
||||
return nil, errInfoVersionMissing
|
||||
if iv.Version == nil {
|
||||
return errInfoVersionMissing
|
||||
}
|
||||
|
||||
if ov.Version.GreaterThan(CurInfoVersion) {
|
||||
return nil, fmt.Errorf("info version %q not supported, latest is %q", ov.Version, CurInfoVersion)
|
||||
if iv.Version.GreaterThan(CurInfoVersion) {
|
||||
return fmt.Errorf("info version %q not supported, latest is %q", iv.Version, CurInfoVersion)
|
||||
}
|
||||
|
||||
info := new(Info)
|
||||
if err := vjson.UnmarshalStruct(jsonData, info); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Assuming any version less than the current version is forwards
|
||||
// compatible. If that is not the case in the future, add code here to
|
||||
// upgrade the older version to the current version when deserializing.
|
||||
// (Or error if it is completely incompatible.)
|
||||
|
||||
info.statusCh = make(chan types.Status, statusChSize)
|
||||
// Unmarshal without recursion
|
||||
type _Info Info
|
||||
if err := vjson.UnmarshalStruct(jsonData, (*_Info)(i)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Are there additional sanity checks we can perform on the Provided and Received amounts
|
||||
// (or other fields) here when decoding the JSON?
|
||||
return info, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeepCopy returns a deep copy of the Info data structure
|
||||
func (i *Info) DeepCopy() (*Info, error) {
|
||||
i.rwMu.RLock()
|
||||
defer i.rwMu.RUnlock()
|
||||
|
||||
// This is not the most efficient means of getting a deep copy, but for our
|
||||
// needs it is fast enough and least prone to human error, as the structure
|
||||
// has numerous nested pointer types.
|
||||
jsonData, err := json.Marshal(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clone := new(Info)
|
||||
if err = clone.UnmarshalJSON(jsonData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return clone, nil
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ func Test_InfoMarshal(t *testing.T) {
|
||||
types.EthAssetETH,
|
||||
types.CompletedSuccess,
|
||||
200,
|
||||
make(chan types.Status),
|
||||
)
|
||||
err := info.StartTime.UnmarshalJSON([]byte("\"2023-02-20T17:29:43.471020297-05:00\""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -134,6 +134,5 @@ func (s *swapState) checkAndSetTimeouts(t1, t2 *big.Int) error {
|
||||
func (s *swapState) setTimeouts(t1, t2 *big.Int) {
|
||||
s.t1 = time.Unix(t1.Int64(), 0)
|
||||
s.t2 = time.Unix(t2.Int64(), 0)
|
||||
s.info.Timeout1 = &s.t1
|
||||
s.info.Timeout2 = &s.t2
|
||||
s.info.SetTimeouts(&s.t1, &s.t2)
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestSwapState_handleEvent_EventContractReady(t *testing.T) {
|
||||
|
||||
// runContractEventWatcher will trigger EventContractReady,
|
||||
// which will then set the next expected event to EventExit.
|
||||
for status := range s.info.StatusCh() {
|
||||
for status := range s.SwapManager().GetStatusChan(s.OfferID()) {
|
||||
if !status.IsOngoing() {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -84,12 +84,17 @@ func NewInstance(cfg *Config) (*Instance, error) {
|
||||
}
|
||||
|
||||
func (inst *Instance) checkForOngoingSwaps() error {
|
||||
swaps, err := inst.backend.SwapManager().GetOngoingSwaps()
|
||||
ongoingIDs, err := inst.backend.SwapManager().GetOngoingSwapOfferIDs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, s := range swaps {
|
||||
for _, offerID := range ongoingIDs {
|
||||
s, err := inst.backend.SwapManager().GetOngoingSwap(*offerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.Provides != coins.ProvidesXMR {
|
||||
continue
|
||||
}
|
||||
@@ -178,7 +183,7 @@ func (inst *Instance) createOngoingSwap(s *swap.Info) error {
|
||||
if err != nil {
|
||||
// we can ignore the error; if the key doesn't exist,
|
||||
// then no relayer was set for this swap.
|
||||
relayerInfo = &types.OfferExtra{}
|
||||
relayerInfo = types.NewOfferExtra(false)
|
||||
}
|
||||
|
||||
ss, err := newSwapStateFromOngoing(
|
||||
|
||||
@@ -52,7 +52,7 @@ func (s *swapState) HandleProtocolMessage(msg common.Message) error {
|
||||
|
||||
func (s *swapState) clearNextExpectedEvent(status types.Status) {
|
||||
s.nextExpectedEvent = EventNoneType
|
||||
s.info.SetStatus(status)
|
||||
s.updateStatus(status)
|
||||
}
|
||||
|
||||
func (s *swapState) setNextExpectedEvent(event EventType) error {
|
||||
@@ -72,7 +72,7 @@ func (s *swapState) setNextExpectedEvent(event EventType) error {
|
||||
panic("status corresponding to event cannot be UnknownStatus")
|
||||
}
|
||||
|
||||
s.info.SetStatus(status)
|
||||
s.updateStatus(status)
|
||||
err := s.Backend.SwapManager().WriteSwapToDB(s.info)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -10,12 +10,10 @@ import (
|
||||
|
||||
"github.com/ChainSafe/chaindb"
|
||||
|
||||
"github.com/athanorlabs/atomic-swap/common/types"
|
||||
|
||||
logging "github.com/ipfs/go-log"
|
||||
)
|
||||
|
||||
const statusChSize = 6 // the max number of stages a swap can potentially go through
|
||||
"github.com/athanorlabs/atomic-swap/common/types"
|
||||
)
|
||||
|
||||
var (
|
||||
log = logging.Logger("offers")
|
||||
@@ -49,13 +47,9 @@ func NewManager(dataDir string, db Database) (*Manager, error) {
|
||||
offers := make(map[types.Hash]*offerWithExtra)
|
||||
|
||||
for _, offer := range savedOffers {
|
||||
extra := &types.OfferExtra{
|
||||
StatusCh: make(chan types.Status, statusChSize),
|
||||
}
|
||||
|
||||
offers[offer.ID] = &offerWithExtra{
|
||||
offer: offer,
|
||||
extra: extra,
|
||||
extra: types.NewOfferExtra(false),
|
||||
}
|
||||
|
||||
log.Infof("loaded offer %s from database", offer.ID)
|
||||
@@ -83,10 +77,7 @@ func (m *Manager) GetOffer(id types.Hash) (*types.Offer, *types.OfferExtra, erro
|
||||
}
|
||||
|
||||
// AddOffer adds a new offer to the manager and returns its OffersExtra data
|
||||
func (m *Manager) AddOffer(
|
||||
offer *types.Offer,
|
||||
useRelayer bool,
|
||||
) (*types.OfferExtra, error) {
|
||||
func (m *Manager) AddOffer(offer *types.Offer, useRelayer bool) (*types.OfferExtra, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -101,10 +92,7 @@ func (m *Manager) AddOffer(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extra := &types.OfferExtra{
|
||||
StatusCh: make(chan types.Status, statusChSize),
|
||||
UseRelayer: useRelayer,
|
||||
}
|
||||
extra := types.NewOfferExtra(useRelayer)
|
||||
|
||||
m.offers[id] = &offerWithExtra{
|
||||
offer: offer,
|
||||
|
||||
@@ -112,9 +112,6 @@ func newSwapStateFromStart(
|
||||
// and we'll send our own after this function returns.
|
||||
// see HandleInitiateMessage().
|
||||
stage := types.KeysExchanged
|
||||
if offerExtra.StatusCh == nil {
|
||||
offerExtra.StatusCh = make(chan types.Status, 7)
|
||||
}
|
||||
|
||||
if offerExtra.UseRelayer {
|
||||
if err := b.RecoveryDB().PutSwapRelayerInfo(offer.ID, offerExtra); err != nil {
|
||||
@@ -146,7 +143,6 @@ func newSwapStateFromStart(
|
||||
offer.EthAsset,
|
||||
stage,
|
||||
moneroStartHeight,
|
||||
offerExtra.StatusCh,
|
||||
)
|
||||
|
||||
if err = b.SwapManager().AddSwap(info); err != nil {
|
||||
@@ -171,7 +167,8 @@ func newSwapStateFromStart(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
offerExtra.StatusCh <- stage
|
||||
s.SwapManager().PushNewStatus(offer.ID, stage)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -247,6 +244,7 @@ func checkIfAlreadyClaimed(
|
||||
func completeSwap(info *swap.Info, b backend.Backend, om *offers.Manager) error {
|
||||
// set swap to completed
|
||||
info.SetStatus(types.CompletedSuccess)
|
||||
b.SwapManager().PushNewStatus(info.OfferID, types.CompletedSuccess)
|
||||
err := b.SwapManager().CompleteOngoingSwap(info)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark swap %s as completed: %s", info.OfferID, err)
|
||||
@@ -452,6 +450,11 @@ func (s *swapState) SendKeysMessage() common.Message {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *swapState) updateStatus(status types.Status) {
|
||||
s.info.SetStatus(status)
|
||||
s.SwapManager().PushNewStatus(s.OfferID(), status)
|
||||
}
|
||||
|
||||
// ExpectedAmount returns the amount received, or expected to be received, at the end of the swap
|
||||
func (s *swapState) ExpectedAmount() *apd.Decimal {
|
||||
return s.info.ExpectedAmount
|
||||
|
||||
@@ -44,7 +44,7 @@ func newTestSwapStateAndDB(t *testing.T) (*Instance, *swapState, *offers.MockDat
|
||||
xmrmaker.backend,
|
||||
testPeerID,
|
||||
types.NewOffer("", new(apd.Decimal), new(apd.Decimal), new(coins.ExchangeRate), types.EthAssetETH),
|
||||
&types.OfferExtra{},
|
||||
types.NewOfferExtra(false),
|
||||
xmrmaker.offerManager,
|
||||
coins.MoneroToPiconero(coins.StrToDecimal("0.05")),
|
||||
desiredAmount,
|
||||
@@ -270,7 +270,7 @@ func TestSwapState_HandleProtocolMessage_NotifyETHLocked_timeout(t *testing.T) {
|
||||
|
||||
go s.runT1ExpirationHandler()
|
||||
|
||||
for status := range s.info.StatusCh() {
|
||||
for status := range s.SwapManager().GetStatusChan(s.OfferID()) {
|
||||
if status == types.CompletedSuccess {
|
||||
break
|
||||
} else if !status.IsOngoing() {
|
||||
@@ -319,7 +319,8 @@ func TestSwapState_handleRefund(t *testing.T) {
|
||||
|
||||
// runContractEventWatcher will trigger EventETHRefunded,
|
||||
// which will then set the next expected event to EventExit.
|
||||
for status := range s.info.StatusCh() {
|
||||
statusCh := s.SwapManager().GetStatusChan(s.info.OfferID)
|
||||
for status := range statusCh {
|
||||
if !status.IsOngoing() {
|
||||
break
|
||||
}
|
||||
@@ -376,7 +377,7 @@ func TestSwapState_Exit_Reclaim(t *testing.T) {
|
||||
|
||||
// runContractEventWatcher will trigger EventETHRefunded,
|
||||
// which will then set the next expected event to EventExit.
|
||||
for status := range s.info.StatusCh() {
|
||||
for status := range s.SwapManager().GetStatusChan(s.info.OfferID) {
|
||||
if !status.IsOngoing() {
|
||||
require.Equal(t, types.CompletedRefund.String(), status.String())
|
||||
break
|
||||
@@ -418,7 +419,7 @@ func TestSwapState_Exit_Success(t *testing.T) {
|
||||
max := coins.StrToDecimal("0.2")
|
||||
rate := coins.ToExchangeRate(coins.StrToDecimal("0.1"))
|
||||
s.offer = types.NewOffer(coins.ProvidesXMR, min, max, rate, types.EthAssetETH)
|
||||
s.info.SetStatus(types.CompletedSuccess)
|
||||
s.updateStatus(types.CompletedSuccess)
|
||||
err := s.Exit()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -441,7 +442,7 @@ func TestSwapState_Exit_Refunded(t *testing.T) {
|
||||
_, err := b.MakeOffer(s.offer, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.info.SetStatus(types.CompletedRefund)
|
||||
s.updateStatus(types.CompletedRefund)
|
||||
err = s.Exit()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -64,12 +64,17 @@ func NewInstance(cfg *Config) (*Instance, error) {
|
||||
}
|
||||
|
||||
func (inst *Instance) checkForOngoingSwaps() error {
|
||||
swaps, err := inst.backend.SwapManager().GetOngoingSwaps()
|
||||
ongoingIDs, err := inst.backend.SwapManager().GetOngoingSwapOfferIDs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, s := range swaps {
|
||||
for _, offerID := range ongoingIDs {
|
||||
s, err := inst.backend.SwapManager().GetOngoingSwap(*offerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.Provides != coins.ProvidesETH {
|
||||
continue
|
||||
}
|
||||
@@ -171,7 +176,7 @@ func (inst *Instance) createOngoingSwap(s *swap.Info) error {
|
||||
}
|
||||
|
||||
// completeSwap is called in the case where we find an ongoing swap in the db on startup,
|
||||
// and the swap already has the counterpary's swap secret stored.
|
||||
// and the swap already has the counterparty's swap secret stored.
|
||||
// In this case, we simply claim the XMR, as we have both secrets required.
|
||||
// It's unlikely for this case to ever be hit, unless the daemon was shut down in-between
|
||||
// us finding the counterparty's secret and claiming the XMR.
|
||||
|
||||
@@ -40,7 +40,7 @@ func (s *swapState) HandleProtocolMessage(msg common.Message) error {
|
||||
|
||||
func (s *swapState) clearNextExpectedEvent(status types.Status) {
|
||||
s.nextExpectedEvent = EventNoneType
|
||||
s.info.SetStatus(status)
|
||||
s.updateStatus(status)
|
||||
}
|
||||
|
||||
func (s *swapState) setNextExpectedEvent(event EventType) error {
|
||||
@@ -61,7 +61,7 @@ func (s *swapState) setNextExpectedEvent(event EventType) error {
|
||||
}
|
||||
|
||||
log.Debugf("setting status to %s", status)
|
||||
s.info.SetStatus(status)
|
||||
s.updateStatus(status)
|
||||
return s.Backend.SwapManager().WriteSwapToDB(s.info)
|
||||
}
|
||||
|
||||
|
||||
@@ -104,7 +104,6 @@ func newSwapStateFromStart(
|
||||
ethAsset types.EthAsset,
|
||||
) (*swapState, error) {
|
||||
stage := types.ExpectingKeys
|
||||
statusCh := make(chan types.Status, 16)
|
||||
|
||||
moneroStartNumber, err := b.XMRClient().GetHeight()
|
||||
if err != nil {
|
||||
@@ -136,7 +135,6 @@ func newSwapStateFromStart(
|
||||
ethAsset,
|
||||
stage,
|
||||
moneroStartNumber,
|
||||
statusCh,
|
||||
)
|
||||
if err = b.SwapManager().AddSwap(info); err != nil {
|
||||
return nil, err
|
||||
@@ -157,7 +155,8 @@ func newSwapStateFromStart(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
statusCh <- stage
|
||||
s.SwapManager().PushNewStatus(offerID, stage)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -318,6 +317,11 @@ func (s *swapState) SendKeysMessage() common.Message {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *swapState) updateStatus(status types.Status) {
|
||||
s.info.SetStatus(status)
|
||||
s.SwapManager().PushNewStatus(s.OfferID(), status)
|
||||
}
|
||||
|
||||
// ExpectedAmount returns the amount received, or expected to be received, at the end of the swap
|
||||
func (s *swapState) ExpectedAmount() *apd.Decimal {
|
||||
return s.info.ExpectedAmount
|
||||
|
||||
@@ -48,7 +48,7 @@ func setupSwapStateUntilETHLocked(t *testing.T) (*swapState, uint64) {
|
||||
// shutdown swap state, re-create from ongoing
|
||||
s.cancel()
|
||||
|
||||
rdb.EXPECT().GetCounterpartySwapKeys(s.info.OfferID).Return(
|
||||
rdb.EXPECT().GetCounterpartySwapKeys(s.OfferID()).Return(
|
||||
makerKeys.PublicKeyPair.SpendKey(),
|
||||
makerKeys.PrivateKeyPair.ViewKey(),
|
||||
nil,
|
||||
|
||||
@@ -252,7 +252,7 @@ func TestSwapState_HandleProtocolMessage_SendKeysMessage_Refund(t *testing.T) {
|
||||
require.Equal(t, xmrmakerKeysAndProof.PrivateKeyPair.ViewKey().String(), s.xmrmakerPrivateViewKey.String())
|
||||
|
||||
// ensure we refund before t1
|
||||
for status := range s.info.StatusCh() {
|
||||
for status := range s.SwapManager().GetStatusChan(s.OfferID()) {
|
||||
if status == types.CompletedRefund {
|
||||
// check this is before t1
|
||||
// TODO: remove the 10-second buffer, this is needed for now
|
||||
@@ -346,7 +346,7 @@ func TestSwapState_NotifyXMRLock_Refund(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, EventETHClaimedType, s.nextExpectedEvent)
|
||||
|
||||
for status := range s.info.StatusCh() {
|
||||
for status := range s.SwapManager().GetStatusChan(s.OfferID()) {
|
||||
if status == types.CompletedRefund {
|
||||
// check this is after t2
|
||||
require.Less(t, s.t2, time.Now())
|
||||
@@ -368,7 +368,7 @@ func TestExit_afterSendKeysMessage(t *testing.T) {
|
||||
s.nextExpectedEvent = EventKeysReceivedType
|
||||
err := s.Exit()
|
||||
require.NoError(t, err)
|
||||
info, err := s.SwapManager().GetPastSwap(s.info.OfferID)
|
||||
info, err := s.SwapManager().GetPastSwap(s.OfferID())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, types.CompletedAbort, info.Status)
|
||||
}
|
||||
@@ -395,7 +395,7 @@ func TestExit_afterNotifyXMRLock(t *testing.T) {
|
||||
err = s.Exit()
|
||||
require.NoError(t, err)
|
||||
|
||||
info, err := s.SwapManager().GetPastSwap(s.info.OfferID)
|
||||
info, err := s.SwapManager().GetPastSwap(s.OfferID())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, types.CompletedRefund, info.Status)
|
||||
}
|
||||
@@ -422,7 +422,7 @@ func TestExit_afterNotifyClaimed(t *testing.T) {
|
||||
err = s.Exit()
|
||||
require.NoError(t, err)
|
||||
|
||||
info, err := s.SwapManager().GetPastSwap(s.info.OfferID)
|
||||
info, err := s.SwapManager().GetPastSwap(s.OfferID())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, types.CompletedRefund, info.Status)
|
||||
}
|
||||
@@ -450,7 +450,7 @@ func TestExit_invalidNextMessageType(t *testing.T) {
|
||||
err = s.Exit()
|
||||
require.True(t, errors.Is(err, errUnexpectedEventType))
|
||||
|
||||
info, err := s.SwapManager().GetPastSwap(s.info.OfferID)
|
||||
info, err := s.SwapManager().GetPastSwap(s.OfferID())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, types.CompletedAbort, info.Status)
|
||||
}
|
||||
|
||||
@@ -5,19 +5,23 @@ package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ChainSafe/chaindb"
|
||||
"github.com/MarinX/monerorpc/wallet"
|
||||
"github.com/cockroachdb/apd/v3"
|
||||
ethcommon "github.com/ethereum/go-ethereum/common"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
libp2ptest "github.com/libp2p/go-libp2p/core/test"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/athanorlabs/atomic-swap/coins"
|
||||
"github.com/athanorlabs/atomic-swap/common"
|
||||
"github.com/athanorlabs/atomic-swap/common/types"
|
||||
mcrypto "github.com/athanorlabs/atomic-swap/crypto/monero"
|
||||
"github.com/athanorlabs/atomic-swap/db"
|
||||
"github.com/athanorlabs/atomic-swap/ethereum/extethclient"
|
||||
"github.com/athanorlabs/atomic-swap/net/message"
|
||||
"github.com/athanorlabs/atomic-swap/protocol/swap"
|
||||
@@ -67,32 +71,21 @@ func (*mockNet) CloseProtocolStream(_ types.Hash) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
type mockSwapManager struct{}
|
||||
func mockSwapManager(t *testing.T) swap.Manager {
|
||||
db, err := db.NewDatabase(&chaindb.Config{
|
||||
DataDir: t.TempDir(),
|
||||
InMemory: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
func (*mockSwapManager) WriteSwapToDB(_ *swap.Info) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*mockSwapManager) GetPastIDs() ([]types.Hash, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*mockSwapManager) GetPastSwap(_ types.Hash) (*swap.Info, error) {
|
||||
return &swap.Info{}, nil
|
||||
}
|
||||
|
||||
func (*mockSwapManager) GetOngoingSwaps() ([]*swap.Info, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*mockSwapManager) GetOngoingSwap(id types.Hash) (swap.Info, error) {
|
||||
statusCh := make(chan types.Status, 1)
|
||||
statusCh <- types.CompletedSuccess
|
||||
sm, err := swap.NewManager(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
one := apd.New(1, 0)
|
||||
return *swap.NewInfo(
|
||||
|
||||
sm.AddSwap(swap.NewInfo(
|
||||
testPeerID,
|
||||
id,
|
||||
testSwapID,
|
||||
coins.ProvidesETH,
|
||||
one,
|
||||
one,
|
||||
@@ -100,20 +93,11 @@ func (*mockSwapManager) GetOngoingSwap(id types.Hash) (swap.Info, error) {
|
||||
types.EthAssetETH,
|
||||
types.CompletedSuccess,
|
||||
1,
|
||||
statusCh,
|
||||
), nil
|
||||
}
|
||||
))
|
||||
|
||||
func (*mockSwapManager) AddSwap(_ *swap.Info) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
sm.PushNewStatus(testSwapID, types.CompletedSuccess)
|
||||
|
||||
func (*mockSwapManager) CompleteOngoingSwap(_ *swap.Info) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*mockSwapManager) HasOngoingSwap(_ types.Hash) bool {
|
||||
panic("not implemented")
|
||||
return sm
|
||||
}
|
||||
|
||||
type mockXMRTaker struct{}
|
||||
@@ -153,10 +137,7 @@ func (m *mockXMRMaker) GetOngoingSwapState(_ types.Hash) common.SwapState {
|
||||
}
|
||||
|
||||
func (*mockXMRMaker) MakeOffer(_ *types.Offer, _ bool) (*types.OfferExtra, error) {
|
||||
offerExtra := &types.OfferExtra{
|
||||
StatusCh: make(chan types.Status, 1),
|
||||
}
|
||||
offerExtra.StatusCh <- types.CompletedSuccess
|
||||
offerExtra := types.NewOfferExtra(false)
|
||||
return offerExtra, nil
|
||||
}
|
||||
|
||||
@@ -193,12 +174,12 @@ func (*mockSwapState) OfferID() types.Hash {
|
||||
}
|
||||
|
||||
type mockProtocolBackend struct {
|
||||
sm *mockSwapManager
|
||||
sm swap.Manager
|
||||
}
|
||||
|
||||
func newMockProtocolBackend() *mockProtocolBackend {
|
||||
func newMockProtocolBackend(t *testing.T) *mockProtocolBackend {
|
||||
return &mockProtocolBackend{
|
||||
sm: new(mockSwapManager),
|
||||
sm: mockSwapManager(t),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
39
rpc/net.go
39
rpc/net.go
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/athanorlabs/atomic-swap/common/rpctypes"
|
||||
"github.com/athanorlabs/atomic-swap/common/types"
|
||||
"github.com/athanorlabs/atomic-swap/net/message"
|
||||
"github.com/athanorlabs/atomic-swap/protocol/swap"
|
||||
)
|
||||
|
||||
const defaultSearchTime = time.Second * 12
|
||||
@@ -37,12 +38,12 @@ type NetService struct {
|
||||
net Net
|
||||
xmrtaker XMRTaker
|
||||
xmrmaker XMRMaker
|
||||
sm SwapManager
|
||||
sm swap.Manager
|
||||
isBootnode bool
|
||||
}
|
||||
|
||||
// NewNetService ...
|
||||
func NewNetService(net Net, xmrtaker XMRTaker, xmrmaker XMRMaker, sm SwapManager, isBootnode bool) *NetService {
|
||||
func NewNetService(net Net, xmrtaker XMRTaker, xmrmaker XMRMaker, sm swap.Manager, isBootnode bool) *NetService {
|
||||
return &NetService{
|
||||
net: net,
|
||||
xmrtaker: xmrtaker,
|
||||
@@ -160,7 +161,7 @@ func (s *NetService) TakeOffer(
|
||||
return errUnsupportedForBootnode
|
||||
}
|
||||
|
||||
_, err := s.takeOffer(req.PeerID, req.OfferID, req.ProvidesAmount)
|
||||
err := s.takeOffer(req.PeerID, req.OfferID, req.ProvidesAmount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -168,13 +169,10 @@ func (s *NetService) TakeOffer(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NetService) takeOffer(makerPeerID peer.ID, offerID types.Hash, providesAmount *apd.Decimal) (
|
||||
<-chan types.Status,
|
||||
error,
|
||||
) {
|
||||
func (s *NetService) takeOffer(makerPeerID peer.ID, offerID types.Hash, providesAmount *apd.Decimal) error {
|
||||
queryResp, err := s.net.Query(makerPeerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
var offer *types.Offer
|
||||
@@ -185,12 +183,12 @@ func (s *NetService) takeOffer(makerPeerID peer.ID, offerID types.Hash, provides
|
||||
}
|
||||
}
|
||||
if offer == nil {
|
||||
return nil, errNoOfferWithID
|
||||
return errNoOfferWithID
|
||||
}
|
||||
|
||||
swapState, err := s.xmrtaker.InitiateProtocol(makerPeerID, providesAmount, offer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initiate protocol: %w", err)
|
||||
return fmt.Errorf("failed to initiate protocol: %w", err)
|
||||
}
|
||||
|
||||
skm := swapState.SendKeysMessage().(*message.SendKeysMessage)
|
||||
@@ -201,15 +199,10 @@ func (s *NetService) takeOffer(makerPeerID peer.ID, offerID types.Hash, provides
|
||||
if err = swapState.Exit(); err != nil {
|
||||
log.Warnf("Swap exit failure: %s", err)
|
||||
}
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := s.sm.GetOngoingSwap(offerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return info.StatusCh(), nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// TakeOfferSyncResponse ...
|
||||
@@ -228,7 +221,7 @@ func (s *NetService) TakeOfferSync(
|
||||
return errUnsupportedForBootnode
|
||||
}
|
||||
|
||||
if _, err := s.takeOffer(req.PeerID, req.OfferID, req.ProvidesAmount); err != nil {
|
||||
if err := s.takeOffer(req.PeerID, req.OfferID, req.ProvidesAmount); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -263,7 +256,7 @@ func (s *NetService) MakeOffer(
|
||||
return errUnsupportedForBootnode
|
||||
}
|
||||
|
||||
offerResp, _, err := s.makeOffer(req)
|
||||
offerResp, err := s.makeOffer(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -271,7 +264,7 @@ func (s *NetService) MakeOffer(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NetService) makeOffer(req *rpctypes.MakeOfferRequest) (*rpctypes.MakeOfferResponse, *types.OfferExtra, error) {
|
||||
func (s *NetService) makeOffer(req *rpctypes.MakeOfferRequest) (*rpctypes.MakeOfferResponse, error) {
|
||||
offer := types.NewOffer(
|
||||
coins.ProvidesXMR,
|
||||
req.MinAmount,
|
||||
@@ -280,13 +273,13 @@ func (s *NetService) makeOffer(req *rpctypes.MakeOfferRequest) (*rpctypes.MakeOf
|
||||
req.EthAsset,
|
||||
)
|
||||
|
||||
offerExtra, err := s.xmrmaker.MakeOffer(offer, req.UseRelayer)
|
||||
_, err := s.xmrmaker.MakeOffer(offer, req.UseRelayer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rpctypes.MakeOfferResponse{
|
||||
PeerID: s.net.PeerID(),
|
||||
OfferID: offer.ID,
|
||||
}, offerExtra, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func TestNet_Discover(t *testing.T) {
|
||||
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager), false)
|
||||
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, mockSwapManager(t), false)
|
||||
|
||||
req := &rpctypes.DiscoverRequest{
|
||||
Provides: "",
|
||||
@@ -28,7 +28,7 @@ func TestNet_Discover(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNet_Query(t *testing.T) {
|
||||
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager), false)
|
||||
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, mockSwapManager(t), false)
|
||||
|
||||
req := &rpctypes.QueryPeerRequest{
|
||||
PeerID: "12D3KooWDqCzbjexHEa8Rut7bzxHFpRMZyDRW1L6TGkL1KY24JH5",
|
||||
@@ -42,7 +42,7 @@ func TestNet_Query(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNet_TakeOffer(t *testing.T) {
|
||||
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager), false)
|
||||
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, mockSwapManager(t), false)
|
||||
|
||||
req := &rpctypes.TakeOfferRequest{
|
||||
PeerID: "12D3KooWDqCzbjexHEa8Rut7bzxHFpRMZyDRW1L6TGkL1KY24JH5",
|
||||
@@ -55,7 +55,7 @@ func TestNet_TakeOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNet_TakeOfferSync(t *testing.T) {
|
||||
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, new(mockSwapManager), false)
|
||||
ns := NewNetService(new(mockNet), new(mockXMRTaker), nil, mockSwapManager(t), false)
|
||||
|
||||
req := &rpctypes.TakeOfferRequest{
|
||||
PeerID: "12D3KooWDqCzbjexHEa8Rut7bzxHFpRMZyDRW1L6TGkL1KY24JH5",
|
||||
|
||||
@@ -30,7 +30,7 @@ type Metrics struct {
|
||||
|
||||
func pastSwapsMetric(
|
||||
factory promauto.Factory,
|
||||
swapManager SwapManager,
|
||||
swapManager swap.Manager,
|
||||
status swap.Status,
|
||||
statusLabel string,
|
||||
) prometheus.GaugeFunc {
|
||||
@@ -108,7 +108,7 @@ func SetupMetrics(
|
||||
Help: "The number of ongoing swaps",
|
||||
},
|
||||
func() float64 {
|
||||
swaps, err := swapManager.GetOngoingSwaps()
|
||||
swaps, err := swapManager.GetOngoingSwapsSnapshot()
|
||||
if err != nil {
|
||||
return float64(-1)
|
||||
}
|
||||
|
||||
@@ -263,6 +263,3 @@ type XMRMaker interface {
|
||||
ClearOffers([]types.Hash) error
|
||||
GetMoneroBalance() (*mcrypto.Address, *wallet.GetBalanceResponse, error)
|
||||
}
|
||||
|
||||
// SwapManager ...
|
||||
type SwapManager = swap.Manager
|
||||
|
||||
14
rpc/swap.go
14
rpc/swap.go
@@ -24,7 +24,7 @@ import (
|
||||
// SwapService handles information about ongoing or past swaps.
|
||||
type SwapService struct {
|
||||
ctx context.Context
|
||||
sm SwapManager
|
||||
sm swap.Manager
|
||||
xmrtaker XMRTaker
|
||||
xmrmaker XMRMaker
|
||||
net Net
|
||||
@@ -35,7 +35,7 @@ type SwapService struct {
|
||||
// NewSwapService ...
|
||||
func NewSwapService(
|
||||
ctx context.Context,
|
||||
sm SwapManager,
|
||||
sm swap.Manager,
|
||||
xmrtaker XMRTaker,
|
||||
xmrmaker XMRMaker,
|
||||
net Net,
|
||||
@@ -165,17 +165,17 @@ func (s *SwapService) GetOngoing(_ *http.Request, req *GetOngoingRequest, resp *
|
||||
)
|
||||
|
||||
if req.OfferID == nil {
|
||||
swaps, err = s.sm.GetOngoingSwaps()
|
||||
swaps, err = s.sm.GetOngoingSwapsSnapshot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
info, err := s.sm.GetOngoingSwap(*req.OfferID) //nolint:govet
|
||||
info, err := s.sm.GetOngoingSwapSnapshot(*req.OfferID) //nolint:govet
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
swaps = []*swap.Info{&info}
|
||||
swaps = []*swap.Info{info}
|
||||
}
|
||||
|
||||
resp.Swaps = make([]*OngoingSwap, len(swaps))
|
||||
@@ -221,7 +221,7 @@ type GetStatusResponse struct {
|
||||
|
||||
// GetStatus returns the status of the ongoing swap, if there is one.
|
||||
func (s *SwapService) GetStatus(_ *http.Request, req *GetStatusRequest, resp *GetStatusResponse) error {
|
||||
info, err := s.sm.GetOngoingSwap(req.ID)
|
||||
info, err := s.sm.GetOngoingSwapSnapshot(req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -272,7 +272,7 @@ type CancelResponse struct {
|
||||
|
||||
// Cancel attempts to cancel the currently ongoing swap, if there is one.
|
||||
func (s *SwapService) Cancel(_ *http.Request, req *CancelRequest, resp *CancelResponse) error {
|
||||
info, err := s.sm.GetOngoingSwap(req.OfferID)
|
||||
info, err := s.sm.GetOngoingSwapSnapshot(req.OfferID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get ongoing swap: %w", err)
|
||||
}
|
||||
|
||||
87
rpc/ws.go
87
rpc/ws.go
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/athanorlabs/atomic-swap/common/types"
|
||||
"github.com/athanorlabs/atomic-swap/common/vjson"
|
||||
mcrypto "github.com/athanorlabs/atomic-swap/crypto/monero"
|
||||
"github.com/athanorlabs/atomic-swap/protocol/swap"
|
||||
|
||||
ethcommon "github.com/ethereum/go-ethereum/common"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -29,13 +30,13 @@ func checkOriginFunc(_ *http.Request) bool {
|
||||
|
||||
type wsServer struct {
|
||||
ctx context.Context
|
||||
sm SwapManager
|
||||
sm swap.Manager
|
||||
ns *NetService
|
||||
backend ProtocolBackend
|
||||
taker XMRTaker
|
||||
}
|
||||
|
||||
func newWsServer(ctx context.Context, sm SwapManager, ns *NetService, backend ProtocolBackend,
|
||||
func newWsServer(ctx context.Context, sm swap.Manager, ns *NetService, backend ProtocolBackend,
|
||||
taker XMRTaker) *wsServer {
|
||||
s := &wsServer{
|
||||
ctx: ctx,
|
||||
@@ -142,12 +143,12 @@ func (s *wsServer) handleRequest(conn *websocket.Conn, req *rpctypes.Request) er
|
||||
return fmt.Errorf("failed to unmarshal parameters: %w", err)
|
||||
}
|
||||
|
||||
ch, err := s.ns.takeOffer(params.PeerID, params.OfferID, params.ProvidesAmount)
|
||||
err := s.ns.takeOffer(params.PeerID, params.OfferID, params.ProvidesAmount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.subscribeTakeOffer(s.ctx, conn, ch)
|
||||
return s.subscribeSwapStatus(s.ctx, conn, params.OfferID)
|
||||
case rpctypes.SubscribeMakeOffer:
|
||||
if s.ns == nil {
|
||||
return errNamespaceNotEnabled
|
||||
@@ -158,12 +159,12 @@ func (s *wsServer) handleRequest(conn *websocket.Conn, req *rpctypes.Request) er
|
||||
return fmt.Errorf("failed to unmarshal parameters: %w", err)
|
||||
}
|
||||
|
||||
offerResp, offerExtra, err := s.ns.makeOffer(params)
|
||||
offerResp, err := s.ns.makeOffer(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.subscribeMakeOffer(s.ctx, conn, offerResp.OfferID, offerExtra)
|
||||
return s.subscribeMakeOffer(s.ctx, conn, offerResp.OfferID)
|
||||
default:
|
||||
return errInvalidMethod
|
||||
}
|
||||
@@ -240,8 +241,22 @@ func (s *wsServer) handleSigner(
|
||||
}
|
||||
}
|
||||
|
||||
func (s *wsServer) subscribeTakeOffer(ctx context.Context, conn *websocket.Conn,
|
||||
statusCh <-chan types.Status) error {
|
||||
func (s *wsServer) subscribeMakeOffer(
|
||||
ctx context.Context,
|
||||
conn *websocket.Conn,
|
||||
offerID types.Hash,
|
||||
) error {
|
||||
resp := &rpctypes.MakeOfferResponse{
|
||||
PeerID: s.ns.net.PeerID(),
|
||||
OfferID: offerID,
|
||||
}
|
||||
|
||||
if err := writeResponse(conn, resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
statusCh := s.backend.SwapManager().GetStatusChan(offerID)
|
||||
|
||||
for {
|
||||
select {
|
||||
case status, ok := <-statusCh:
|
||||
@@ -266,53 +281,19 @@ func (s *wsServer) subscribeTakeOffer(ctx context.Context, conn *websocket.Conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *wsServer) subscribeMakeOffer(ctx context.Context, conn *websocket.Conn,
|
||||
offerID types.Hash, offerExtra *types.OfferExtra) error {
|
||||
resp := &rpctypes.MakeOfferResponse{
|
||||
PeerID: s.ns.net.PeerID(),
|
||||
OfferID: offerID,
|
||||
// subscribeSwapStatus writes the swap's status transitions to the websockets
|
||||
// connection when the state changes. When the swap completes, it writes the
|
||||
// final status and then closes the connection. This method is not intended for
|
||||
// simultaneous requests on the same swap. If more than one request is made
|
||||
// (including calls to net_[make|take]OfferAndSubscribe), only one of the
|
||||
// websocket connections will see any individual state transition.
|
||||
func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn, offerID types.Hash) error {
|
||||
statusCh := s.backend.SwapManager().GetStatusChan(offerID)
|
||||
|
||||
if !s.sm.HasOngoingSwap(offerID) {
|
||||
return s.writeSwapExitStatus(conn, offerID)
|
||||
}
|
||||
|
||||
if err := writeResponse(conn, resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case status, ok := <-offerExtra.StatusCh:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp := &rpctypes.SubscribeSwapStatusResponse{
|
||||
Status: status,
|
||||
}
|
||||
|
||||
if err := writeResponse(conn, resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !status.IsOngoing() {
|
||||
return nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// subscribeSwapStatus writes the swap's stage to the connection every time it updates.
|
||||
// when the swap completes, it writes the final status then closes the connection.
|
||||
// example: `{"jsonrpc":"2.0", "method":"swap_subscribeStatus", "params": {"id": 0}, "id": 0}`
|
||||
func (s *wsServer) subscribeSwapStatus(ctx context.Context, conn *websocket.Conn, id types.Hash) error {
|
||||
// we can ignore the error here, since the error will only be if the swap cannot be found
|
||||
// as ongoing, in which case `writeSwapExitStatus` will look for it in the past swaps.
|
||||
info, err := s.sm.GetOngoingSwap(id)
|
||||
if err != nil {
|
||||
return s.writeSwapExitStatus(conn, id)
|
||||
}
|
||||
|
||||
statusCh := info.StatusCh()
|
||||
for {
|
||||
select {
|
||||
case status, ok := <-statusCh:
|
||||
|
||||
@@ -26,7 +26,7 @@ var (
|
||||
testTimeout = time.Second * 5
|
||||
)
|
||||
|
||||
func newServer(t *testing.T) *Server {
|
||||
func newServer(t *testing.T) (*Server, *Config) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cfg := &Config{
|
||||
@@ -34,7 +34,7 @@ func newServer(t *testing.T) *Server {
|
||||
Env: common.Development,
|
||||
Address: "127.0.0.1:0", // OS assigned port
|
||||
Net: new(mockNet),
|
||||
ProtocolBackend: newMockProtocolBackend(),
|
||||
ProtocolBackend: newMockProtocolBackend(t),
|
||||
XMRTaker: new(mockXMRTaker),
|
||||
XMRMaker: new(mockXMRMaker),
|
||||
Namespaces: AllNamespaces(),
|
||||
@@ -60,11 +60,11 @@ func newServer(t *testing.T) *Server {
|
||||
wg.Wait() // wait for the server to exit
|
||||
})
|
||||
|
||||
return s
|
||||
return s, cfg
|
||||
}
|
||||
|
||||
func TestSubscribeSwapStatus(t *testing.T) {
|
||||
s := newServer(t)
|
||||
s, _ := newServer(t)
|
||||
|
||||
c, err := wsclient.NewWsClient(s.ctx, s.WsURL())
|
||||
require.NoError(t, err)
|
||||
@@ -81,7 +81,7 @@ func TestSubscribeSwapStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSubscribeMakeOffer(t *testing.T) {
|
||||
s := newServer(t)
|
||||
s, cfg := newServer(t)
|
||||
|
||||
c, err := wsclient.NewWsClient(s.ctx, s.WsURL())
|
||||
require.NoError(t, err)
|
||||
@@ -92,6 +92,9 @@ func TestSubscribeMakeOffer(t *testing.T) {
|
||||
offerResp, ch, err := c.MakeOfferAndSubscribe(min, max, exRate, types.EthAssetETH, false)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, offerResp.OfferID, testSwapID)
|
||||
|
||||
cfg.ProtocolBackend.SwapManager().PushNewStatus(offerResp.OfferID, types.CompletedSuccess)
|
||||
|
||||
select {
|
||||
case status := <-ch:
|
||||
require.Equal(t, types.CompletedSuccess, status)
|
||||
@@ -101,7 +104,7 @@ func TestSubscribeMakeOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSubscribeTakeOffer(t *testing.T) {
|
||||
s := newServer(t)
|
||||
s, _ := newServer(t)
|
||||
|
||||
cliCtx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(func() {
|
||||
|
||||
Reference in New Issue
Block a user