mirror of
https://github.com/AthanorLabs/atomic-swap.git
synced 2026-01-10 06:38:04 -05:00
fix: remove ws eth endpoint requirement, update taker exit (#358)
This commit is contained in:
@@ -89,7 +89,12 @@ func (db *Database) PutOffer(offer *types.Offer) error {
|
||||
}
|
||||
|
||||
key := offer.ID
|
||||
return db.offerTable.Put(key[:], val)
|
||||
err = db.offerTable.Put(key[:], val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.offerTable.Flush()
|
||||
}
|
||||
|
||||
// DeleteOffer deletes an offer from the database.
|
||||
@@ -188,7 +193,12 @@ func (db *Database) PutSwap(s *swap.Info) error {
|
||||
}
|
||||
|
||||
key := s.ID
|
||||
return db.swapTable.Put(key[:], val)
|
||||
err = db.swapTable.Put(key[:], val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.swapTable.Flush()
|
||||
}
|
||||
|
||||
// HasSwap returns whether the db contains a swap with the given ID.
|
||||
|
||||
@@ -45,7 +45,12 @@ func (db *RecoveryDB) PutSwapRelayerInfo(id types.Hash, info *types.OfferExtra)
|
||||
}
|
||||
|
||||
key := getRecoveryDBKey(id, relayerInfoPrefix)
|
||||
return db.db.Put(key, val)
|
||||
err = db.db.Put(key, val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.db.Flush()
|
||||
}
|
||||
|
||||
// GetSwapRelayerInfo ...
|
||||
@@ -75,7 +80,12 @@ func (db *RecoveryDB) PutContractSwapInfo(id types.Hash, info *EthereumSwapInfo)
|
||||
}
|
||||
|
||||
key := getRecoveryDBKey(id, contractSwapInfoPrefix)
|
||||
return db.db.Put(key, val)
|
||||
err = db.db.Put(key, val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.db.Flush()
|
||||
}
|
||||
|
||||
// GetContractSwapInfo returns the contract swap ID (a hash of the `SwapFactorySwap` structure) and
|
||||
@@ -104,7 +114,12 @@ func (db *RecoveryDB) PutSwapPrivateKey(id types.Hash, sk *mcrypto.PrivateSpendK
|
||||
}
|
||||
|
||||
key := getRecoveryDBKey(id, swapPrivateKeyPrefix)
|
||||
return db.db.Put(key[:], val)
|
||||
err = db.db.Put(key, val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.db.Flush()
|
||||
}
|
||||
|
||||
// GetSwapPrivateKey returns the swap private key share, if it exists.
|
||||
@@ -132,7 +147,12 @@ func (db *RecoveryDB) PutCounterpartySwapPrivateKey(id types.Hash, kp *mcrypto.P
|
||||
}
|
||||
|
||||
key := getRecoveryDBKey(id, counterpartySwapPrivateKeyPrefix)
|
||||
return db.db.Put(key[:], val)
|
||||
err = db.db.Put(key, val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.db.Flush()
|
||||
}
|
||||
|
||||
// GetCounterpartySwapPrivateKey returns the counterparty's swap private key, if it exists.
|
||||
@@ -168,7 +188,14 @@ func (db *RecoveryDB) PutCounterpartySwapKeys(id types.Hash, sk *mcrypto.PublicK
|
||||
}
|
||||
|
||||
key := getRecoveryDBKey(id, counterpartySwapKeysPrefix)
|
||||
return db.db.Put(key[:], val)
|
||||
log.Debugf("PutCounterpartySwapKeys %s", key)
|
||||
err = db.db.Put(key, val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("flushing db")
|
||||
return db.db.Flush()
|
||||
}
|
||||
|
||||
// GetCounterpartySwapKeys is called during recovery to retrieve the counterparty's swap keys.
|
||||
@@ -189,7 +216,13 @@ func (db *RecoveryDB) GetCounterpartySwapKeys(id types.Hash) (*mcrypto.PublicKey
|
||||
}
|
||||
|
||||
// DeleteSwap deletes all recovery info from the db for the given swap.
|
||||
// TODO: this is currently unimplemented
|
||||
func (db *RecoveryDB) DeleteSwap(id types.Hash) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteSwap is currently unused.
|
||||
func (db *RecoveryDB) deleteSwap(id types.Hash) error {
|
||||
keys := [][]byte{
|
||||
getRecoveryDBKey(id, relayerInfoPrefix),
|
||||
getRecoveryDBKey(id, contractSwapInfoPrefix),
|
||||
|
||||
@@ -177,7 +177,7 @@ func TestRecoveryDB_DeleteSwap(t *testing.T) {
|
||||
err = rdb.PutCounterpartySwapKeys(offerID, kp.SpendKey().Public(), kp.ViewKey())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = rdb.DeleteSwap(offerID)
|
||||
err = rdb.deleteSwap(offerID)
|
||||
require.NoError(t, err)
|
||||
_, err = rdb.GetContractSwapInfo(offerID)
|
||||
require.EqualError(t, chaindb.ErrKeyNotFound, err.Error())
|
||||
|
||||
@@ -26,22 +26,18 @@ func WaitForEthBlockAfterTimestamp(ctx context.Context, ec *ethclient.Client, ts
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// subscribe to new block headers
|
||||
headers := make(chan *ethtypes.Header)
|
||||
defer close(headers)
|
||||
sub, err := ec.SubscribeNewHead(ctx, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
ticker := time.NewTicker(time.Second)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case err := <-sub.Err():
|
||||
return nil, err
|
||||
case header := <-headers:
|
||||
case <-ticker.C:
|
||||
header, err := ec.HeaderByNumber(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if header.Time >= uint64(ts.Unix()) {
|
||||
return header, nil
|
||||
}
|
||||
|
||||
@@ -56,11 +56,6 @@ func NewEventFilter(
|
||||
|
||||
// Start starts the EventFilter. It watches the chain for logs.
|
||||
func (f *EventFilter) Start() error {
|
||||
header, err := f.ec.HeaderByNumber(f.ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
@@ -71,10 +66,11 @@ func (f *EventFilter) Start() error {
|
||||
|
||||
currHeader, err := f.ec.HeaderByNumber(f.ctx, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get header in event watcher: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if currHeader.Number.Cmp(header.Number) <= 0 {
|
||||
if currHeader.Number.Cmp(f.filterQuery.FromBlock) <= 0 {
|
||||
// no new blocks, don't do anything
|
||||
continue
|
||||
}
|
||||
@@ -82,9 +78,12 @@ func (f *EventFilter) Start() error {
|
||||
// let's see if we have logs
|
||||
logs, err := f.ec.FilterLogs(f.ctx, f.filterQuery)
|
||||
if err != nil {
|
||||
log.Errorf("failed to filter logs for topic %s: %s", f.topic, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("filtered for logs from block %s to block %s", f.filterQuery.FromBlock, currHeader.Number)
|
||||
|
||||
for _, l := range logs {
|
||||
if l.Topics[0] != f.topic {
|
||||
continue
|
||||
@@ -95,12 +94,11 @@ func (f *EventFilter) Start() error {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("watcher for topic %s found log in block %d", f.topic, l.BlockNumber)
|
||||
f.logCh <- l
|
||||
}
|
||||
|
||||
// the filter is inclusive of the latest block when `ToBlock` is nil, so we add 1
|
||||
f.filterQuery.FromBlock = new(big.Int).Add(currHeader.Number, big.NewInt(1))
|
||||
header = currHeader
|
||||
f.filterQuery.FromBlock = currHeader.Number
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ func checkContractSwapID(msg *message.NotifyETHLocked) error {
|
||||
func (s *swapState) checkContract(txHash ethcommon.Hash) error {
|
||||
tx, _, err := s.ETHClient().Raw().TransactionByHash(s.ctx, txHash)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to get newSwap transaction %s by hash: %w", txHash, err)
|
||||
}
|
||||
|
||||
if tx.To() == nil || *(tx.To()) != s.contractAddr {
|
||||
|
||||
@@ -163,7 +163,7 @@ func waitForClaimReceipt(
|
||||
const (
|
||||
checkInterval = time.Second // time between transaction polls
|
||||
maxWait = time.Minute // max wait for the tx to be included in a block
|
||||
maxNotFound = 5 // max failures where the tx is not even found in the mempool
|
||||
maxNotFound = 10 // max failures where the tx is not even found in the mempool
|
||||
)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
@@ -142,12 +142,12 @@ func (inst *Instance) createOngoingSwap(s *swap.Info) error {
|
||||
|
||||
ethSwapInfo, err := inst.backend.RecoveryDB().GetContractSwapInfo(s.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get info for ongoing swap, id %s: %s", s.ID, err)
|
||||
return fmt.Errorf("failed to get contract info for ongoing swap from db with swap id %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
sk, err := inst.backend.RecoveryDB().GetSwapPrivateKey(s.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get private key for ongoing swap, id %s: %s", s.ID, err)
|
||||
return fmt.Errorf("failed to get private key for ongoing swap from db with swap id %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
kp, err := sk.AsPrivateKeyPair()
|
||||
@@ -172,7 +172,7 @@ func (inst *Instance) createOngoingSwap(s *swap.Info) error {
|
||||
kp,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new swap state for ongoing swap, id %s: %s", s.ID, err)
|
||||
return fmt.Errorf("failed to create new swap state for ongoing swap, id %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
inst.swapMu.Lock()
|
||||
|
||||
@@ -182,6 +182,7 @@ func newSwapStateFromOngoing(
|
||||
return nil, errInvalidStageForRecovery
|
||||
}
|
||||
|
||||
log.Debugf("restarting swap from eth block number %s", ethSwapInfo.StartNumber)
|
||||
s, err := newSwapState(
|
||||
b, offer, offerExtra, om, ethSwapInfo.StartNumber, info.MoneroStartHeight, info,
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ type Instance struct {
|
||||
// non-nil if a swap is currently happening, nil otherwise
|
||||
// map of offer IDs -> ongoing swaps
|
||||
swapStates map[types.Hash]*swapState
|
||||
swapMu sync.Mutex // lock for above map
|
||||
swapMu sync.RWMutex // lock for above map
|
||||
}
|
||||
|
||||
// Config contains the configuration values for a new XMRTaker instance.
|
||||
@@ -86,7 +86,8 @@ func (inst *Instance) checkForOngoingSwaps() error {
|
||||
|
||||
err = inst.createOngoingSwap(s)
|
||||
if err != nil {
|
||||
return err
|
||||
log.Errorf("%s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,12 +117,12 @@ func (inst *Instance) createOngoingSwap(s *swap.Info) error {
|
||||
|
||||
ethSwapInfo, err := inst.backend.RecoveryDB().GetContractSwapInfo(s.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get info for ongoing swap, id %s: %s", s.ID, err)
|
||||
return fmt.Errorf("failed to get contract info for ongoing swap from db with swap id %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
sk, err := inst.backend.RecoveryDB().GetSwapPrivateKey(s.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get private key for ongoing swap, id %s: %s", s.ID, err)
|
||||
return fmt.Errorf("failed to get private key for ongoing swap from db with swap id %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
kp, err := sk.AsPrivateKeyPair()
|
||||
@@ -139,7 +140,7 @@ func (inst *Instance) createOngoingSwap(s *swap.Info) error {
|
||||
kp,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new swap state for ongoing swap, id %s: %s", s.ID, err)
|
||||
return fmt.Errorf("failed to create new swap state for ongoing swap, id %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
inst.swapStates[s.ID] = ss
|
||||
@@ -226,6 +227,8 @@ func (inst *Instance) Refund(offerID types.Hash) (ethcommon.Hash, error) {
|
||||
|
||||
// GetOngoingSwapState ...
|
||||
func (inst *Instance) GetOngoingSwapState(offerID types.Hash) common.SwapState {
|
||||
inst.swapMu.RLock()
|
||||
defer inst.swapMu.RUnlock()
|
||||
return inst.swapStates[offerID]
|
||||
}
|
||||
|
||||
|
||||
@@ -120,6 +120,8 @@ func (s *swapState) handleSendKeysMessage(msg *message.SendKeysMessage) (common.
|
||||
return nil, fmt.Errorf("failed to set xmrmaker keys: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("stored XMR maker's keys, going to lock ETH")
|
||||
|
||||
txHash, err := s.lockAsset()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to lock ethereum asset in contract: %w", err)
|
||||
@@ -273,7 +275,7 @@ func (s *swapState) runT1ExpirationHandler() {
|
||||
if err != nil {
|
||||
// TODO: Do we propagate this error? If we retry, the logic should probably be inside
|
||||
// WaitForTimestamp. (#162)
|
||||
log.Errorf("Failure waiting for T1 timeout: err=%s", err)
|
||||
log.Errorf("failure waiting for T1 timeout: %s", err)
|
||||
return
|
||||
}
|
||||
s.handleT1Expired()
|
||||
|
||||
@@ -390,16 +390,21 @@ func (s *swapState) exit() error {
|
||||
txHash, err := s.tryRefund()
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), revertSwapCompleted) {
|
||||
return s.tryClaim()
|
||||
// note: this should NOT ever error; it could if the ethclient
|
||||
// or monero clients crash during the course of the claim,
|
||||
// but that would be very bad.
|
||||
err = s.tryClaim()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to claim even though swap was completed on-chain: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.clearNextExpectedEvent(types.CompletedAbort)
|
||||
log.Errorf("failed to refund: err=%s", err)
|
||||
return err
|
||||
return fmt.Errorf("failed to refund: %w", err)
|
||||
}
|
||||
|
||||
s.clearNextExpectedEvent(types.CompletedRefund)
|
||||
log.Infof("refunded ether: transaction hash=%s", txHash)
|
||||
return nil
|
||||
case EventNoneType:
|
||||
// the swap completed already, do nothing
|
||||
return nil
|
||||
@@ -408,8 +413,6 @@ func (s *swapState) exit() error {
|
||||
s.clearNextExpectedEvent(types.CompletedAbort)
|
||||
return errUnexpectedEventType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// doRefund is called by the RPC function swap_refund.
|
||||
@@ -439,7 +442,7 @@ func (s *swapState) tryRefund() (ethcommon.Hash, error) {
|
||||
|
||||
switch stage {
|
||||
case contracts.StageInvalid:
|
||||
return ethcommon.Hash{}, errRefundInvalid
|
||||
return ethcommon.Hash{}, fmt.Errorf("%w: contract swap ID: %s", errRefundInvalid, s.contractSwapID)
|
||||
case contracts.StageCompleted:
|
||||
return ethcommon.Hash{}, errRefundSwapCompleted
|
||||
case contracts.StagePending, contracts.StageReady:
|
||||
@@ -459,7 +462,7 @@ func (s *swapState) tryRefund() (ethcommon.Hash, error) {
|
||||
isReady, s.t0.Sub(ts).Seconds(), s.t1.Sub(ts).Seconds())
|
||||
|
||||
if ts.Before(s.t0) && !isReady {
|
||||
txHash, err := s.refund()
|
||||
txHash, err := s.refund() //nolint:govet
|
||||
// TODO: Have refund() return errors that we can use errors.Is to check against
|
||||
if err == nil {
|
||||
return txHash, nil
|
||||
@@ -484,17 +487,34 @@ func (s *swapState) tryRefund() (ethcommon.Hash, error) {
|
||||
// it won't handle those events while this function is executing.)
|
||||
log.Infof("waiting until time %s to refund", s.t1)
|
||||
|
||||
event := <-s.eventCh
|
||||
log.Debugf("got event %s while waiting for T1", event)
|
||||
switch event.(type) {
|
||||
case *EventShouldRefund:
|
||||
waitCtx, waitCtxCancel := context.WithCancel(s.ctx)
|
||||
defer waitCtxCancel()
|
||||
|
||||
waitCh := make(chan error)
|
||||
go func() {
|
||||
waitCh <- s.ETHClient().WaitForTimestamp(waitCtx, s.t1)
|
||||
close(waitCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case event := <-s.eventCh:
|
||||
log.Debugf("got event %s while waiting for T1", event)
|
||||
switch event.(type) {
|
||||
case *EventShouldRefund:
|
||||
return s.refund()
|
||||
case *EventETHClaimed:
|
||||
// we should claim; returning this error
|
||||
// causes the calling function to claim
|
||||
return ethcommon.Hash{}, fmt.Errorf(revertSwapCompleted)
|
||||
default:
|
||||
panic(fmt.Sprintf("got unexpected event while waiting for Claimed/T1: %s", event))
|
||||
}
|
||||
case err = <-waitCh:
|
||||
if err != nil {
|
||||
return ethcommon.Hash{}, fmt.Errorf("failed to wait for T1: %w", err)
|
||||
}
|
||||
|
||||
return s.refund()
|
||||
case *EventETHClaimed:
|
||||
// we should claim; returning this error
|
||||
// causes the calling function to claim
|
||||
return ethcommon.Hash{}, fmt.Errorf(revertSwapCompleted)
|
||||
default:
|
||||
panic(fmt.Sprintf("got unexpected event while waiting for Claimed/T1: %s", event))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,6 +601,8 @@ func (s *swapState) lockAsset() (ethcommon.Hash, error) {
|
||||
cmtXMRTaker := s.secp256k1Pub.Keccak256()
|
||||
cmtXMRMaker := s.xmrmakerSecp256k1PublicKey.Keccak256()
|
||||
|
||||
log.Debugf("locking ETH in contract")
|
||||
|
||||
nonce := generateNonce()
|
||||
txHash, receipt, err := s.sender.NewSwap(
|
||||
cmtXMRMaker,
|
||||
@@ -660,10 +682,12 @@ func (s *swapState) ready() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stage != contracts.StagePending {
|
||||
return fmt.Errorf("cannot set contract to ready when swap stage is %s", contracts.StageToString(stage))
|
||||
}
|
||||
_, receipt, err := s.sender.SetReady(s.contractSwap)
|
||||
|
||||
txHash, receipt, err := s.sender.SetReady(s.contractSwap)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), revertSwapCompleted) && !s.info.Status.IsOngoing() {
|
||||
return nil
|
||||
@@ -671,7 +695,7 @@ func (s *swapState) ready() error {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("contract set to ready in block %d", receipt.BlockNumber)
|
||||
log.Debugf("contract set to ready in block %d, tx %s", receipt.BlockNumber, txHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -303,6 +303,12 @@ func (s *SwapService) Cancel(_ *http.Request, req *CancelRequest, resp *CancelRe
|
||||
ss = s.xmrmaker.GetOngoingSwapState(req.OfferID)
|
||||
}
|
||||
|
||||
if ss == nil {
|
||||
return fmt.Errorf("failed to find swap state with ID %s", req.OfferID)
|
||||
}
|
||||
|
||||
// Exit() is safe to be called concurrently, since it since it puts an exit event
|
||||
// into the swap state's eventCh, and events are handled sequentially.
|
||||
if err = ss.Exit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user