diff --git a/protocol/xmrtaker/errors.go b/protocol/xmrtaker/errors.go index 7ba31852..792ed872 100644 --- a/protocol/xmrtaker/errors.go +++ b/protocol/xmrtaker/errors.go @@ -48,3 +48,27 @@ func errContractAddrMismatch(addr string) error { //nolint:lll return fmt.Errorf("cannot recover from swap where contract address is not the one loaded at start-up; please restart with --contract-address=%s", addr) } + +type errAmountProvidedTooLow struct { + providedAmount *apd.Decimal + minAmount *apd.Decimal +} + +func (e errAmountProvidedTooLow) Error() string { + return fmt.Sprintf("%s ETH provided is under offer minimum of %s XMR", + e.providedAmount.String(), + e.minAmount.String(), + ) +} + +type errAmountProvidedTooHigh struct { + providedAmount *apd.Decimal + maxAmount *apd.Decimal +} + +func (e errAmountProvidedTooHigh) Error() string { + return fmt.Sprintf("%s ETH provided is over offer maximum of %s XMR", + e.providedAmount.String(), + e.maxAmount.String(), + ) +} diff --git a/protocol/xmrtaker/net.go b/protocol/xmrtaker/net.go index 71eea01a..fbfec3f0 100644 --- a/protocol/xmrtaker/net.go +++ b/protocol/xmrtaker/net.go @@ -27,10 +27,24 @@ func (inst *Instance) InitiateProtocol( providesAmount *apd.Decimal, offer *types.Offer, ) (common.SwapState, error) { + err := coins.ValidatePositive("providesAmount", coins.NumEtherDecimals, providesAmount) + if err != nil { + return nil, err + } + expectedAmount, err := offer.ExchangeRate.ToXMR(providesAmount) if err != nil { return nil, err } + + if expectedAmount.Cmp(offer.MinAmount) < 0 { + return nil, errAmountProvidedTooLow{providesAmount, offer.MinAmount} + } + + if expectedAmount.Cmp(offer.MaxAmount) > 0 { + return nil, errAmountProvidedTooHigh{providesAmount, offer.MaxAmount} + } + providedAmount, err := pcommon.GetEthAssetAmount( inst.backend.Ctx(), inst.backend.ETHClient(), diff --git a/protocol/xmrtaker/net_test.go b/protocol/xmrtaker/net_test.go index f1c70238..d253e341 100644 --- a/protocol/xmrtaker/net_test.go +++ b/protocol/xmrtaker/net_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/athanorlabs/atomic-swap/coins" + "github.com/athanorlabs/atomic-swap/common" "github.com/athanorlabs/atomic-swap/common/types" ) @@ -26,13 +27,50 @@ func newTestXMRTaker(t *testing.T) *Instance { return xmrtaker } +func initiate( + xmrtaker *Instance, + providesAmount *apd.Decimal, + minAmount *apd.Decimal, + maxAmount *apd.Decimal, +) (*types.Offer, common.SwapState, error) { + offer := types.NewOffer( + coins.ProvidesETH, + minAmount, + maxAmount, + coins.ToExchangeRate(apd.New(1, 0)), + types.EthAssetETH, + ) + s, err := xmrtaker.InitiateProtocol(testPeerID, providesAmount, offer) + return offer, s, err +} + func TestXMRTaker_InitiateProtocol(t *testing.T) { a := newTestXMRTaker(t) zero := new(apd.Decimal) one := apd.New(1, 0) - offer := types.NewOffer(coins.ProvidesETH, zero, zero, coins.ToExchangeRate(one), types.EthAssetETH) - providesAmount := apd.New(333, -2) // 3.33 - s, err := a.InitiateProtocol(testPeerID, providesAmount, offer) + + // Provided between minAmount and maxAmount + offer, s, err := initiate(a, apd.New(1, -1), zero, one) // 0.1 require.NoError(t, err) require.Equal(t, a.swapStates[offer.ID], s) + + // Provided with too many decimals + _, s, err = initiate(a, apd.New(1, -50), zero, one) // 10^-50 + require.Error(t, err) + require.Equal(t, nil, s) + + // Provided with a negative number + _, s, err = initiate(a, apd.New(-1, 0), zero, one) // -1 + require.Error(t, err) + require.Equal(t, nil, s) + + // Provided over maxAmount + _, s, err = initiate(a, apd.New(2, 0), one, one) // 2 + require.Error(t, err) + require.Equal(t, nil, s) + + // Provided under minAmount + _, s, err = initiate(a, apd.New(1, -1), one, one) // 0.1 + require.Error(t, err) + require.Equal(t, nil, s) }