mirror of
https://github.com/AthanorLabs/atomic-swap.git
synced 2026-01-10 06:38:04 -05:00
net: add initiate, query and relay timeouts (#378)
Co-authored-by: Dmitry Holodov <dimalinux@protonmail.com>
This commit is contained in:
26
net/host.go
26
net/host.go
@@ -7,6 +7,8 @@ package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -29,6 +31,7 @@ const (
|
||||
ProtocolID = "/atomic-swap/0.3"
|
||||
maxMessageSize = 1 << 17
|
||||
maxRelayMessageSize = 2048
|
||||
connectionTimeout = time.Second * 5
|
||||
)
|
||||
|
||||
var log = logging.Logger("net")
|
||||
@@ -220,3 +223,26 @@ func readStreamMessage(stream libp2pnetwork.Stream, maxMessageSize uint32) (comm
|
||||
|
||||
return message.DecodeMessage(msgBytes)
|
||||
}
|
||||
|
||||
// nextStreamMessage returns a channel that will receive the next message from the stream.
|
||||
// if there is an error reading from the stream, the channel will be closed, thus
|
||||
// the received value will be nil.
|
||||
func nextStreamMessage(stream libp2pnetwork.Stream, maxMessageSize uint32) <-chan common.Message {
|
||||
ch := make(chan common.Message)
|
||||
go func() {
|
||||
for {
|
||||
msg, err := readStreamMessage(stream, maxMessageSize)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Warnf("failed to read stream message: %s", err)
|
||||
}
|
||||
close(ch)
|
||||
return
|
||||
}
|
||||
|
||||
ch <- msg
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
@@ -20,8 +20,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
swapID = "/swap/0"
|
||||
protocolTimeout = time.Second * 5
|
||||
swapID = "/swap/0"
|
||||
)
|
||||
|
||||
// Initiate attempts to initiate a swap with the given peer by sending a SendKeysMessage,
|
||||
@@ -36,7 +35,7 @@ func (h *Host) Initiate(who peer.AddrInfo, sendKeysMessage common.Message, s com
|
||||
return errSwapAlreadyInProgress
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(h.ctx, protocolTimeout)
|
||||
ctx, cancel := context.WithTimeout(h.ctx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
if h.h.Connectedness(who.ID) != libp2pnetwork.Connected {
|
||||
@@ -66,10 +65,38 @@ func (h *Host) Initiate(who peer.AddrInfo, sendKeysMessage common.Message, s com
|
||||
isTaker: true,
|
||||
}
|
||||
|
||||
go h.handleProtocolStreamInner(stream, s)
|
||||
go h.receiveInitiateResponse(stream, s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Host) receiveInitiateResponse(stream libp2pnetwork.Stream, s SwapState) {
|
||||
defer h.handleProtocolStreamClose(stream, s)
|
||||
|
||||
const initiateResponseTimeout = time.Minute
|
||||
|
||||
select {
|
||||
case msg := <-nextStreamMessage(stream, maxMessageSize):
|
||||
if msg == nil {
|
||||
log.Errorf("failed to read initial SendKeysMessage response")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("received protocol=%s message from peer=%s type=%s",
|
||||
stream.Protocol(), stream.Conn().RemotePeer(), message.TypeToString(msg.Type()))
|
||||
|
||||
err := s.HandleProtocolMessage(msg)
|
||||
if err != nil {
|
||||
log.Warnf("failed to handle protocol message: err=%s", err)
|
||||
return
|
||||
}
|
||||
case <-time.After(initiateResponseTimeout):
|
||||
log.Errorf("timed out waiting for SendKeysMessage response")
|
||||
return
|
||||
}
|
||||
|
||||
h.handleProtocolStreamInner(stream, s)
|
||||
}
|
||||
|
||||
// handleProtocolStream is called when there is an incoming protocol stream.
|
||||
func (h *Host) handleProtocolStream(stream libp2pnetwork.Stream) {
|
||||
if h.makerHandler == nil {
|
||||
@@ -129,18 +156,7 @@ func (h *Host) handleProtocolStream(stream libp2pnetwork.Stream) {
|
||||
|
||||
// handleProtocolStreamInner is called to handle a protocol stream, in both ingoing and outgoing cases.
|
||||
func (h *Host) handleProtocolStreamInner(stream libp2pnetwork.Stream, s SwapState) {
|
||||
defer func() {
|
||||
log.Debugf("closing stream: peer=%s protocol=%s", stream.Conn().RemotePeer(), stream.Protocol())
|
||||
_ = stream.Close()
|
||||
|
||||
log.Debugf("exiting swap...")
|
||||
if err := s.Exit(); err != nil {
|
||||
log.Errorf("failed to exit protocol: err=%s", err)
|
||||
}
|
||||
h.swapMu.Lock()
|
||||
delete(h.swaps, s.OfferID())
|
||||
h.swapMu.Unlock()
|
||||
}()
|
||||
defer h.handleProtocolStreamClose(stream, s)
|
||||
|
||||
for {
|
||||
msg, err := readStreamMessage(stream, maxMessageSize)
|
||||
@@ -164,3 +180,16 @@ func (h *Host) handleProtocolStreamInner(stream libp2pnetwork.Stream, s SwapStat
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Host) handleProtocolStreamClose(stream libp2pnetwork.Stream, s SwapState) {
|
||||
log.Debugf("closing stream: peer=%s protocol=%s", stream.Conn().RemotePeer(), stream.Protocol())
|
||||
_ = stream.Close()
|
||||
|
||||
log.Debugf("exiting swap...")
|
||||
if err := s.Exit(); err != nil {
|
||||
log.Errorf("failed to exit protocol: %s", err)
|
||||
}
|
||||
h.swapMu.Lock()
|
||||
delete(h.swaps, s.OfferID())
|
||||
h.swapMu.Unlock()
|
||||
}
|
||||
|
||||
32
net/query.go
32
net/query.go
@@ -5,6 +5,7 @@ package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +18,6 @@ import (
|
||||
|
||||
const (
|
||||
queryProtocolID = "/query/0"
|
||||
queryTimeout = time.Second * 5
|
||||
)
|
||||
|
||||
func (h *Host) handleQueryStream(stream libp2pnetwork.Stream) {
|
||||
@@ -34,7 +34,7 @@ func (h *Host) handleQueryStream(stream libp2pnetwork.Stream) {
|
||||
|
||||
// Query queries the given peer for its offers.
|
||||
func (h *Host) Query(who peer.ID) (*QueryResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(h.ctx, queryTimeout)
|
||||
ctx, cancel := context.WithTimeout(h.ctx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := h.h.Connect(ctx, peer.AddrInfo{ID: who}); err != nil {
|
||||
@@ -56,17 +56,23 @@ func (h *Host) Query(who peer.ID) (*QueryResponse, error) {
|
||||
}
|
||||
|
||||
func receiveQueryResponse(stream libp2pnetwork.Stream) (*QueryResponse, error) {
|
||||
msg, err := readStreamMessage(stream, maxMessageSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading QueryResponse: %w", err)
|
||||
}
|
||||
const queryResponseTimeout = time.Second * 15
|
||||
|
||||
resp, ok := msg.(*QueryResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected %s message but received %s",
|
||||
message.TypeToString(message.QueryResponseType),
|
||||
message.TypeToString(msg.Type()))
|
||||
}
|
||||
select {
|
||||
case msg := <-nextStreamMessage(stream, maxMessageSize):
|
||||
if msg == nil {
|
||||
return nil, errors.New("failed to read QueryResponse")
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
resp, ok := msg.(*QueryResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected %s message but received %s",
|
||||
message.TypeToString(message.QueryResponseType),
|
||||
message.TypeToString(msg.Type()))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
case <-time.After(queryResponseTimeout):
|
||||
return nil, errors.New("timed out waiting for QueryResponse")
|
||||
}
|
||||
}
|
||||
|
||||
43
net/relay.go
43
net/relay.go
@@ -5,6 +5,7 @@ package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -16,8 +17,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
relayProtocolID = "/relay/0"
|
||||
relayClaimTimeout = time.Second * 30 // TODO: Vet this value
|
||||
relayProtocolID = "/relay/0"
|
||||
|
||||
// RelayerProvidesStr is the DHT namespace advertised by nodes willing to relay
|
||||
// claims for arbitrary XMR makers.
|
||||
@@ -93,13 +93,7 @@ func (h *Host) handleRelayStream(stream libp2pnetwork.Stream) {
|
||||
|
||||
// SubmitClaimToRelayer sends a request to relay a swap claim to a peer.
|
||||
func (h *Host) SubmitClaimToRelayer(relayerID peer.ID, request *RelayClaimRequest) (*RelayClaimResponse, error) {
|
||||
// The timeout should be short enough, that the Maker can try multiple relayers
|
||||
// before T1 expires even if the receiving node accepts the relay request and
|
||||
// just sits on it without doing anything.
|
||||
// TODO: https://github.com/AthanorLabs/atomic-swap/issues/375
|
||||
// The context below needs extension to cover the response. Right now
|
||||
// only covers the Connect(...).
|
||||
ctx, cancel := context.WithTimeout(h.ctx, relayClaimTimeout)
|
||||
ctx, cancel := context.WithTimeout(h.ctx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := h.h.Connect(ctx, peer.AddrInfo{ID: relayerID}); err != nil {
|
||||
@@ -123,17 +117,26 @@ func (h *Host) SubmitClaimToRelayer(relayerID peer.ID, request *RelayClaimReques
|
||||
}
|
||||
|
||||
func receiveRelayClaimResponse(stream libp2pnetwork.Stream) (*RelayClaimResponse, error) {
|
||||
msg, err := readStreamMessage(stream, maxRelayMessageSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read RelayClaimResponse: %w", err)
|
||||
}
|
||||
// The timeout should be short enough, that the Maker can try multiple relayers
|
||||
// before T1 expires even if the receiving node accepts the relay request and
|
||||
// just sits on it without doing anything.
|
||||
const relayResponseTimeout = time.Second * 45
|
||||
|
||||
resp, ok := msg.(*RelayClaimResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected %s message but received %s",
|
||||
message.TypeToString(message.RelayClaimResponseType),
|
||||
message.TypeToString(msg.Type()))
|
||||
}
|
||||
select {
|
||||
case msg := <-nextStreamMessage(stream, maxMessageSize):
|
||||
if msg == nil {
|
||||
return nil, errors.New("failed to read RelayClaimResponse")
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
resp, ok := msg.(*RelayClaimResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected %s message but received %s",
|
||||
message.TypeToString(message.RelayClaimResponseType),
|
||||
message.TypeToString(msg.Type()))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
case <-time.After(relayResponseTimeout):
|
||||
return nil, errors.New("timed out waiting for QueryResponse")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func TestHost_SubmitClaimToRelayer_dhtRelayer(t *testing.T) {
|
||||
// possible privacy data leaks, but in this case it is because hb is not
|
||||
// a DHT advertising relayer.
|
||||
_, err = hb.SubmitClaimToRelayer(ha.PeerID(), createTestClaimRequest())
|
||||
require.ErrorContains(t, err, "failed to read RelayClaimResponse: EOF")
|
||||
require.ErrorContains(t, err, "failed to read RelayClaimResponse")
|
||||
}
|
||||
|
||||
func TestHost_SubmitClaimToRelayer_xmrTakerRelayer(t *testing.T) {
|
||||
@@ -104,7 +104,7 @@ func TestHost_SubmitClaimToRelayer_xmrTakerRelayer(t *testing.T) {
|
||||
|
||||
// fail, because there is no ongoing swap between ha and hb
|
||||
_, err := hb.SubmitClaimToRelayer(ha.PeerID(), request)
|
||||
require.ErrorContains(t, err, "failed to read RelayClaimResponse: EOF")
|
||||
require.ErrorContains(t, err, "failed to read RelayClaimResponse")
|
||||
|
||||
// create an ongoing swap between ha and hb
|
||||
swapState := &mockSwapState{offerID: offerID}
|
||||
|
||||
@@ -284,6 +284,5 @@ func (s *swapState) handleEventETHRefunded(e *EventETHRefunded) error {
|
||||
}
|
||||
|
||||
s.clearNextExpectedEvent(types.CompletedRefund)
|
||||
s.CloseProtocolStream(s.OfferID())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -458,6 +458,8 @@ func (s *swapState) exit() error {
|
||||
log.Debugf("attempting to exit swap: nextExpectedEvent=%v", s.nextExpectedEvent)
|
||||
|
||||
defer func() {
|
||||
s.CloseProtocolStream(s.OfferID())
|
||||
|
||||
err := s.SwapManager().CompleteOngoingSwap(s.info)
|
||||
if err != nil {
|
||||
log.Warnf("failed to mark swap %s as completed: %s", s.offer.ID, err)
|
||||
|
||||
@@ -324,7 +324,6 @@ func (s *swapState) handleEventETHClaimed(event *EventETHClaimed) error {
|
||||
}
|
||||
|
||||
s.clearNextExpectedEvent(types.CompletedSuccess)
|
||||
s.CloseProtocolStream(s.OfferID())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -287,32 +286,11 @@ func newSwapState(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.waitForSendKeysMessage()
|
||||
go s.runHandleEvents()
|
||||
go s.runContractEventWatcher()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *swapState) waitForSendKeysMessage() {
|
||||
waitDuration := time.Minute * 5
|
||||
timer := time.After(waitDuration)
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-timer:
|
||||
}
|
||||
|
||||
// check if we've received a response from the counterparty yet
|
||||
if reflect.TypeOf(s.nextExpectedEvent) != reflect.TypeOf(&EventKeysReceived{}) {
|
||||
return
|
||||
}
|
||||
|
||||
// if not, just exit the swap
|
||||
if err := s.Exit(); err != nil {
|
||||
log.Warnf("Swap exit failure: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SendKeysMessage ...
|
||||
func (s *swapState) SendKeysMessage() common.Message {
|
||||
return &message.SendKeysMessage{
|
||||
@@ -349,6 +327,8 @@ func (s *swapState) Exit() error {
|
||||
// exit is the same as Exit, but assumes the calling code block already holds the swapState lock.
|
||||
func (s *swapState) exit() error {
|
||||
defer func() {
|
||||
s.CloseProtocolStream(s.OfferID())
|
||||
|
||||
err := s.SwapManager().CompleteOngoingSwap(s.info)
|
||||
if err != nil {
|
||||
log.Warnf("failed to mark swap %s as completed: %s", s.info.OfferID, err)
|
||||
|
||||
Reference in New Issue
Block a user