check that XMR amount peer will send is same as expected (#105)

This commit is contained in:
noot
2022-03-24 10:55:21 -04:00
committed by GitHub
parent 5da8a30630
commit 8bb6499a90
12 changed files with 64 additions and 59 deletions

View File

@@ -15,7 +15,7 @@ type Type byte
const (
QueryResponseType Type = iota //nolint
SendKeysType
NotifyETHLockedType // TODO: rename to NotifyETHLockType
NotifyETHLockedType
NotifyXMRLockType
NotifyReadyType
NotifyClaimedType

View File

@@ -93,12 +93,12 @@ func (s *swapState) checkMessageType(msg net.Message) error {
}
func (s *swapState) handleSendKeysMessage(msg *net.SendKeysMessage) (net.Message, error) {
// TODO: get user to confirm amount they will receive!!
s.info.SetReceivedAmount(msg.ProvidedAmount)
log.Infof(color.New(color.Bold).Sprintf("you will be receiving %v XMR", msg.ProvidedAmount))
exchangeRate := msg.ProvidedAmount / s.info.ProvidedAmount()
s.info.SetExchangeRate(types.ExchangeRate(exchangeRate))
if msg.ProvidedAmount < s.info.ReceivedAmount() {
return nil, fmt.Errorf("receiving amount is not the same as expected: got %v, expected %v",
msg.ProvidedAmount,
s.info.ReceivedAmount(),
)
}
if msg.PublicSpendKey == "" || msg.PrivateViewKey == "" {
return nil, errMissingKeys
@@ -128,6 +128,8 @@ func (s *swapState) handleSendKeysMessage(msg *net.SendKeysMessage) (net.Message
return nil, err
}
log.Infof(color.New(color.Bold).Sprintf("you will be receiving %v XMR", msg.ProvidedAmount))
s.setBobKeys(sk, vk, secp256k1Pub)
err = s.lockETH(s.providedAmountInWei())
if err != nil {
@@ -150,6 +152,8 @@ func (s *swapState) handleSendKeysMessage(msg *net.SendKeysMessage) (net.Message
const timeoutBuffer = time.Second * 5
until := time.Until(s.t0)
log.Debugf("time until refund: %ds", until.Seconds())
select {
case <-s.ctx.Done():
return
@@ -170,6 +174,21 @@ func (s *swapState) handleSendKeysMessage(msg *net.SendKeysMessage) (net.Message
log.Infof("got our ETH back: tx hash=%s", txhash)
if s == nil {
log.Error("swap state is nil")
return
}
if s.alice == nil {
log.Error("s.alice is nil")
return
}
if s.alice.net == nil {
log.Error("s.alice.net is nil")
return
}
// send NotifyRefund msg
if err := s.alice.net.SendSwapMessage(&message.NotifyRefund{
TxHash: txhash.String(),

View File

@@ -17,15 +17,19 @@ func (a *Instance) Provides() types.ProvidesCoin {
// InitiateProtocol is called when an RPC call is made from the user to initiate a swap.
// The input units are ether that we will provide.
func (a *Instance) InitiateProtocol(providesAmount float64) (common.SwapState, error) {
if err := a.initiate(common.EtherToWei(providesAmount)); err != nil {
func (a *Instance) InitiateProtocol(providesAmount float64, offer *types.Offer) (common.SwapState, error) {
receivedAmount := offer.ExchangeRate.ToXMR(providesAmount)
err := a.initiate(common.EtherToWei(providesAmount), common.MoneroToPiconero(receivedAmount),
offer.ExchangeRate)
if err != nil {
return nil, err
}
return a.swapState, nil
}
func (a *Instance) initiate(providesAmount common.EtherAmount) error {
func (a *Instance) initiate(providesAmount common.EtherAmount, receivedAmount common.MoneroAmount,
exchangeRate types.ExchangeRate) error {
a.swapMu.Lock()
defer a.swapMu.Unlock()
@@ -43,7 +47,8 @@ func (a *Instance) initiate(providesAmount common.EtherAmount) error {
return errors.New("balance lower than amount to be provided")
}
a.swapState, err = newSwapState(a, pcommon.GetSwapInfoFilepath(a.basepath), providesAmount)
a.swapState, err = newSwapState(a, pcommon.GetSwapInfoFilepath(a.basepath), providesAmount,
receivedAmount, exchangeRate)
if err != nil {
return err
}

View File

@@ -3,12 +3,16 @@ package alice
import (
"testing"
"github.com/noot/atomic-swap/common/types"
"github.com/stretchr/testify/require"
)
func TestAlice_InitiateProtocol(t *testing.T) {
a := newTestAlice(t)
s, err := a.InitiateProtocol(3.33)
s, err := a.InitiateProtocol(3.33, &types.Offer{
ExchangeRate: 1,
})
require.NoError(t, err)
require.Equal(t, a.swapState, s)
}

View File

@@ -2,6 +2,7 @@ package alice
import (
"testing"
"time"
"github.com/noot/atomic-swap/common"
@@ -11,6 +12,7 @@ import (
func newTestRecoveryState(t *testing.T) *recoveryState {
inst, s := newTestInstance(t)
inst.swapTimeout = time.Second * 10
akp, err := generateKeys()
require.NoError(t, err)

View File

@@ -66,7 +66,8 @@ type swapState struct {
claimedCh chan struct{}
}
func newSwapState(a *Instance, infofile string, providesAmount common.EtherAmount) (*swapState, error) {
func newSwapState(a *Instance, infofile string, providesAmount common.EtherAmount,
receivedAmount common.MoneroAmount, exhangeRate types.ExchangeRate) (*swapState, error) {
txOpts, err := bind.NewKeyedTransactorWithChainID(a.ethPrivKey, a.chainID)
if err != nil {
return nil, err
@@ -78,7 +79,8 @@ func newSwapState(a *Instance, infofile string, providesAmount common.EtherAmoun
stage := types.ExpectingKeys
statusCh := make(chan types.Status, 16)
statusCh <- stage
info := pswap.NewInfo(types.ProvidesETH, providesAmount.AsEther(), 0, 0, stage, statusCh)
info := pswap.NewInfo(types.ProvidesETH, providesAmount.AsEther(), receivedAmount.AsMonero(),
exhangeRate, stage, statusCh)
if err := a.swapManager.AddSwap(info); err != nil {
return nil, err
}
@@ -379,7 +381,7 @@ func (s *swapState) lockETH(amount common.EtherAmount) error {
tx, err := s.alice.contract.NewSwap(s.txOpts,
cmtBob, cmtAlice, s.bobAddress, big.NewInt(int64(s.alice.swapTimeout.Seconds())))
if err != nil {
return fmt.Errorf("failed to deploy Swap.sol: %w", err)
return fmt.Errorf("failed to instantiate swap on-chain: %w", err)
}
log.Debugf("instantiating swap on-chain: amount=%s txHash=%s", amount, tx.Hash())

View File

@@ -71,9 +71,8 @@ func newTestAlice(t *testing.T) *Instance {
func newTestInstance(t *testing.T) (*Instance, *swapState) {
alice := newTestAlice(t)
swapState, err := newSwapState(alice, infofile, common.NewEtherAmount(1))
swapState, err := newSwapState(alice, infofile, common.NewEtherAmount(1), common.MoneroAmount(0), 1)
require.NoError(t, err)
swapState.info.SetReceivedAmount(1)
return alice, swapState
}
@@ -173,7 +172,6 @@ func TestSwapState_NotifyXMRLock(t *testing.T) {
err = s.lockETH(common.NewEtherAmount(1))
require.NoError(t, err)
s.info.SetReceivedAmount(0)
kp := mcrypto.SumSpendAndViewKeys(bobKeysAndProof.PublicKeyPair, s.pubkeys)
xmrAddr := kp.Address(common.Mainnet)
@@ -209,7 +207,6 @@ func TestSwapState_NotifyXMRLock_Refund(t *testing.T) {
err = s.lockETH(common.NewEtherAmount(1))
require.NoError(t, err)
s.info.SetReceivedAmount(0)
kp := mcrypto.SumSpendAndViewKeys(bobKeysAndProof.PublicKeyPair, s.pubkeys)
xmrAddr := kp.Address(common.Mainnet)
@@ -246,6 +243,7 @@ func TestSwapState_NotifyXMRLock_Refund(t *testing.T) {
func TestSwapState_NotifyClaimed(t *testing.T) {
_, s := newTestInstance(t)
defer s.cancel()
s.alice.swapTimeout = time.Minute * 2
s.alice.client = monero.NewClient(common.DefaultBobMoneroEndpoint)
err := s.alice.client.OpenWallet("test-wallet", "")
@@ -269,7 +267,7 @@ func TestSwapState_NotifyClaimed(t *testing.T) {
require.NoError(t, err)
require.False(t, done)
require.NotNil(t, resp)
require.Equal(t, defaultTimeoutDuration, s.t1.Sub(s.t0))
require.Equal(t, time.Minute*2, s.t1.Sub(s.t0))
require.Equal(t, msg.PublicSpendKey, s.bobPublicSpendKey.Hex())
require.Equal(t, msg.PrivateViewKey, s.bobPrivateViewKey.Hex())
@@ -282,7 +280,6 @@ func TestSwapState_NotifyClaimed(t *testing.T) {
_ = daemonClient.GenerateBlocks(bobAddr.Address, 60)
amt := common.MoneroAmount(1)
s.info.SetReceivedAmount(amt.AsMonero())
kp := mcrypto.SumSpendAndViewKeys(s.pubkeys, s.pubkeys)
xmrAddr := kp.Address(common.Mainnet)

View File

@@ -72,7 +72,7 @@ func newTestBob(t *testing.T) *Instance {
bobAddr, err := bob.client.GetAddress(0)
require.NoError(t, err)
_ = bob.daemonClient.GenerateBlocks(bobAddr.Address, 256)
_ = bob.daemonClient.GenerateBlocks(bobAddr.Address, 512)
err = bob.client.Refresh()
require.NoError(t, err)
return bob

View File

@@ -70,16 +70,6 @@ func (i *Info) StatusCh() <-chan types.Status {
return i.statusCh
}
// SetReceivedAmount ...
func (i *Info) SetReceivedAmount(a float64) {
i.receivedAmount = a
}
// SetExchangeRate ...
func (i *Info) SetExchangeRate(r types.ExchangeRate) {
i.exchangeRate = r
}
// SetStatus ...
func (i *Info) SetStatus(s Status) {
if i == nil {

View File

@@ -161,10 +161,14 @@ func (s *NetService) takeOffer(multiaddr, offerID string,
return 0, nil, "", err
}
var found bool
for _, offer := range queryResp.Offers {
if offer.GetID().String() == offerID {
var (
found bool
offer *types.Offer
)
for _, maybeOffer := range queryResp.Offers {
if maybeOffer.GetID().String() == offerID {
found = true
offer = maybeOffer
break
}
}
@@ -173,7 +177,7 @@ func (s *NetService) takeOffer(multiaddr, offerID string,
return 0, nil, "", errors.New("peer does not have offer with given ID")
}
swapState, err := s.alice.InitiateProtocol(providesAmount)
swapState, err := s.alice.InitiateProtocol(providesAmount, offer)
if err != nil {
return 0, nil, "", err
}
@@ -210,31 +214,13 @@ type TakeOfferSyncResponse struct {
// It synchronously waits until the swap is completed before returning its status.
func (s *NetService) TakeOfferSync(_ *http.Request, req *TakeOfferRequest,
resp *TakeOfferSyncResponse) error {
swapState, err := s.alice.InitiateProtocol(req.ProvidesAmount)
id, _, infofile, err := s.takeOffer(req.Multiaddr, req.OfferID, req.ProvidesAmount)
if err != nil {
return err
}
skm, err := swapState.SendKeysMessage()
if err != nil {
return err
}
skm.OfferID = req.OfferID
skm.ProvidedAmount = req.ProvidesAmount
who, err := net.StringToAddrInfo(req.Multiaddr)
if err != nil {
return err
}
if err = s.net.Initiate(who, skm, swapState); err != nil {
_ = swapState.Exit()
return err
}
resp.ID = swapState.ID()
resp.InfoFile = swapState.InfoFile()
resp.ID = id
resp.InfoFile = infofile
const checkSwapSleepDuration = time.Millisecond * 100

View File

@@ -114,7 +114,7 @@ type Protocol interface {
// Alice ...
type Alice interface {
Protocol
InitiateProtocol(providesAmount float64) (common.SwapState, error)
InitiateProtocol(providesAmount float64, offer *types.Offer) (common.SwapState, error)
Refund() (ethcommon.Hash, error)
SetSwapTimeout(timeout time.Duration)
}

View File

@@ -86,7 +86,7 @@ func (*mockAlice) SetGasPrice(gasPrice uint64) {}
func (*mockAlice) GetOngoingSwapState() common.SwapState {
return new(mockSwapState)
}
func (*mockAlice) InitiateProtocol(providesAmount float64) (common.SwapState, error) {
func (*mockAlice) InitiateProtocol(providesAmount float64, _ *types.Offer) (common.SwapState, error) {
return new(mockSwapState), nil
}
func (*mockAlice) Refund() (ethcommon.Hash, error) {