Compare commits

...

9 Commits

Author SHA1 Message Date
Vlado Pajić
a37c325f99 utilize quic 2025-09-09 13:13:43 +02:00
Radosław Kamiński
62388a7a20 refactor(rendezvous): Split Rendezvous Protobuf and add tests (#1671) 2025-09-09 09:12:43 +01:00
vladopajic
27051164db chore(tests): utilize quic transport in pubsub tests (#1667) 2025-09-08 21:30:21 +00:00
Gabriel Cruz
f41009461b fix: revert reviewdog workaround (#1672) 2025-09-08 16:35:47 -04:00
vladopajic
c3faabf522 chore(quic): add tests from common interop (#1662) 2025-09-08 16:06:22 +00:00
Gabriel Cruz
10f7f5c68a chore(autonat-v2): add server config (#1669) 2025-09-08 15:23:23 +00:00
Gabriel Cruz
f345026900 fix(linters): use workaround for reviewdog bug (#1668) 2025-09-08 14:48:03 +00:00
vladopajic
5d6578a06f chore: splitRPCMsg improvements (#1665) 2025-09-08 11:06:55 -03:00
Gabriel Cruz
871a5d047f feat(autonat-v2): add server (#1658) 2025-09-04 13:27:49 -04:00
18 changed files with 1662 additions and 941 deletions

View File

@@ -8,7 +8,7 @@ json_serialization;https://github.com/status-im/nim-json-serialization@#2b1c5eb1
metrics;https://github.com/status-im/nim-metrics@#6142e433fc8ea9b73379770a788017ac528d46ff
ngtcp2;https://github.com/status-im/nim-ngtcp2@#9456daa178c655bccd4a3c78ad3b8cce1f0add73
nimcrypto;https://github.com/cheatfate/nimcrypto@#19c41d6be4c00b4a2c8000583bd30cf8ceb5f4b1
quic;https://github.com/vacp2p/nim-quic@#cae13c2d22ba2730c979486cf89b88927045c3ae
quic;https://github.com/vacp2p/nim-quic@#9370190ded18d78a5a9990f57aa8cbbf947f3891
results;https://github.com/arnetheduck/nim-results@#df8113dda4c2d74d460a8fa98252b0b771bf1f27
secp256k1;https://github.com/status-im/nim-secp256k1@#f808ed5e7a7bfc42204ec7830f14b7a42b63c284
serialization;https://github.com/status-im/nim-serialization@#548d0adc9797a10b2db7f788b804330306293088

View File

@@ -10,7 +10,7 @@ skipDirs = @["tests", "examples", "Nim", "tools", "scripts", "docs"]
requires "nim >= 2.0.0",
"nimcrypto >= 0.6.0 & < 0.7.0", "dnsclient >= 0.3.0 & < 0.4.0", "bearssl >= 0.2.5",
"chronicles >= 0.11.0 & < 0.12.0", "chronos >= 4.0.4", "metrics", "secp256k1",
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.15",
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.16",
"https://github.com/vacp2p/nim-jwt.git#18f8378de52b241f321c1f9ea905456e89b95c6f"
let nimc = getEnv("NIMC", "nim") # Which nim compiler to use

View File

@@ -26,7 +26,8 @@ import
transports/[transport, tcptransport, wstransport, memorytransport],
muxers/[muxer, mplex/mplex, yamux/yamux],
protocols/[identify, secure/secure, secure/noise, rendezvous],
protocols/connectivity/[autonat/server, relay/relay, relay/client, relay/rtransport],
protocols/connectivity/
[autonat/server, autonatv2/server, relay/relay, relay/client, relay/rtransport],
connmanager,
upgrademngrs/muxedupgrade,
observedaddrmanager,
@@ -74,6 +75,8 @@ type
nameResolver: NameResolver
peerStoreCapacity: Opt[int]
autonat: bool
autonatV2: bool
autonatV2Config: AutonatV2Config
autotls: AutotlsService
circuitRelay: Relay
rdv: RendezVous
@@ -280,6 +283,13 @@ proc withAutonat*(b: SwitchBuilder): SwitchBuilder =
b.autonat = true
b
proc withAutonatV2*(
b: SwitchBuilder, config: AutonatV2Config = AutonatV2Config.new()
): SwitchBuilder =
b.autonatV2 = true
b.autonatV2Config = config
b
when defined(libp2p_autotls_support):
proc withAutotls*(
b: SwitchBuilder, config: AutotlsConfig = AutotlsConfig.new()
@@ -379,7 +389,10 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
switch.mount(identify)
if b.autonat:
if b.autonatV2:
let autonatV2 = AutonatV2.new(switch, config = b.autonatV2Config)
switch.mount(autonatV2)
elif b.autonat:
let autonat = Autonat.new(switch)
switch.mount(autonat)
@@ -395,10 +408,75 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
return switch
type TransportType* {.pure.} = enum
QUIC
TCP
Memory
proc newStandardSwitchBuilder*(
privKey = none(PrivateKey),
addrs: MultiAddress | seq[MultiAddress] = newSeq[MultiAddress](),
transport: TransportType = TransportType.TCP,
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
transportFlags: set[ServerFlags] = {},
rng = newRng(),
inTimeout: Duration = 5.minutes,
outTimeout: Duration = 5.minutes,
maxConnections = MaxConnections,
maxIn = -1,
maxOut = -1,
maxConnsPerPeer = MaxConnectionsPerPeer,
nameResolver: NameResolver = nil,
sendSignedPeerRecord = false,
peerStoreCapacity = 1000,
): SwitchBuilder {.raises: [LPError], public.} =
## Helper for common switch configurations.
var b = SwitchBuilder
.new()
.withRng(rng)
.withSignedPeerRecord(sendSignedPeerRecord)
.withMaxConnections(maxConnections)
.withMaxIn(maxIn)
.withMaxOut(maxOut)
.withMaxConnsPerPeer(maxConnsPerPeer)
.withPeerStore(capacity = peerStoreCapacity)
.withNameResolver(nameResolver)
.withNoise()
var addrs =
when addrs is MultiAddress:
@[addrs]
else:
addrs
case transport
of TransportType.QUIC:
when defined(libp2p_quic_support):
if addrs.len == 0:
addrs = @[MultiAddress.init("/ip4/0.0.0.0/udp/0/quic-v1").tryGet()]
b = b.withQuicTransport().withAddresses(addrs)
else:
raiseAssert "QUIC not supported in this build"
of TransportType.TCP:
if addrs.len == 0:
addrs = @[MultiAddress.init("/ip4/127.0.0.1/tcp/0").tryGet()]
b = b.withTcpTransport(transportFlags).withAddresses(addrs).withMplex(
inTimeout, outTimeout
)
of TransportType.Memory:
if addrs.len == 0:
addrs = @[MultiAddress.init(MemoryAutoAddress).tryGet()]
b = b.withMemoryTransport().withAddresses(addrs).withMplex(inTimeout, outTimeout)
privKey.withValue(pkey):
b = b.withPrivateKey(pkey)
b
proc newStandardSwitch*(
privKey = none(PrivateKey),
addrs: MultiAddress | seq[MultiAddress] =
MultiAddress.init("/ip4/127.0.0.1/tcp/0").expect("valid address"),
addrs: MultiAddress | seq[MultiAddress] = newSeq[MultiAddress](),
transport: TransportType = TransportType.TCP,
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
transportFlags: set[ServerFlags] = {},
rng = newRng(),
@@ -412,28 +490,20 @@ proc newStandardSwitch*(
sendSignedPeerRecord = false,
peerStoreCapacity = 1000,
): Switch {.raises: [LPError], public.} =
## Helper for common switch configurations.
let addrs =
when addrs is MultiAddress:
@[addrs]
else:
addrs
var b = SwitchBuilder
.new()
.withAddresses(addrs)
.withRng(rng)
.withSignedPeerRecord(sendSignedPeerRecord)
.withMaxConnections(maxConnections)
.withMaxIn(maxIn)
.withMaxOut(maxOut)
.withMaxConnsPerPeer(maxConnsPerPeer)
.withPeerStore(capacity = peerStoreCapacity)
.withMplex(inTimeout, outTimeout)
.withTcpTransport(transportFlags)
.withNameResolver(nameResolver)
.withNoise()
privKey.withValue(pkey):
b = b.withPrivateKey(pkey)
b.build()
newStandardSwitchBuilder(
privKey = privKey,
addrs = addrs,
secureManagers = secureManagers,
transportFlags = transportFlags,
rng = rng,
inTimeout = inTimeout,
outTimeout = outTimeout,
maxConnections = maxConnections,
maxIn = maxIn,
maxOut = maxOut,
maxConnsPerPeer = maxConnsPerPeer,
nameResolver = nameResolver,
sendSignedPeerRecord = sendSignedPeerRecord,
peerStoreCapacity = peerStoreCapacity,
)
.build()

View File

@@ -11,7 +11,7 @@
import chronos
import results
import peerid, stream/connection, transports/transport
import peerid, stream/connection, transports/transport, muxers/muxer
export results
@@ -65,6 +65,23 @@ method dial*(
method addTransport*(self: Dial, transport: Transport) {.base.} =
doAssert(false, "[Dial.addTransport] abstract method not implemented!")
method dialAndUpgrade*(
self: Dial, peerId: Opt[PeerId], address: MultiAddress, dir = Direction.Out
): Future[Muxer] {.base, async: (raises: [CancelledError]).} =
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
method dialAndUpgrade*(
self: Dial, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
): Future[Muxer] {.
base, async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
.} =
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
method negotiateStream*(
self: Dial, conn: Connection, protos: seq[string]
): Future[Connection] {.base, async: (raises: [CatchableError]).} =
doAssert(false, "[Dial.negotiateStream] abstract method not implemented!")
method tryDial*(
self: Dial, peerId: PeerId, addrs: seq[MultiAddress]
): Future[Opt[MultiAddress]] {.

View File

@@ -43,7 +43,7 @@ type Dialer* = ref object of Dial
peerStore: PeerStore
nameResolver: NameResolver
proc dialAndUpgrade(
proc dialAndUpgrade*(
self: Dialer,
peerId: Opt[PeerId],
hostname: string,
@@ -139,7 +139,7 @@ proc expandDnsAddr(
else:
result.add((resolvedAddress, peerId))
proc dialAndUpgrade(
proc dialAndUpgrade*(
self: Dialer, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
): Future[Muxer] {.
async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
@@ -284,7 +284,7 @@ method connect*(
return
(await self.internalConnect(Opt.none(PeerId), @[address], false)).connection.peerId
proc negotiateStream(
proc negotiateStream*(
self: Dialer, conn: Connection, protos: seq[string]
): Future[Connection] {.async: (raises: [CatchableError]).} =
trace "Negotiating stream", conn, protos

View File

@@ -0,0 +1,279 @@
# Nim-LibP2P
# Copyright (c) 2025 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
import results
import chronos, chronicles
import
../../../../libp2p/[
switch,
muxers/muxer,
dialer,
multiaddress,
transports/transport,
multicodec,
peerid,
protobuf/minprotobuf,
utils/ipaddr,
],
../../protocol,
./types
logScope:
topics = "libp2p autonat v2 server"
type AutonatV2Config* = object
dialTimeout: Duration
dialDataSize: uint64
amplificationAttackTimeout: Duration
allowPrivateAddresses: bool
type AutonatV2* = ref object of LPProtocol
switch*: Switch
config: AutonatV2Config
proc new*(
T: typedesc[AutonatV2Config],
dialTimeout: Duration = DefaultDialTimeout,
dialDataSize: uint64 = DefaultDialDataSize,
amplificationAttackTimeout: Duration = DefaultAmplificationAttackDialTimeout,
allowPrivateAddresses: bool = false,
): T =
T(
dialTimeout: dialTimeout,
dialDataSize: dialDataSize,
amplificationAttackTimeout: amplificationAttackTimeout,
allowPrivateAddresses: allowPrivateAddresses,
)
proc sendDialResponse(
conn: Connection,
status: ResponseStatus,
addrIdx: Opt[AddrIdx] = Opt.none(AddrIdx),
dialStatus: Opt[DialStatus] = Opt.none(DialStatus),
) {.async: (raises: [CancelledError, LPStreamError]).} =
await conn.writeLp(
AutonatV2Msg(
msgType: MsgType.DialResponse,
dialResp: DialResponse(status: status, addrIdx: addrIdx, dialStatus: dialStatus),
).encode().buffer
)
proc findObservedIPAddr*(
conn: Connection, req: DialRequest
): Future[Opt[MultiAddress]] {.async: (raises: [CancelledError, LPStreamError]).} =
let observedAddr = conn.observedAddr.valueOr:
await conn.sendDialResponse(ResponseStatus.EInternalError)
return Opt.none(MultiAddress)
let isRelayed = observedAddr.contains(multiCodec("p2p-circuit")).valueOr:
error "Invalid observed address"
await conn.sendDialResponse(ResponseStatus.EDialRefused)
return Opt.none(MultiAddress)
if isRelayed:
error "Invalid observed address: relayed address"
await conn.sendDialResponse(ResponseStatus.EDialRefused)
return Opt.none(MultiAddress)
let hostIp = observedAddr[0].valueOr:
error "Invalid observed address"
await conn.sendDialResponse(ResponseStatus.EInternalError)
return Opt.none(MultiAddress)
return Opt.some(hostIp)
proc dialBack(
conn: Connection, nonce: Nonce
): Future[DialStatus] {.
async: (raises: [CancelledError, DialFailedError, LPStreamError])
.} =
try:
# send dial back
await conn.writeLp(DialBack(nonce: nonce).encode().buffer)
# receive DialBackResponse
let dialBackResp = DialBackResponse.decode(
initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))
).valueOr:
error "DialBack failed, could not decode DialBackResponse"
return DialStatus.EDialBackError
except LPStreamRemoteClosedError as exc:
# failed because of nonce error (remote reset the stream): EDialBackError
error "DialBack failed, remote closed the connection", description = exc.msg
return DialStatus.EDialBackError
# TODO: failed because of client or server resources: EDialError
trace "DialBack successful"
return DialStatus.Ok
proc handleDialDataResponses(
self: AutonatV2, conn: Connection
) {.async: (raises: [CancelledError, AutonatV2Error, LPStreamError]).} =
var dataReceived: uint64 = 0
while dataReceived < self.config.dialDataSize:
let msg = AutonatV2Msg.decode(
initProtoBuffer(await conn.readLp(AutonatV2DialDataResponseLpSize))
).valueOr:
raise newException(AutonatV2Error, "Received malformed message")
debug "Received message", msgType = $msg.msgType
if msg.msgType != MsgType.DialDataResponse:
raise
newException(AutonatV2Error, "Expecting DialDataResponse, got " & $msg.msgType)
let resp = msg.dialDataResp
dataReceived += resp.data.len.uint64
debug "received data",
dataReceived = resp.data.len.uint64, totalDataReceived = dataReceived
proc amplificationAttackPrevention(
self: AutonatV2, conn: Connection, addrIdx: AddrIdx
): Future[bool] {.async: (raises: [CancelledError, LPStreamError]).} =
# send DialDataRequest
await conn.writeLp(
AutonatV2Msg(
msgType: MsgType.DialDataRequest,
dialDataReq: DialDataRequest(addrIdx: addrIdx, numBytes: self.config.dialDataSize),
).encode().buffer
)
# recieve DialDataResponses until we're satisfied
try:
if not await self.handleDialDataResponses(conn).withTimeout(self.config.dialTimeout):
error "Amplification attack prevention timeout",
timeout = self.config.amplificationAttackTimeout, peer = conn.peerId
return false
except AutonatV2Error as exc:
error "Amplification attack prevention failed", description = exc.msg
return false
return true
proc canDial(self: AutonatV2, addrs: MultiAddress): bool =
let (ipv4Support, ipv6Support) = self.switch.peerInfo.listenAddrs.ipSupport()
addrs[0].withValue(addrIp):
if IP4.match(addrIp) and not ipv4Support:
return false
if IP6.match(addrIp) and not ipv6Support:
return false
try:
if not self.config.allowPrivateAddresses and isPrivate($addrIp):
return false
except ValueError:
warn "Unable to parse IP address, skipping", addrs = $addrIp
return false
for t in self.switch.transports:
if t.handles(addrs):
return true
return false
proc forceNewConnection(
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
): Future[Opt[Connection]] {.async: (raises: [CancelledError]).} =
## Bypasses connManager to force a new connection to ``pid``
## instead of reusing a preexistent one
try:
let mux = await self.switch.dialer.dialAndUpgrade(Opt.some(pid), addrs)
if mux.isNil():
return Opt.none(Connection)
return Opt.some(
await self.switch.negotiateStream(
await mux.newStream(), @[$AutonatV2Codec.DialBack]
)
)
except CancelledError as exc:
raise exc
except CatchableError:
return Opt.none(Connection)
proc chooseDialAddr(
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
): Future[Opt[(Connection, AddrIdx)]] {.async: (raises: [CancelledError]).} =
for i, ma in addrs:
if self.canDial(ma):
debug "Trying to dial", chosenAddrs = ma, addrIdx = i
let conn = (await self.forceNewConnection(pid, @[ma])).valueOr:
return Opt.none((Connection, AddrIdx))
return Opt.some((conn, i.AddrIdx))
return Opt.none((Connection, AddrIdx))
proc handleDialRequest(
self: AutonatV2, conn: Connection, req: DialRequest
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
let observedIPAddr = (await conn.findObservedIPAddr(req)).valueOr:
error "Could not find observed IP address"
return
let (dialBackConn, addrIdx) = (await self.chooseDialAddr(conn.peerId, req.addrs)).valueOr:
error "No dialable addresses found"
await conn.sendDialResponse(ResponseStatus.EDialRefused)
return
defer:
await dialBackConn.close()
# if observed address for peer is not in address list to try
# then we perform Amplification Attack Prevention
if not ipAddrMatches(observedIPAddr, req.addrs):
debug "Starting amplification attack prevention",
observedIPAddr = observedIPAddr, testAddr = req.addrs[addrIdx]
# send DialDataRequest and wait until dataReceived is enough
if not await self.amplificationAttackPrevention(conn, addrIdx):
return
debug "Sending DialBack",
nonce = req.nonce, addrIdx = addrIdx, addr = req.addrs[addrIdx]
let dialStatus = await dialBackConn.dialBack(req.nonce)
await conn.sendDialResponse(
ResponseStatus.Ok, addrIdx = Opt.some(addrIdx), dialStatus = Opt.some(dialStatus)
)
proc new*(
T: typedesc[AutonatV2],
switch: Switch,
config: AutonatV2Config = AutonatV2Config.new(),
): T =
let autonatV2 = T(switch: switch, config: config)
proc handleStream(
conn: Connection, proto: string
) {.async: (raises: [CancelledError]).} =
defer:
await conn.close()
let msg =
try:
AutonatV2Msg.decode(initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))).valueOr:
trace "Unable to decode AutonatV2Msg"
return
except LPStreamError as exc:
debug "Could not receive AutonatV2Msg", description = exc.msg
return
debug "Received message", msgType = $msg.msgType
if msg.msgType != MsgType.DialRequest:
debug "Expecting DialRequest", receivedMsgType = msg.msgType
return
try:
await autonatV2.handleDialRequest(conn, msg.dialReq)
except CancelledError as exc:
raise exc
except LPStreamRemoteClosedError as exc:
debug "Connection closed by peer", description = exc.msg, peer = conn.peerId
except LPStreamError as exc:
debug "Stream Error", description = exc.msg
except DialFailedError as exc:
debug "Could not dial peer", description = exc.msg, peer = conn.peerId
autonatV2.handler = handleStream
autonatV2.codec = $AutonatV2Codec.DialRequest
autonatV2

View File

@@ -14,6 +14,14 @@ import
../../../multiaddress, ../../../peerid, ../../../protobuf/minprotobuf, ../../../switch
from ../autonat/types import NetworkReachability
const
DefaultDialTimeout*: Duration = 15.seconds
DefaultAmplificationAttackDialTimeout*: Duration = 3.seconds
DefaultDialDataSize*: uint64 = 50 * 1024 # 50 KiB > 50 KB
AutonatV2MsgLpSize*: int = 1024
# readLp needs to receive more than 4096 bytes (since it's a DialDataResponse) + overhead
AutonatV2DialDataResponseLpSize*: int = 5000
type
AutonatV2Codec* {.pure.} = enum
DialRequest = "/libp2p/autonat/2/dial-request"

View File

@@ -478,19 +478,16 @@ iterator splitRPCMsg(
## exceeds the `maxSize` when trying to fit into an empty `RPCMsg`, the latter is skipped as too large to send.
## Every constructed `RPCMsg` is then encoded, optionally anonymized, and yielded as a sequence of bytes.
var currentRPCMsg = rpcMsg
currentRPCMsg.messages = newSeq[Message]()
var currentSize = byteSize(currentRPCMsg)
var currentRPCMsg = RPCMsg()
var currentSize = 0
for msg in rpcMsg.messages:
let msgSize = byteSize(msg)
# Check if adding the next message will exceed maxSize
if float(currentSize + msgSize) * 1.1 > float(maxSize):
# Guessing 10% protobuf overhead
if currentRPCMsg.messages.len == 0:
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
if currentSize + msgSize > maxSize:
if msgSize > maxSize:
warn "message too big to sent", peer, rpcMsg = shortLog(msg)
continue # Skip this message
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
@@ -502,11 +499,9 @@ iterator splitRPCMsg(
currentSize += msgSize
# Check if there is a non-empty currentRPCMsg left to be added
if currentSize > 0 and currentRPCMsg.messages.len > 0:
if currentRPCMsg.messages.len > 0:
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
yield encodeRpcMsg(currentRPCMsg, anonymize)
else:
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
proc send*(
p: PubSubPeer,
@@ -542,8 +537,11 @@ proc send*(
sendMetrics(msg)
encodeRpcMsg(msg, anonymize)
if encoded.len > p.maxMessageSize and msg.messages.len > 1:
for encodedSplitMsg in splitRPCMsg(p, msg, p.maxMessageSize, anonymize):
# Messages should not exceed 90% of maxMessageSize. Guessing 10% protobuf overhead.
let maxEncodedMsgSize = (p.maxMessageSize * 90) div 100
if encoded.len > maxEncodedMsgSize and msg.messages.len > 1:
for encodedSplitMsg in splitRPCMsg(p, msg, maxEncodedMsgSize, anonymize):
asyncSpawn p.sendEncoded(encodedSplitMsg, isHighPriority, useCustomConn)
else:
# If the message size is within limits, send it as is

View File

@@ -1,851 +1,3 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import ./rendezvous/rendezvous
{.push raises: [].}
import tables, sequtils, sugar, sets
import metrics except collect
import chronos, chronicles, bearssl/rand, stew/[byteutils, objects]
import
./protocol,
../protobuf/minprotobuf,
../switch,
../routing_record,
../utils/heartbeat,
../stream/connection,
../utils/offsettedseq,
../utils/semaphore,
../discovery/discoverymngr
export chronicles
logScope:
topics = "libp2p discovery rendezvous"
declareCounter(libp2p_rendezvous_register, "number of advertise requests")
declareCounter(libp2p_rendezvous_discover, "number of discovery requests")
declareGauge(libp2p_rendezvous_registered, "number of registered peers")
declareGauge(libp2p_rendezvous_namespaces, "number of registered namespaces")
const
RendezVousCodec* = "/rendezvous/1.0.0"
# Default minimum TTL per libp2p spec
MinimumDuration* = 2.hours
# Lower validation limit to accommodate Waku requirements
MinimumAcceptedDuration = 1.minutes
MaximumDuration = 72.hours
MaximumMessageLen = 1 shl 22 # 4MB
MinimumNamespaceLen = 1
MaximumNamespaceLen = 255
RegistrationLimitPerPeer* = 1000
DiscoverLimit = 1000'u64
SemaphoreDefaultSize = 5
type
MessageType {.pure.} = enum
Register = 0
RegisterResponse = 1
Unregister = 2
Discover = 3
DiscoverResponse = 4
ResponseStatus = enum
Ok = 0
InvalidNamespace = 100
InvalidSignedPeerRecord = 101
InvalidTTL = 102
InvalidCookie = 103
NotAuthorized = 200
InternalError = 300
Unavailable = 400
Cookie = object
offset: uint64
ns: Opt[string]
Register = object
ns: string
signedPeerRecord: seq[byte]
ttl*: Opt[uint64] # in seconds
RegisterResponse = object
status: ResponseStatus
text: Opt[string]
ttl: Opt[uint64] # in seconds
Unregister = object
ns: string
Discover = object
ns: Opt[string]
limit: Opt[uint64]
cookie: Opt[seq[byte]]
DiscoverResponse = object
registrations: seq[Register]
cookie: Opt[seq[byte]]
status: ResponseStatus
text: Opt[string]
Message = object
msgType: MessageType
register: Opt[Register]
registerResponse: Opt[RegisterResponse]
unregister: Opt[Unregister]
discover: Opt[Discover]
discoverResponse: Opt[DiscoverResponse]
proc encode(c: Cookie): ProtoBuffer =
result = initProtoBuffer()
result.write(1, c.offset)
if c.ns.isSome():
result.write(2, c.ns.get())
result.finish()
proc encode(r: Register): ProtoBuffer =
result = initProtoBuffer()
result.write(1, r.ns)
result.write(2, r.signedPeerRecord)
r.ttl.withValue(ttl):
result.write(3, ttl)
result.finish()
proc encode(rr: RegisterResponse): ProtoBuffer =
result = initProtoBuffer()
result.write(1, rr.status.uint)
rr.text.withValue(text):
result.write(2, text)
rr.ttl.withValue(ttl):
result.write(3, ttl)
result.finish()
proc encode(u: Unregister): ProtoBuffer =
result = initProtoBuffer()
result.write(1, u.ns)
result.finish()
proc encode(d: Discover): ProtoBuffer =
result = initProtoBuffer()
if d.ns.isSome():
result.write(1, d.ns.get())
d.limit.withValue(limit):
result.write(2, limit)
d.cookie.withValue(cookie):
result.write(3, cookie)
result.finish()
proc encode(dr: DiscoverResponse): ProtoBuffer =
result = initProtoBuffer()
for reg in dr.registrations:
result.write(1, reg.encode())
dr.cookie.withValue(cookie):
result.write(2, cookie)
result.write(3, dr.status.uint)
dr.text.withValue(text):
result.write(4, text)
result.finish()
proc encode(msg: Message): ProtoBuffer =
result = initProtoBuffer()
result.write(1, msg.msgType.uint)
msg.register.withValue(register):
result.write(2, register.encode())
msg.registerResponse.withValue(registerResponse):
result.write(3, registerResponse.encode())
msg.unregister.withValue(unregister):
result.write(4, unregister.encode())
msg.discover.withValue(discover):
result.write(5, discover.encode())
msg.discoverResponse.withValue(discoverResponse):
result.write(6, discoverResponse.encode())
result.finish()
proc decode(_: typedesc[Cookie], buf: seq[byte]): Opt[Cookie] =
var
c: Cookie
ns: string
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, c.offset)
r2 = pb.getField(2, ns)
if r1.isErr() or r2.isErr():
return Opt.none(Cookie)
if r2.get(false):
c.ns = Opt.some(ns)
Opt.some(c)
proc decode(_: typedesc[Register], buf: seq[byte]): Opt[Register] =
var
r: Register
ttl: uint64
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, r.ns)
r2 = pb.getRequiredField(2, r.signedPeerRecord)
r3 = pb.getField(3, ttl)
if r1.isErr() or r2.isErr() or r3.isErr():
return Opt.none(Register)
if r3.get(false):
r.ttl = Opt.some(ttl)
Opt.some(r)
proc decode(_: typedesc[RegisterResponse], buf: seq[byte]): Opt[RegisterResponse] =
var
rr: RegisterResponse
statusOrd: uint
text: string
ttl: uint64
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, statusOrd)
r2 = pb.getField(2, text)
r3 = pb.getField(3, ttl)
if r1.isErr() or r2.isErr() or r3.isErr() or
not checkedEnumAssign(rr.status, statusOrd):
return Opt.none(RegisterResponse)
if r2.get(false):
rr.text = Opt.some(text)
if r3.get(false):
rr.ttl = Opt.some(ttl)
Opt.some(rr)
proc decode(_: typedesc[Unregister], buf: seq[byte]): Opt[Unregister] =
var u: Unregister
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, u.ns)
if r1.isErr():
return Opt.none(Unregister)
Opt.some(u)
proc decode(_: typedesc[Discover], buf: seq[byte]): Opt[Discover] =
var
d: Discover
limit: uint64
cookie: seq[byte]
ns: string
let
pb = initProtoBuffer(buf)
r1 = pb.getField(1, ns)
r2 = pb.getField(2, limit)
r3 = pb.getField(3, cookie)
if r1.isErr() or r2.isErr() or r3.isErr:
return Opt.none(Discover)
if r1.get(false):
d.ns = Opt.some(ns)
if r2.get(false):
d.limit = Opt.some(limit)
if r3.get(false):
d.cookie = Opt.some(cookie)
Opt.some(d)
proc decode(_: typedesc[DiscoverResponse], buf: seq[byte]): Opt[DiscoverResponse] =
var
dr: DiscoverResponse
registrations: seq[seq[byte]]
cookie: seq[byte]
statusOrd: uint
text: string
let
pb = initProtoBuffer(buf)
r1 = pb.getRepeatedField(1, registrations)
r2 = pb.getField(2, cookie)
r3 = pb.getRequiredField(3, statusOrd)
r4 = pb.getField(4, text)
if r1.isErr() or r2.isErr() or r3.isErr or r4.isErr() or
not checkedEnumAssign(dr.status, statusOrd):
return Opt.none(DiscoverResponse)
for reg in registrations:
var r: Register
let regOpt = Register.decode(reg).valueOr:
return
dr.registrations.add(regOpt)
if r2.get(false):
dr.cookie = Opt.some(cookie)
if r4.get(false):
dr.text = Opt.some(text)
Opt.some(dr)
proc decode(_: typedesc[Message], buf: seq[byte]): Opt[Message] =
var
msg: Message
statusOrd: uint
pbr, pbrr, pbu, pbd, pbdr: ProtoBuffer
let pb = initProtoBuffer(buf)
?pb.getRequiredField(1, statusOrd).toOpt
if not checkedEnumAssign(msg.msgType, statusOrd):
return Opt.none(Message)
if ?pb.getField(2, pbr).optValue:
msg.register = Register.decode(pbr.buffer)
if msg.register.isNone():
return Opt.none(Message)
if ?pb.getField(3, pbrr).optValue:
msg.registerResponse = RegisterResponse.decode(pbrr.buffer)
if msg.registerResponse.isNone():
return Opt.none(Message)
if ?pb.getField(4, pbu).optValue:
msg.unregister = Unregister.decode(pbu.buffer)
if msg.unregister.isNone():
return Opt.none(Message)
if ?pb.getField(5, pbd).optValue:
msg.discover = Discover.decode(pbd.buffer)
if msg.discover.isNone():
return Opt.none(Message)
if ?pb.getField(6, pbdr).optValue:
msg.discoverResponse = DiscoverResponse.decode(pbdr.buffer)
if msg.discoverResponse.isNone():
return Opt.none(Message)
Opt.some(msg)
type
RendezVousError* = object of DiscoveryError
RegisteredData = object
expiration*: Moment
peerId*: PeerId
data*: Register
RendezVous* = ref object of LPProtocol
# Registered needs to be an offsetted sequence
# because we need stable index for the cookies.
registered*: OffsettedSeq[RegisteredData]
# Namespaces is a table whose key is a salted namespace and
# the value is the index sequence corresponding to this
# namespace in the offsettedqueue.
namespaces*: Table[string, seq[int]]
rng: ref HmacDrbgContext
salt: string
expiredDT: Moment
registerDeletionLoop: Future[void]
#registerEvent: AsyncEvent # TODO: to raise during the heartbeat
# + make the heartbeat sleep duration "smarter"
sema: AsyncSemaphore
peers: seq[PeerId]
cookiesSaved*: Table[PeerId, Table[string, seq[byte]]]
switch: Switch
minDuration: Duration
maxDuration: Duration
minTTL: uint64
maxTTL: uint64
proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] =
if spr.len == 0:
return err("Empty peer record")
let signedEnv = ?SignedPeerRecord.decode(spr).mapErr(x => $x)
if signedEnv.data.peerId != peerId:
return err("Bad Peer ID")
return ok()
proc sendRegisterResponse(
conn: Connection, ttl: uint64
) {.async: (raises: [CancelledError, LPStreamError]).} =
let msg = encode(
Message(
msgType: MessageType.RegisterResponse,
registerResponse: Opt.some(RegisterResponse(status: Ok, ttl: Opt.some(ttl))),
)
)
await conn.writeLp(msg.buffer)
proc sendRegisterResponseError(
conn: Connection, status: ResponseStatus, text: string = ""
) {.async: (raises: [CancelledError, LPStreamError]).} =
let msg = encode(
Message(
msgType: MessageType.RegisterResponse,
registerResponse: Opt.some(RegisterResponse(status: status, text: Opt.some(text))),
)
)
await conn.writeLp(msg.buffer)
proc sendDiscoverResponse(
conn: Connection, s: seq[Register], cookie: Cookie
) {.async: (raises: [CancelledError, LPStreamError]).} =
let msg = encode(
Message(
msgType: MessageType.DiscoverResponse,
discoverResponse: Opt.some(
DiscoverResponse(
status: Ok, registrations: s, cookie: Opt.some(cookie.encode().buffer)
)
),
)
)
await conn.writeLp(msg.buffer)
proc sendDiscoverResponseError(
conn: Connection, status: ResponseStatus, text: string = ""
) {.async: (raises: [CancelledError, LPStreamError]).} =
let msg = encode(
Message(
msgType: MessageType.DiscoverResponse,
discoverResponse: Opt.some(DiscoverResponse(status: status, text: Opt.some(text))),
)
)
await conn.writeLp(msg.buffer)
proc countRegister(rdv: RendezVous, peerId: PeerId): int =
for data in rdv.registered:
if data.peerId == peerId:
result.inc()
proc save(
rdv: RendezVous, ns: string, peerId: PeerId, r: Register, update: bool = true
) =
let nsSalted = ns & rdv.salt
discard rdv.namespaces.hasKeyOrPut(nsSalted, newSeq[int]())
try:
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].peerId == peerId:
if update == false:
return
rdv.registered[index].expiration = rdv.expiredDT
rdv.registered.add(
RegisteredData(
peerId: peerId,
expiration: Moment.now() + r.ttl.get(rdv.minTTL).int64.seconds,
data: r,
)
)
rdv.namespaces[nsSalted].add(rdv.registered.high)
# rdv.registerEvent.fire()
except KeyError as e:
doAssert false, "Should have key: " & e.msg
proc register(rdv: RendezVous, conn: Connection, r: Register): Future[void] =
trace "Received Register", peerId = conn.peerId, ns = r.ns
libp2p_rendezvous_register.inc()
if r.ns.len < MinimumNamespaceLen or r.ns.len > MaximumNamespaceLen:
return conn.sendRegisterResponseError(InvalidNamespace)
let ttl = r.ttl.get(rdv.minTTL)
if ttl < rdv.minTTL or ttl > rdv.maxTTL:
return conn.sendRegisterResponseError(InvalidTTL)
let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId)
if pr.isErr():
return conn.sendRegisterResponseError(InvalidSignedPeerRecord, pr.error())
if rdv.countRegister(conn.peerId) >= RegistrationLimitPerPeer:
return conn.sendRegisterResponseError(NotAuthorized, "Registration limit reached")
rdv.save(r.ns, conn.peerId, r)
libp2p_rendezvous_registered.inc()
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
conn.sendRegisterResponse(ttl)
proc unregister(rdv: RendezVous, conn: Connection, u: Unregister) =
trace "Received Unregister", peerId = conn.peerId, ns = u.ns
let nsSalted = u.ns & rdv.salt
try:
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].peerId == conn.peerId:
rdv.registered[index].expiration = rdv.expiredDT
libp2p_rendezvous_registered.dec()
except KeyError:
return
proc discover(
rdv: RendezVous, conn: Connection, d: Discover
) {.async: (raises: [CancelledError, LPStreamError]).} =
trace "Received Discover", peerId = conn.peerId, ns = d.ns
libp2p_rendezvous_discover.inc()
if d.ns.isSome() and d.ns.get().len > MaximumNamespaceLen:
await conn.sendDiscoverResponseError(InvalidNamespace)
return
var limit = min(DiscoverLimit, d.limit.get(DiscoverLimit))
var cookie =
if d.cookie.isSome():
try:
Cookie.decode(d.cookie.tryGet()).tryGet()
except CatchableError:
await conn.sendDiscoverResponseError(InvalidCookie)
return
else:
# Start from the current lowest index (inclusive)
Cookie(offset: rdv.registered.low().uint64)
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get():
# Namespace changed: start from the beginning of that namespace
cookie = Cookie(offset: rdv.registered.low().uint64)
elif cookie.offset < rdv.registered.low().uint64:
# Cookie behind available range: reset to current low
cookie.offset = rdv.registered.low().uint64
elif cookie.offset > (rdv.registered.high() + 1).uint64:
# Cookie ahead of available range: reset to one past current high (empty page)
cookie.offset = (rdv.registered.high() + 1).uint64
let namespaces =
if d.ns.isSome():
try:
rdv.namespaces[d.ns.get() & rdv.salt]
except KeyError:
await conn.sendDiscoverResponseError(InvalidNamespace)
return
else:
toSeq(max(cookie.offset.int, rdv.registered.offset) .. rdv.registered.high())
if namespaces.len() == 0:
await conn.sendDiscoverResponse(@[], Cookie())
return
var nextOffset = cookie.offset
let n = Moment.now()
var s = collect(newSeq()):
for index in namespaces:
var reg = rdv.registered[index]
if limit == 0:
break
if reg.expiration < n or index.uint64 < cookie.offset:
continue
limit.dec()
nextOffset = index.uint64 + 1
reg.data.ttl = Opt.some((reg.expiration - Moment.now()).seconds.uint64)
reg.data
rdv.rng.shuffle(s)
await conn.sendDiscoverResponse(s, Cookie(offset: nextOffset, ns: d.ns))
proc advertisePeer(
rdv: RendezVous, peer: PeerId, msg: seq[byte]
) {.async: (raises: [CancelledError]).} =
proc advertiseWrap() {.async: (raises: []).} =
try:
let conn = await rdv.switch.dial(peer, RendezVousCodec)
defer:
await conn.close()
await conn.writeLp(msg)
let
buf = await conn.readLp(4096)
msgRecv = Message.decode(buf).tryGet()
if msgRecv.msgType != MessageType.RegisterResponse:
trace "Unexpected register response", peer, msgType = msgRecv.msgType
elif msgRecv.registerResponse.tryGet().status != ResponseStatus.Ok:
trace "Refuse to register", peer, response = msgRecv.registerResponse
else:
trace "Successfully registered", peer, response = msgRecv.registerResponse
except CatchableError as exc:
trace "exception in the advertise", description = exc.msg
finally:
rdv.sema.release()
await rdv.sema.acquire()
await advertiseWrap()
proc advertise*(
rdv: RendezVous, ns: string, ttl: Duration, peers: seq[PeerId]
) {.async: (raises: [CancelledError, AdvertiseError]).} =
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
raise newException(AdvertiseError, "Invalid namespace")
if ttl < rdv.minDuration or ttl > rdv.maxDuration:
raise newException(AdvertiseError, "Invalid time to live: " & $ttl)
let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr:
raise newException(AdvertiseError, "Wrong Signed Peer Record")
let
r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64))
msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r)))
rdv.save(ns, rdv.switch.peerInfo.peerId, r)
let futs = collect(newSeq()):
for peer in peers:
trace "Send Advertise", peerId = peer, ns
rdv.advertisePeer(peer, msg.buffer).withTimeout(5.seconds)
await allFutures(futs)
method advertise*(
rdv: RendezVous, ns: string, ttl: Duration = rdv.minDuration
) {.base, async: (raises: [CancelledError, AdvertiseError]).} =
await rdv.advertise(ns, ttl, rdv.peers)
proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] =
let
nsSalted = ns & rdv.salt
n = Moment.now()
try:
collect(newSeq()):
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].expiration > n:
let res = SignedPeerRecord.decode(rdv.registered[index].data.signedPeerRecord).valueOr:
continue
res.data
except KeyError as exc:
@[]
proc request*(
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int, peers: seq[PeerId]
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
var
s: Table[PeerId, (PeerRecord, Register)]
limit: uint64
d = Discover(ns: ns)
if l <= 0 or l > DiscoverLimit.int:
raise newException(AdvertiseError, "Invalid limit")
if ns.isSome() and ns.get().len > MaximumNamespaceLen:
raise newException(AdvertiseError, "Invalid namespace")
limit = l.uint64
proc requestPeer(
peer: PeerId
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
let conn = await rdv.switch.dial(peer, RendezVousCodec)
defer:
await conn.close()
d.limit = Opt.some(limit)
d.cookie =
if ns.isSome():
try:
Opt.some(rdv.cookiesSaved[peer][ns.get()])
except KeyError, CatchableError:
Opt.none(seq[byte])
else:
Opt.none(seq[byte])
await conn.writeLp(
encode(Message(msgType: MessageType.Discover, discover: Opt.some(d))).buffer
)
let
buf = await conn.readLp(MaximumMessageLen)
msgRcv = Message.decode(buf).valueOr:
debug "Message undecodable"
return
if msgRcv.msgType != MessageType.DiscoverResponse:
debug "Unexpected discover response", msgType = msgRcv.msgType
return
let resp = msgRcv.discoverResponse.valueOr:
debug "Discover response is empty"
return
if resp.status != ResponseStatus.Ok:
trace "Cannot discover", ns, status = resp.status, text = resp.text
return
resp.cookie.withValue(cookie):
if ns.isSome:
let namespace = ns.get()
if cookie.len() < 1000 and
rdv.cookiesSaved.hasKeyOrPut(peer, {namespace: cookie}.toTable()):
try:
rdv.cookiesSaved[peer][namespace] = cookie
except KeyError:
raiseAssert "checked with hasKeyOrPut"
for r in resp.registrations:
if limit == 0:
return
let ttl = r.ttl.get(rdv.maxTTL + 1)
if ttl > rdv.maxTTL:
continue
let
spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr:
continue
pr = spr.data
if s.hasKey(pr.peerId):
let (prSaved, rSaved) =
try:
s[pr.peerId]
except KeyError:
raiseAssert "checked with hasKey"
if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(rdv.maxTTL) < ttl) or
prSaved.seqNo < pr.seqNo:
s[pr.peerId] = (pr, r)
else:
s[pr.peerId] = (pr, r)
limit.dec()
if ns.isSome():
for (_, r) in s.values():
rdv.save(ns.get(), peer, r, false)
for peer in peers:
if limit == 0:
break
if RendezVousCodec notin rdv.switch.peerStore[ProtoBook][peer]:
continue
try:
trace "Send Request", peerId = peer, ns
await peer.requestPeer()
except CancelledError as e:
raise e
except DialFailedError as e:
trace "failed to dial a peer", description = e.msg
except LPStreamError as e:
trace "failed to communicate with a peer", description = e.msg
return toSeq(s.values()).mapIt(it[0])
proc request*(
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
await rdv.request(ns, l, rdv.peers)
proc request*(
rdv: RendezVous, l: int = DiscoverLimit.int
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
await rdv.request(Opt.none(string), l, rdv.peers)
proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
let nsSalted = ns & rdv.salt
try:
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].peerId == rdv.switch.peerInfo.peerId:
rdv.registered[index].expiration = rdv.expiredDT
except KeyError:
return
proc unsubscribe*(
rdv: RendezVous, ns: string, peerIds: seq[PeerId]
) {.async: (raises: [RendezVousError, CancelledError]).} =
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
raise newException(RendezVousError, "Invalid namespace")
let msg = encode(
Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns)))
)
proc unsubscribePeer(peerId: PeerId) {.async: (raises: []).} =
try:
let conn = await rdv.switch.dial(peerId, RendezVousCodec)
defer:
await conn.close()
await conn.writeLp(msg.buffer)
except CatchableError as exc:
trace "exception while unsubscribing", description = exc.msg
let futs = collect(newSeq()):
for peer in peerIds:
unsubscribePeer(peer)
await allFutures(futs)
proc unsubscribe*(
rdv: RendezVous, ns: string
) {.async: (raises: [RendezVousError, CancelledError]).} =
rdv.unsubscribeLocally(ns)
await rdv.unsubscribe(ns, rdv.peers)
proc setup*(rdv: RendezVous, switch: Switch) =
rdv.switch = switch
proc handlePeer(
peerId: PeerId, event: PeerEvent
) {.async: (raises: [CancelledError]).} =
if event.kind == PeerEventKind.Joined:
rdv.peers.add(peerId)
elif event.kind == PeerEventKind.Left:
rdv.peers.keepItIf(it != peerId)
rdv.switch.addPeerEventHandler(handlePeer, Joined)
rdv.switch.addPeerEventHandler(handlePeer, Left)
proc new*(
T: typedesc[RendezVous],
rng: ref HmacDrbgContext = newRng(),
minDuration = MinimumDuration,
maxDuration = MaximumDuration,
): T {.raises: [RendezVousError].} =
if minDuration < MinimumAcceptedDuration:
raise newException(RendezVousError, "TTL too short: 1 minute minimum")
if maxDuration > MaximumDuration:
raise newException(RendezVousError, "TTL too long: 72 hours maximum")
if minDuration >= maxDuration:
raise newException(RendezVousError, "Minimum TTL longer than maximum")
let
minTTL = minDuration.seconds.uint64
maxTTL = maxDuration.seconds.uint64
let rdv = T(
rng: rng,
salt: string.fromBytes(generateBytes(rng[], 8)),
registered: initOffsettedSeq[RegisteredData](),
expiredDT: Moment.now() - 1.days,
#registerEvent: newAsyncEvent(),
sema: newAsyncSemaphore(SemaphoreDefaultSize),
minDuration: minDuration,
maxDuration: maxDuration,
minTTL: minTTL,
maxTTL: maxTTL,
)
logScope:
topics = "libp2p discovery rendezvous"
proc handleStream(
conn: Connection, proto: string
) {.async: (raises: [CancelledError]).} =
try:
let
buf = await conn.readLp(4096)
msg = Message.decode(buf).tryGet()
case msg.msgType
of MessageType.Register:
await rdv.register(conn, msg.register.tryGet())
of MessageType.RegisterResponse:
trace "Got an unexpected Register Response", response = msg.registerResponse
of MessageType.Unregister:
rdv.unregister(conn, msg.unregister.tryGet())
of MessageType.Discover:
await rdv.discover(conn, msg.discover.tryGet())
of MessageType.DiscoverResponse:
trace "Got an unexpected Discover Response", response = msg.discoverResponse
except CancelledError as exc:
trace "cancelled rendezvous handler"
raise exc
except CatchableError as exc:
trace "exception in rendezvous handler", description = exc.msg
finally:
await conn.close()
rdv.handler = handleStream
rdv.codec = RendezVousCodec
return rdv
proc new*(
T: typedesc[RendezVous],
switch: Switch,
rng: ref HmacDrbgContext = newRng(),
minDuration = MinimumDuration,
maxDuration = MaximumDuration,
): T {.raises: [RendezVousError].} =
let rdv = T.new(rng, minDuration, maxDuration)
rdv.setup(switch)
return rdv
proc deletesRegister*(
rdv: RendezVous, interval = 1.minutes
) {.async: (raises: [CancelledError]).} =
heartbeat "Register timeout", interval:
let n = Moment.now()
var total = 0
rdv.registered.flushIfIt(it.expiration < n)
for data in rdv.namespaces.mvalues():
data.keepItIf(it >= rdv.registered.offset)
total += data.len
libp2p_rendezvous_registered.set(int64(total))
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
method start*(
rdv: RendezVous
): Future[void] {.async: (raises: [CancelledError], raw: true).} =
let fut = newFuture[void]()
fut.complete()
if not rdv.registerDeletionLoop.isNil:
warn "Starting rendezvous twice"
return fut
rdv.registerDeletionLoop = rdv.deletesRegister()
rdv.started = true
fut
method stop*(rdv: RendezVous): Future[void] {.async: (raises: [], raw: true).} =
let fut = newFuture[void]()
fut.complete()
if rdv.registerDeletionLoop.isNil:
warn "Stopping rendezvous without starting it"
return fut
rdv.started = false
rdv.registerDeletionLoop.cancelSoon()
rdv.registerDeletionLoop = nil
fut
export rendezvous

View File

@@ -0,0 +1,275 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import results
import stew/objects
import ../../protobuf/minprotobuf
type
MessageType* {.pure.} = enum
Register = 0
RegisterResponse = 1
Unregister = 2
Discover = 3
DiscoverResponse = 4
ResponseStatus* = enum
Ok = 0
InvalidNamespace = 100
InvalidSignedPeerRecord = 101
InvalidTTL = 102
InvalidCookie = 103
NotAuthorized = 200
InternalError = 300
Unavailable = 400
Cookie* = object
offset*: uint64
ns*: Opt[string]
Register* = object
ns*: string
signedPeerRecord*: seq[byte]
ttl*: Opt[uint64] # in seconds
RegisterResponse* = object
status*: ResponseStatus
text*: Opt[string]
ttl*: Opt[uint64] # in seconds
Unregister* = object
ns*: string
Discover* = object
ns*: Opt[string]
limit*: Opt[uint64]
cookie*: Opt[seq[byte]]
DiscoverResponse* = object
registrations*: seq[Register]
cookie*: Opt[seq[byte]]
status*: ResponseStatus
text*: Opt[string]
Message* = object
msgType*: MessageType
register*: Opt[Register]
registerResponse*: Opt[RegisterResponse]
unregister*: Opt[Unregister]
discover*: Opt[Discover]
discoverResponse*: Opt[DiscoverResponse]
proc encode*(c: Cookie): ProtoBuffer =
result = initProtoBuffer()
result.write(1, c.offset)
if c.ns.isSome():
result.write(2, c.ns.get())
result.finish()
proc encode*(r: Register): ProtoBuffer =
result = initProtoBuffer()
result.write(1, r.ns)
result.write(2, r.signedPeerRecord)
r.ttl.withValue(ttl):
result.write(3, ttl)
result.finish()
proc encode*(rr: RegisterResponse): ProtoBuffer =
result = initProtoBuffer()
result.write(1, rr.status.uint)
rr.text.withValue(text):
result.write(2, text)
rr.ttl.withValue(ttl):
result.write(3, ttl)
result.finish()
proc encode*(u: Unregister): ProtoBuffer =
result = initProtoBuffer()
result.write(1, u.ns)
result.finish()
proc encode*(d: Discover): ProtoBuffer =
result = initProtoBuffer()
if d.ns.isSome():
result.write(1, d.ns.get())
d.limit.withValue(limit):
result.write(2, limit)
d.cookie.withValue(cookie):
result.write(3, cookie)
result.finish()
proc encode*(dr: DiscoverResponse): ProtoBuffer =
result = initProtoBuffer()
for reg in dr.registrations:
result.write(1, reg.encode())
dr.cookie.withValue(cookie):
result.write(2, cookie)
result.write(3, dr.status.uint)
dr.text.withValue(text):
result.write(4, text)
result.finish()
proc encode*(msg: Message): ProtoBuffer =
result = initProtoBuffer()
result.write(1, msg.msgType.uint)
msg.register.withValue(register):
result.write(2, register.encode())
msg.registerResponse.withValue(registerResponse):
result.write(3, registerResponse.encode())
msg.unregister.withValue(unregister):
result.write(4, unregister.encode())
msg.discover.withValue(discover):
result.write(5, discover.encode())
msg.discoverResponse.withValue(discoverResponse):
result.write(6, discoverResponse.encode())
result.finish()
proc decode*(_: typedesc[Cookie], buf: seq[byte]): Opt[Cookie] =
var
c: Cookie
ns: string
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, c.offset)
r2 = pb.getField(2, ns)
if r1.isErr() or r2.isErr():
return Opt.none(Cookie)
if r2.get(false):
c.ns = Opt.some(ns)
Opt.some(c)
proc decode*(_: typedesc[Register], buf: seq[byte]): Opt[Register] =
var
r: Register
ttl: uint64
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, r.ns)
r2 = pb.getRequiredField(2, r.signedPeerRecord)
r3 = pb.getField(3, ttl)
if r1.isErr() or r2.isErr() or r3.isErr():
return Opt.none(Register)
if r3.get(false):
r.ttl = Opt.some(ttl)
Opt.some(r)
proc decode*(_: typedesc[RegisterResponse], buf: seq[byte]): Opt[RegisterResponse] =
var
rr: RegisterResponse
statusOrd: uint
text: string
ttl: uint64
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, statusOrd)
r2 = pb.getField(2, text)
r3 = pb.getField(3, ttl)
if r1.isErr() or r2.isErr() or r3.isErr() or
not checkedEnumAssign(rr.status, statusOrd):
return Opt.none(RegisterResponse)
if r2.get(false):
rr.text = Opt.some(text)
if r3.get(false):
rr.ttl = Opt.some(ttl)
Opt.some(rr)
proc decode*(_: typedesc[Unregister], buf: seq[byte]): Opt[Unregister] =
var u: Unregister
let
pb = initProtoBuffer(buf)
r1 = pb.getRequiredField(1, u.ns)
if r1.isErr():
return Opt.none(Unregister)
Opt.some(u)
proc decode*(_: typedesc[Discover], buf: seq[byte]): Opt[Discover] =
var
d: Discover
limit: uint64
cookie: seq[byte]
ns: string
let
pb = initProtoBuffer(buf)
r1 = pb.getField(1, ns)
r2 = pb.getField(2, limit)
r3 = pb.getField(3, cookie)
if r1.isErr() or r2.isErr() or r3.isErr:
return Opt.none(Discover)
if r1.get(false):
d.ns = Opt.some(ns)
if r2.get(false):
d.limit = Opt.some(limit)
if r3.get(false):
d.cookie = Opt.some(cookie)
Opt.some(d)
proc decode*(_: typedesc[DiscoverResponse], buf: seq[byte]): Opt[DiscoverResponse] =
var
dr: DiscoverResponse
registrations: seq[seq[byte]]
cookie: seq[byte]
statusOrd: uint
text: string
let
pb = initProtoBuffer(buf)
r1 = pb.getRepeatedField(1, registrations)
r2 = pb.getField(2, cookie)
r3 = pb.getRequiredField(3, statusOrd)
r4 = pb.getField(4, text)
if r1.isErr() or r2.isErr() or r3.isErr or r4.isErr() or
not checkedEnumAssign(dr.status, statusOrd):
return Opt.none(DiscoverResponse)
for reg in registrations:
var r: Register
let regOpt = Register.decode(reg).valueOr:
return
dr.registrations.add(regOpt)
if r2.get(false):
dr.cookie = Opt.some(cookie)
if r4.get(false):
dr.text = Opt.some(text)
Opt.some(dr)
proc decode*(_: typedesc[Message], buf: seq[byte]): Opt[Message] =
var
msg: Message
statusOrd: uint
pbr, pbrr, pbu, pbd, pbdr: ProtoBuffer
let pb = initProtoBuffer(buf)
?pb.getRequiredField(1, statusOrd).toOpt
if not checkedEnumAssign(msg.msgType, statusOrd):
return Opt.none(Message)
if ?pb.getField(2, pbr).optValue:
msg.register = Register.decode(pbr.buffer)
if msg.register.isNone():
return Opt.none(Message)
if ?pb.getField(3, pbrr).optValue:
msg.registerResponse = RegisterResponse.decode(pbrr.buffer)
if msg.registerResponse.isNone():
return Opt.none(Message)
if ?pb.getField(4, pbu).optValue:
msg.unregister = Unregister.decode(pbu.buffer)
if msg.unregister.isNone():
return Opt.none(Message)
if ?pb.getField(5, pbd).optValue:
msg.discover = Discover.decode(pbd.buffer)
if msg.discover.isNone():
return Opt.none(Message)
if ?pb.getField(6, pbdr).optValue:
msg.discoverResponse = DiscoverResponse.decode(pbdr.buffer)
if msg.discoverResponse.isNone():
return Opt.none(Message)
Opt.some(msg)

View File

@@ -0,0 +1,589 @@
# Nim-LibP2P
# Copyright (c) 2023-2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
{.push raises: [].}
import tables, sequtils, sugar, sets
import metrics except collect
import chronos, chronicles, bearssl/rand, stew/[byteutils, objects]
import
./protobuf,
../protocol,
../../protobuf/minprotobuf,
../../switch,
../../routing_record,
../../utils/heartbeat,
../../stream/connection,
../../utils/offsettedseq,
../../utils/semaphore,
../../discovery/discoverymngr
export chronicles
logScope:
topics = "libp2p discovery rendezvous"
declareCounter(libp2p_rendezvous_register, "number of advertise requests")
declareCounter(libp2p_rendezvous_discover, "number of discovery requests")
declareGauge(libp2p_rendezvous_registered, "number of registered peers")
declareGauge(libp2p_rendezvous_namespaces, "number of registered namespaces")
const
RendezVousCodec* = "/rendezvous/1.0.0"
# Default minimum TTL per libp2p spec
MinimumDuration* = 2.hours
# Lower validation limit to accommodate Waku requirements
MinimumAcceptedDuration = 1.minutes
MaximumDuration = 72.hours
MaximumMessageLen = 1 shl 22 # 4MB
MinimumNamespaceLen = 1
MaximumNamespaceLen = 255
RegistrationLimitPerPeer* = 1000
DiscoverLimit = 1000'u64
SemaphoreDefaultSize = 5
type
RendezVousError* = object of DiscoveryError
RegisteredData = object
expiration*: Moment
peerId*: PeerId
data*: Register
RendezVous* = ref object of LPProtocol
# Registered needs to be an offsetted sequence
# because we need stable index for the cookies.
registered*: OffsettedSeq[RegisteredData]
# Namespaces is a table whose key is a salted namespace and
# the value is the index sequence corresponding to this
# namespace in the offsettedqueue.
namespaces*: Table[string, seq[int]]
rng: ref HmacDrbgContext
salt: string
expiredDT: Moment
registerDeletionLoop: Future[void]
#registerEvent: AsyncEvent # TODO: to raise during the heartbeat
# + make the heartbeat sleep duration "smarter"
sema: AsyncSemaphore
peers: seq[PeerId]
cookiesSaved*: Table[PeerId, Table[string, seq[byte]]]
switch: Switch
minDuration: Duration
maxDuration: Duration
minTTL: uint64
maxTTL: uint64
proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] =
if spr.len == 0:
return err("Empty peer record")
let signedEnv = ?SignedPeerRecord.decode(spr).mapErr(x => $x)
if signedEnv.data.peerId != peerId:
return err("Bad Peer ID")
return ok()
proc sendRegisterResponse(
conn: Connection, ttl: uint64
) {.async: (raises: [CancelledError, LPStreamError]).} =
let msg = encode(
Message(
msgType: MessageType.RegisterResponse,
registerResponse: Opt.some(RegisterResponse(status: Ok, ttl: Opt.some(ttl))),
)
)
await conn.writeLp(msg.buffer)
proc sendRegisterResponseError(
conn: Connection, status: ResponseStatus, text: string = ""
) {.async: (raises: [CancelledError, LPStreamError]).} =
let msg = encode(
Message(
msgType: MessageType.RegisterResponse,
registerResponse: Opt.some(RegisterResponse(status: status, text: Opt.some(text))),
)
)
await conn.writeLp(msg.buffer)
proc sendDiscoverResponse(
conn: Connection, s: seq[Register], cookie: Cookie
) {.async: (raises: [CancelledError, LPStreamError]).} =
let msg = encode(
Message(
msgType: MessageType.DiscoverResponse,
discoverResponse: Opt.some(
DiscoverResponse(
status: Ok, registrations: s, cookie: Opt.some(cookie.encode().buffer)
)
),
)
)
await conn.writeLp(msg.buffer)
proc sendDiscoverResponseError(
conn: Connection, status: ResponseStatus, text: string = ""
) {.async: (raises: [CancelledError, LPStreamError]).} =
let msg = encode(
Message(
msgType: MessageType.DiscoverResponse,
discoverResponse: Opt.some(DiscoverResponse(status: status, text: Opt.some(text))),
)
)
await conn.writeLp(msg.buffer)
proc countRegister(rdv: RendezVous, peerId: PeerId): int =
for data in rdv.registered:
if data.peerId == peerId:
result.inc()
proc save(
rdv: RendezVous, ns: string, peerId: PeerId, r: Register, update: bool = true
) =
let nsSalted = ns & rdv.salt
discard rdv.namespaces.hasKeyOrPut(nsSalted, newSeq[int]())
try:
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].peerId == peerId:
if update == false:
return
rdv.registered[index].expiration = rdv.expiredDT
rdv.registered.add(
RegisteredData(
peerId: peerId,
expiration: Moment.now() + r.ttl.get(rdv.minTTL).int64.seconds,
data: r,
)
)
rdv.namespaces[nsSalted].add(rdv.registered.high)
# rdv.registerEvent.fire()
except KeyError as e:
doAssert false, "Should have key: " & e.msg
proc register(rdv: RendezVous, conn: Connection, r: Register): Future[void] =
trace "Received Register", peerId = conn.peerId, ns = r.ns
libp2p_rendezvous_register.inc()
if r.ns.len < MinimumNamespaceLen or r.ns.len > MaximumNamespaceLen:
return conn.sendRegisterResponseError(InvalidNamespace)
let ttl = r.ttl.get(rdv.minTTL)
if ttl < rdv.minTTL or ttl > rdv.maxTTL:
return conn.sendRegisterResponseError(InvalidTTL)
let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId)
if pr.isErr():
return conn.sendRegisterResponseError(InvalidSignedPeerRecord, pr.error())
if rdv.countRegister(conn.peerId) >= RegistrationLimitPerPeer:
return conn.sendRegisterResponseError(NotAuthorized, "Registration limit reached")
rdv.save(r.ns, conn.peerId, r)
libp2p_rendezvous_registered.inc()
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
conn.sendRegisterResponse(ttl)
proc unregister(rdv: RendezVous, conn: Connection, u: Unregister) =
trace "Received Unregister", peerId = conn.peerId, ns = u.ns
let nsSalted = u.ns & rdv.salt
try:
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].peerId == conn.peerId:
rdv.registered[index].expiration = rdv.expiredDT
libp2p_rendezvous_registered.dec()
except KeyError:
return
proc discover(
rdv: RendezVous, conn: Connection, d: Discover
) {.async: (raises: [CancelledError, LPStreamError]).} =
trace "Received Discover", peerId = conn.peerId, ns = d.ns
libp2p_rendezvous_discover.inc()
if d.ns.isSome() and d.ns.get().len > MaximumNamespaceLen:
await conn.sendDiscoverResponseError(InvalidNamespace)
return
var limit = min(DiscoverLimit, d.limit.get(DiscoverLimit))
var cookie =
if d.cookie.isSome():
try:
Cookie.decode(d.cookie.tryGet()).tryGet()
except CatchableError:
await conn.sendDiscoverResponseError(InvalidCookie)
return
else:
# Start from the current lowest index (inclusive)
Cookie(offset: rdv.registered.low().uint64)
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get():
# Namespace changed: start from the beginning of that namespace
cookie = Cookie(offset: rdv.registered.low().uint64)
elif cookie.offset < rdv.registered.low().uint64:
# Cookie behind available range: reset to current low
cookie.offset = rdv.registered.low().uint64
elif cookie.offset > (rdv.registered.high() + 1).uint64:
# Cookie ahead of available range: reset to one past current high (empty page)
cookie.offset = (rdv.registered.high() + 1).uint64
let namespaces =
if d.ns.isSome():
try:
rdv.namespaces[d.ns.get() & rdv.salt]
except KeyError:
await conn.sendDiscoverResponseError(InvalidNamespace)
return
else:
toSeq(max(cookie.offset.int, rdv.registered.offset) .. rdv.registered.high())
if namespaces.len() == 0:
await conn.sendDiscoverResponse(@[], Cookie())
return
var nextOffset = cookie.offset
let n = Moment.now()
var s = collect(newSeq()):
for index in namespaces:
var reg = rdv.registered[index]
if limit == 0:
break
if reg.expiration < n or index.uint64 < cookie.offset:
continue
limit.dec()
nextOffset = index.uint64 + 1
reg.data.ttl = Opt.some((reg.expiration - Moment.now()).seconds.uint64)
reg.data
rdv.rng.shuffle(s)
await conn.sendDiscoverResponse(s, Cookie(offset: nextOffset, ns: d.ns))
proc advertisePeer(
rdv: RendezVous, peer: PeerId, msg: seq[byte]
) {.async: (raises: [CancelledError]).} =
proc advertiseWrap() {.async: (raises: []).} =
try:
let conn = await rdv.switch.dial(peer, RendezVousCodec)
defer:
await conn.close()
await conn.writeLp(msg)
let
buf = await conn.readLp(4096)
msgRecv = Message.decode(buf).tryGet()
if msgRecv.msgType != MessageType.RegisterResponse:
trace "Unexpected register response", peer, msgType = msgRecv.msgType
elif msgRecv.registerResponse.tryGet().status != ResponseStatus.Ok:
trace "Refuse to register", peer, response = msgRecv.registerResponse
else:
trace "Successfully registered", peer, response = msgRecv.registerResponse
except CatchableError as exc:
trace "exception in the advertise", description = exc.msg
finally:
rdv.sema.release()
await rdv.sema.acquire()
await advertiseWrap()
proc advertise*(
rdv: RendezVous, ns: string, ttl: Duration, peers: seq[PeerId]
) {.async: (raises: [CancelledError, AdvertiseError]).} =
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
raise newException(AdvertiseError, "Invalid namespace")
if ttl < rdv.minDuration or ttl > rdv.maxDuration:
raise newException(AdvertiseError, "Invalid time to live: " & $ttl)
let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr:
raise newException(AdvertiseError, "Wrong Signed Peer Record")
let
r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64))
msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r)))
rdv.save(ns, rdv.switch.peerInfo.peerId, r)
let futs = collect(newSeq()):
for peer in peers:
trace "Send Advertise", peerId = peer, ns
rdv.advertisePeer(peer, msg.buffer).withTimeout(5.seconds)
await allFutures(futs)
method advertise*(
rdv: RendezVous, ns: string, ttl: Duration = rdv.minDuration
) {.base, async: (raises: [CancelledError, AdvertiseError]).} =
await rdv.advertise(ns, ttl, rdv.peers)
proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] =
let
nsSalted = ns & rdv.salt
n = Moment.now()
try:
collect(newSeq()):
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].expiration > n:
let res = SignedPeerRecord.decode(rdv.registered[index].data.signedPeerRecord).valueOr:
continue
res.data
except KeyError as exc:
@[]
proc request*(
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int, peers: seq[PeerId]
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
var
s: Table[PeerId, (PeerRecord, Register)]
limit: uint64
d = Discover(ns: ns)
if l <= 0 or l > DiscoverLimit.int:
raise newException(AdvertiseError, "Invalid limit")
if ns.isSome() and ns.get().len > MaximumNamespaceLen:
raise newException(AdvertiseError, "Invalid namespace")
limit = l.uint64
proc requestPeer(
peer: PeerId
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
let conn = await rdv.switch.dial(peer, RendezVousCodec)
defer:
await conn.close()
d.limit = Opt.some(limit)
d.cookie =
if ns.isSome():
try:
Opt.some(rdv.cookiesSaved[peer][ns.get()])
except KeyError, CatchableError:
Opt.none(seq[byte])
else:
Opt.none(seq[byte])
await conn.writeLp(
encode(Message(msgType: MessageType.Discover, discover: Opt.some(d))).buffer
)
let
buf = await conn.readLp(MaximumMessageLen)
msgRcv = Message.decode(buf).valueOr:
debug "Message undecodable"
return
if msgRcv.msgType != MessageType.DiscoverResponse:
debug "Unexpected discover response", msgType = msgRcv.msgType
return
let resp = msgRcv.discoverResponse.valueOr:
debug "Discover response is empty"
return
if resp.status != ResponseStatus.Ok:
trace "Cannot discover", ns, status = resp.status, text = resp.text
return
resp.cookie.withValue(cookie):
if ns.isSome:
let namespace = ns.get()
if cookie.len() < 1000 and
rdv.cookiesSaved.hasKeyOrPut(peer, {namespace: cookie}.toTable()):
try:
rdv.cookiesSaved[peer][namespace] = cookie
except KeyError:
raiseAssert "checked with hasKeyOrPut"
for r in resp.registrations:
if limit == 0:
return
let ttl = r.ttl.get(rdv.maxTTL + 1)
if ttl > rdv.maxTTL:
continue
let
spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr:
continue
pr = spr.data
if s.hasKey(pr.peerId):
let (prSaved, rSaved) =
try:
s[pr.peerId]
except KeyError:
raiseAssert "checked with hasKey"
if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(rdv.maxTTL) < ttl) or
prSaved.seqNo < pr.seqNo:
s[pr.peerId] = (pr, r)
else:
s[pr.peerId] = (pr, r)
limit.dec()
if ns.isSome():
for (_, r) in s.values():
rdv.save(ns.get(), peer, r, false)
for peer in peers:
if limit == 0:
break
if RendezVousCodec notin rdv.switch.peerStore[ProtoBook][peer]:
continue
try:
trace "Send Request", peerId = peer, ns
await peer.requestPeer()
except CancelledError as e:
raise e
except DialFailedError as e:
trace "failed to dial a peer", description = e.msg
except LPStreamError as e:
trace "failed to communicate with a peer", description = e.msg
return toSeq(s.values()).mapIt(it[0])
proc request*(
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
await rdv.request(ns, l, rdv.peers)
proc request*(
rdv: RendezVous, l: int = DiscoverLimit.int
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
await rdv.request(Opt.none(string), l, rdv.peers)
proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
let nsSalted = ns & rdv.salt
try:
for index in rdv.namespaces[nsSalted]:
if rdv.registered[index].peerId == rdv.switch.peerInfo.peerId:
rdv.registered[index].expiration = rdv.expiredDT
except KeyError:
return
proc unsubscribe*(
rdv: RendezVous, ns: string, peerIds: seq[PeerId]
) {.async: (raises: [RendezVousError, CancelledError]).} =
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
raise newException(RendezVousError, "Invalid namespace")
let msg = encode(
Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns)))
)
proc unsubscribePeer(peerId: PeerId) {.async: (raises: []).} =
try:
let conn = await rdv.switch.dial(peerId, RendezVousCodec)
defer:
await conn.close()
await conn.writeLp(msg.buffer)
except CatchableError as exc:
trace "exception while unsubscribing", description = exc.msg
let futs = collect(newSeq()):
for peer in peerIds:
unsubscribePeer(peer)
await allFutures(futs)
proc unsubscribe*(
rdv: RendezVous, ns: string
) {.async: (raises: [RendezVousError, CancelledError]).} =
rdv.unsubscribeLocally(ns)
await rdv.unsubscribe(ns, rdv.peers)
proc setup*(rdv: RendezVous, switch: Switch) =
rdv.switch = switch
proc handlePeer(
peerId: PeerId, event: PeerEvent
) {.async: (raises: [CancelledError]).} =
if event.kind == PeerEventKind.Joined:
rdv.peers.add(peerId)
elif event.kind == PeerEventKind.Left:
rdv.peers.keepItIf(it != peerId)
rdv.switch.addPeerEventHandler(handlePeer, Joined)
rdv.switch.addPeerEventHandler(handlePeer, Left)
proc new*(
T: typedesc[RendezVous],
rng: ref HmacDrbgContext = newRng(),
minDuration = MinimumDuration,
maxDuration = MaximumDuration,
): T {.raises: [RendezVousError].} =
if minDuration < MinimumAcceptedDuration:
raise newException(RendezVousError, "TTL too short: 1 minute minimum")
if maxDuration > MaximumDuration:
raise newException(RendezVousError, "TTL too long: 72 hours maximum")
if minDuration >= maxDuration:
raise newException(RendezVousError, "Minimum TTL longer than maximum")
let
minTTL = minDuration.seconds.uint64
maxTTL = maxDuration.seconds.uint64
let rdv = T(
rng: rng,
salt: string.fromBytes(generateBytes(rng[], 8)),
registered: initOffsettedSeq[RegisteredData](),
expiredDT: Moment.now() - 1.days,
#registerEvent: newAsyncEvent(),
sema: newAsyncSemaphore(SemaphoreDefaultSize),
minDuration: minDuration,
maxDuration: maxDuration,
minTTL: minTTL,
maxTTL: maxTTL,
)
logScope:
topics = "libp2p discovery rendezvous"
proc handleStream(
conn: Connection, proto: string
) {.async: (raises: [CancelledError]).} =
try:
let
buf = await conn.readLp(4096)
msg = Message.decode(buf).tryGet()
case msg.msgType
of MessageType.Register:
await rdv.register(conn, msg.register.tryGet())
of MessageType.RegisterResponse:
trace "Got an unexpected Register Response", response = msg.registerResponse
of MessageType.Unregister:
rdv.unregister(conn, msg.unregister.tryGet())
of MessageType.Discover:
await rdv.discover(conn, msg.discover.tryGet())
of MessageType.DiscoverResponse:
trace "Got an unexpected Discover Response", response = msg.discoverResponse
except CancelledError as exc:
trace "cancelled rendezvous handler"
raise exc
except CatchableError as exc:
trace "exception in rendezvous handler", description = exc.msg
finally:
await conn.close()
rdv.handler = handleStream
rdv.codec = RendezVousCodec
return rdv
proc new*(
T: typedesc[RendezVous],
switch: Switch,
rng: ref HmacDrbgContext = newRng(),
minDuration = MinimumDuration,
maxDuration = MaximumDuration,
): T {.raises: [RendezVousError].} =
let rdv = T.new(rng, minDuration, maxDuration)
rdv.setup(switch)
return rdv
proc deletesRegister*(
rdv: RendezVous, interval = 1.minutes
) {.async: (raises: [CancelledError]).} =
heartbeat "Register timeout", interval:
let n = Moment.now()
var total = 0
rdv.registered.flushIfIt(it.expiration < n)
for data in rdv.namespaces.mvalues():
data.keepItIf(it >= rdv.registered.offset)
total += data.len
libp2p_rendezvous_registered.set(int64(total))
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
method start*(
rdv: RendezVous
): Future[void] {.async: (raises: [CancelledError], raw: true).} =
let fut = newFuture[void]()
fut.complete()
if not rdv.registerDeletionLoop.isNil:
warn "Starting rendezvous twice"
return fut
rdv.registerDeletionLoop = rdv.deletesRegister()
rdv.started = true
fut
method stop*(rdv: RendezVous): Future[void] {.async: (raises: [], raw: true).} =
let fut = newFuture[void]()
fut.complete()
if rdv.registerDeletionLoop.isNil:
warn "Stopping rendezvous without starting it"
return fut
rdv.started = false
rdv.registerDeletionLoop.cancelSoon()
rdv.registerDeletionLoop = nil
fut

View File

@@ -195,7 +195,7 @@ type CertGenerator =
type QuicTransport* = ref object of Transport
listener: Listener
client: QuicClient
client: Opt[QuicClient]
privateKey: PrivateKey
connections: seq[P2PConnection]
rng: ref HmacDrbgContext
@@ -248,27 +248,33 @@ method handles*(transport: QuicTransport, address: MultiAddress): bool {.raises:
return false
QUIC_V1.match(address)
method start*(
self: QuicTransport, addrs: seq[MultiAddress]
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
doAssert self.listener.isNil, "start() already called"
#TODO handle multiple addr
proc makeConfig(self: QuicTransport): TLSConfig =
let pubkey = self.privateKey.getPublicKey().valueOr:
doAssert false, "could not obtain public key"
return
try:
if self.rng.isNil:
self.rng = newRng()
let cert = self.certGenerator(KeyPair(seckey: self.privateKey, pubkey: pubkey))
let tlsConfig = TLSConfig.init(
cert.certificate, cert.privateKey, @[alpn], Opt.some(makeCertificateVerifier())
)
return tlsConfig
let cert = self.certGenerator(KeyPair(seckey: self.privateKey, pubkey: pubkey))
let tlsConfig = TLSConfig.init(
cert.certificate, cert.privateKey, @[alpn], Opt.some(makeCertificateVerifier())
)
self.client = QuicClient.init(tlsConfig, rng = self.rng)
self.listener =
QuicServer.init(tlsConfig, rng = self.rng).listen(initTAddress(addrs[0]).tryGet)
proc getRng(self: QuicTransport): ref HmacDrbgContext =
if self.rng.isNil:
self.rng = newRng()
return self.rng
method start*(
self: QuicTransport, addrs: seq[MultiAddress]
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
doAssert self.listener.isNil, "start() already called"
# TODO(#1663): handle multiple addr
try:
self.listener = QuicServer.init(self.makeConfig(), rng = self.getRng()).listen(
initTAddress(addrs[0]).tryGet
)
await procCall Transport(self).start(addrs)
self.addrs[0] =
MultiAddress.init(self.listener.localAddress(), IPPROTO_UDP).tryGet() &
@@ -289,19 +295,21 @@ method start*(
self.running = true
method stop*(transport: QuicTransport) {.async: (raises: []).} =
if transport.running:
let conns = transport.connections[0 .. ^1]
for c in conns:
await c.close()
await procCall Transport(transport).stop()
let conns = transport.connections[0 .. ^1]
for c in conns:
await c.close()
if not transport.listener.isNil:
try:
await transport.listener.stop()
except CatchableError as exc:
trace "Error shutting down Quic transport", description = exc.msg
transport.listener.destroy()
transport.running = false
transport.listener = nil
transport.client = Opt.none(QuicClient)
await procCall Transport(transport).stop()
proc wrapConnection(
transport: QuicTransport, connection: QuicConnection
): QuicSession {.raises: [TransportOsError, MaError].} =
@@ -365,7 +373,11 @@ method dial*(
async: (raises: [transport.TransportError, CancelledError])
.} =
try:
let quicConnection = await self.client.dial(initTAddress(address).tryGet)
if not self.client.isSome:
self.client = Opt.some(QuicClient.init(self.makeConfig(), rng = self.getRng()))
let client = self.client.get()
let quicConnection = await client.dial(initTAddress(address).tryGet)
return self.wrapConnection(quicConnection)
except CancelledError as e:
raise e

View File

@@ -1,3 +1,4 @@
{.used.}
import testdiscoverymngr, testrendezvous, testrendezvousinterface
import
testdiscoverymngr, testrendezvous, testrendezvousprotobuf, testrendezvousinterface

View File

@@ -0,0 +1,256 @@
{.used.}
# Nim-Libp2p
# Copyright (c) 2025 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import ../../libp2p/[protocols/rendezvous/protobuf]
import ../../libp2p/protobuf/minprotobuf
import ../helpers
suite "RendezVous Protobuf":
teardown:
checkTrackers()
const namespace = "ns"
test "Cookie roundtrip with namespace":
let originalCookie = Cookie(offset: 42'u64, ns: Opt.some(namespace))
let decodeResult = Cookie.decode(originalCookie.encode().buffer)
check decodeResult.isSome()
let decodedCookie = decodeResult.get()
check:
decodedCookie.offset == originalCookie.offset
decodedCookie.ns.get() == originalCookie.ns.get()
decodedCookie.encode().buffer == originalCookie.encode().buffer # roundtrip again
test "Cookie roundtrip without namespace":
let originalCookie = Cookie(offset: 7'u64)
let decodeResult = Cookie.decode(originalCookie.encode().buffer)
check decodeResult.isSome()
let decodedCookie = decodeResult.get()
check:
decodedCookie.offset == originalCookie.offset
decodedCookie.ns.isNone()
decodedCookie.encode().buffer == originalCookie.encode().buffer
test "Cookie decode fails when missing offset":
var emptyCookieBuffer = initProtoBuffer()
check Cookie.decode(emptyCookieBuffer.buffer).isNone()
test "Register roundtrip with ttl":
let originalRegister =
Register(ns: namespace, signedPeerRecord: @[byte 1, 2, 3], ttl: Opt.some(60'u64))
let decodeResult = Register.decode(originalRegister.encode().buffer)
check decodeResult.isSome()
let decodedRegister = decodeResult.get()
check:
decodedRegister.ns == originalRegister.ns
decodedRegister.signedPeerRecord == originalRegister.signedPeerRecord
decodedRegister.ttl.get() == originalRegister.ttl.get()
decodedRegister.encode().buffer == originalRegister.encode().buffer
test "Register roundtrip without ttl":
let originalRegister = Register(ns: namespace, signedPeerRecord: @[byte 4, 5])
let decodeResult = Register.decode(originalRegister.encode().buffer)
check decodeResult.isSome()
let decodedRegister = decodeResult.get()
check:
decodedRegister.ns == originalRegister.ns
decodedRegister.signedPeerRecord == originalRegister.signedPeerRecord
decodedRegister.ttl.isNone()
decodedRegister.encode().buffer == originalRegister.encode().buffer
test "Register decode fails when missing namespace":
var bufferMissingNamespace = initProtoBuffer()
bufferMissingNamespace.write(2, @[byte 1])
check Register.decode(bufferMissingNamespace.buffer).isNone()
test "Register decode fails when missing signedPeerRecord":
var bufferMissingSignedPeerRecord = initProtoBuffer()
bufferMissingSignedPeerRecord.write(1, namespace)
check Register.decode(bufferMissingSignedPeerRecord.buffer).isNone()
test "RegisterResponse roundtrip successful":
let originalResponse = RegisterResponse(
status: ResponseStatus.Ok, text: Opt.some("ok"), ttl: Opt.some(10'u64)
)
let decodeResult = RegisterResponse.decode(originalResponse.encode().buffer)
check decodeResult.isSome()
let decodedResponse = decodeResult.get()
check:
decodedResponse.status == originalResponse.status
decodedResponse.text.get() == originalResponse.text.get()
decodedResponse.ttl.get() == originalResponse.ttl.get()
decodedResponse.encode().buffer == originalResponse.encode().buffer
test "RegisterResponse roundtrip failed":
let originalResponse = RegisterResponse(status: ResponseStatus.InvalidNamespace)
let decodeResult = RegisterResponse.decode(originalResponse.encode().buffer)
check decodeResult.isSome()
let decodedResponse = decodeResult.get()
check:
decodedResponse.status == originalResponse.status
decodedResponse.text.isNone()
decodedResponse.ttl.isNone()
decodedResponse.encode().buffer == originalResponse.encode().buffer
test "RegisterResponse decode fails invalid status enum":
var bufferInvalidStatusValue = initProtoBuffer()
bufferInvalidStatusValue.write(1, uint(999))
check RegisterResponse.decode(bufferInvalidStatusValue.buffer).isNone()
test "RegisterResponse decode fails missing status":
var bufferMissingStatusField = initProtoBuffer()
bufferMissingStatusField.write(2, "msg")
check RegisterResponse.decode(bufferMissingStatusField.buffer).isNone()
test "Unregister roundtrip":
let originalUnregister = Unregister(ns: namespace)
let decodeResult = Unregister.decode(originalUnregister.encode().buffer)
check decodeResult.isSome()
let decodedUnregister = decodeResult.get()
check:
decodedUnregister.ns == originalUnregister.ns
decodedUnregister.encode().buffer == originalUnregister.encode().buffer
test "Unregister decode fails when missing namespace":
var bufferMissingUnregisterNamespace = initProtoBuffer()
check Unregister.decode(bufferMissingUnregisterNamespace.buffer).isNone()
test "Discover roundtrip with all optional fields":
let originalDiscover = Discover(
ns: Opt.some(namespace), limit: Opt.some(5'u64), cookie: Opt.some(@[byte 1, 2, 3])
)
let decodeResult = Discover.decode(originalDiscover.encode().buffer)
check decodeResult.isSome()
let decodedDiscover = decodeResult.get()
check:
decodedDiscover.ns.get() == originalDiscover.ns.get()
decodedDiscover.limit.get() == originalDiscover.limit.get()
decodedDiscover.cookie.get() == originalDiscover.cookie.get()
decodedDiscover.encode().buffer == originalDiscover.encode().buffer
test "Discover decode empty buffer yields empty object":
var emptyDiscoverBuffer: seq[byte] = @[]
let decodeResult = Discover.decode(emptyDiscoverBuffer)
check decodeResult.isSome()
let decodedDiscover = decodeResult.get()
check:
decodedDiscover.ns.isNone()
decodedDiscover.limit.isNone()
decodedDiscover.cookie.isNone()
test "DiscoverResponse roundtrip with registration":
let registrationEntry = Register(ns: namespace, signedPeerRecord: @[byte 9])
let originalResponse = DiscoverResponse(
registrations: @[registrationEntry],
cookie: Opt.some(@[byte 0xAA]),
status: ResponseStatus.Ok,
text: Opt.some("t"),
)
let decodeResult = DiscoverResponse.decode(originalResponse.encode().buffer)
check decodeResult.isSome()
let decodedResponse = decodeResult.get()
check:
decodedResponse.status == originalResponse.status
decodedResponse.registrations.len == 1
decodedResponse.registrations[0].ns == registrationEntry.ns
decodedResponse.cookie.get() == originalResponse.cookie.get()
decodedResponse.text.get() == originalResponse.text.get()
decodedResponse.encode().buffer == originalResponse.encode().buffer
test "DiscoverResponse roundtrip failed":
let originalResponse = DiscoverResponse(status: ResponseStatus.InternalError)
let decodeResult = DiscoverResponse.decode(originalResponse.encode().buffer)
check decodeResult.isSome()
let decodedResponse = decodeResult.get()
check:
decodedResponse.status == originalResponse.status
decodedResponse.registrations.len == 0
decodedResponse.cookie.isNone()
decodedResponse.text.isNone()
decodedResponse.encode().buffer == originalResponse.encode().buffer
test "DiscoverResponse decode fails with invalid registration":
var bufferInvalidRegistration = initProtoBuffer()
bufferInvalidRegistration.write(1, @[byte 0x00, 0xFF])
bufferInvalidRegistration.write(3, ResponseStatus.Ok.uint)
check DiscoverResponse.decode(bufferInvalidRegistration.buffer).isNone()
test "DiscoverResponse decode fails missing status":
var bufferMissingDiscoverResponseStatus = initProtoBuffer()
check DiscoverResponse.decode(bufferMissingDiscoverResponseStatus.buffer).isNone()
test "Message roundtrip Register variant":
let registerPayload = Register(ns: namespace, signedPeerRecord: @[byte 1])
let originalMessage =
Message(msgType: MessageType.Register, register: Opt.some(registerPayload))
let decodeResult = Message.decode(originalMessage.encode().buffer)
check decodeResult.isSome()
let decodedMessage = decodeResult.get()
check:
decodedMessage.msgType == originalMessage.msgType
decodedMessage.register.get().ns == registerPayload.ns
test "Message roundtrip RegisterResponse variant":
let registerResponsePayload = RegisterResponse(status: ResponseStatus.Ok)
let originalMessage = Message(
msgType: MessageType.RegisterResponse,
registerResponse: Opt.some(registerResponsePayload),
)
let decodeResult = Message.decode(originalMessage.encode().buffer)
check decodeResult.isSome()
let decodedMessage = decodeResult.get()
check:
decodedMessage.msgType == originalMessage.msgType
decodedMessage.registerResponse.get().status == registerResponsePayload.status
test "Message roundtrip Unregister variant":
let unregisterPayload = Unregister(ns: namespace)
let originalMessage =
Message(msgType: MessageType.Unregister, unregister: Opt.some(unregisterPayload))
let decodeResult = Message.decode(originalMessage.encode().buffer)
check decodeResult.isSome()
let decodedMessage = decodeResult.get()
check:
decodedMessage.unregister.get().ns == unregisterPayload.ns
test "Message roundtrip Discover variant":
let discoverPayload = Discover(limit: Opt.some(1'u64))
let originalMessage =
Message(msgType: MessageType.Discover, discover: Opt.some(discoverPayload))
let decodeResult = Message.decode(originalMessage.encode().buffer)
check decodeResult.isSome()
let decodedMessage = decodeResult.get()
check:
decodedMessage.discover.get().limit.get() == discoverPayload.limit.get()
test "Message roundtrip DiscoverResponse variant":
let discoverResponsePayload = DiscoverResponse(status: ResponseStatus.Unavailable)
let originalMessage = Message(
msgType: MessageType.DiscoverResponse,
discoverResponse: Opt.some(discoverResponsePayload),
)
let decodeResult = Message.decode(originalMessage.encode().buffer)
check decodeResult.isSome()
let decodedMessage = decodeResult.get()
check:
decodedMessage.discoverResponse.get().status == discoverResponsePayload.status
test "Message decode header only":
var headerOnlyMessage = Message(msgType: MessageType.Register)
let decodeResult = Message.decode(headerOnlyMessage.encode().buffer)
check decodeResult.isSome()
check:
decodeResult.get().register.isNone()
test "Message decode fails invalid msgType":
var bufferInvalidMessageType = initProtoBuffer()
bufferInvalidMessageType.write(1, uint(999))
check Message.decode(bufferInvalidMessageType.buffer).isNone()

View File

@@ -311,7 +311,8 @@ suite "GossipSub":
check gossipSub.mcache.msgs.len == 0
asyncTest "rpcHandler - subscription limits":
let gossipSub = TestGossipSub.init(newStandardSwitch())
let gossipSub =
TestGossipSub.init(newStandardSwitch(transport = TransportType.QUIC))
gossipSub.topicsHigh = 10
var tooManyTopics: seq[string]
@@ -333,7 +334,8 @@ suite "GossipSub":
await conn.close()
asyncTest "rpcHandler - invalid message bytes":
let gossipSub = TestGossipSub.init(newStandardSwitch())
let gossipSub =
TestGossipSub.init(newStandardSwitch(transport = TransportType.QUIC))
let peerId = randomPeerId()
let peer = gossipSub.getPubSubPeer(peerId)

View File

@@ -87,7 +87,7 @@ proc setupGossipSubWithPeers*(
populateMesh: bool = false,
populateFanout: bool = false,
): (TestGossipSub, seq[Connection], seq[PubSubPeer]) =
let gossipSub = TestGossipSub.init(newStandardSwitch())
let gossipSub = TestGossipSub.init(newStandardSwitch(transport = TransportType.QUIC))
for topic in topics:
gossipSub.subscribe(topic, voidTopicHandler)
@@ -195,7 +195,9 @@ proc generateNodes*(
): seq[PubSub] =
for i in 0 ..< num:
let switch = newStandardSwitch(
secureManagers = secureManagers, sendSignedPeerRecord = sendSignedPeerRecord
secureManagers = secureManagers,
sendSignedPeerRecord = sendSignedPeerRecord,
transport = TransportType.Memory,
)
let pubsub =
if gossip:

View File

@@ -18,6 +18,7 @@ import
upgrademngrs/upgrade,
builders,
protocols/connectivity/autonatv2/types,
protocols/connectivity/autonatv2/server,
protocols/connectivity/autonatv2/utils,
],
./helpers
@@ -27,6 +28,10 @@ proc checkEncodeDecode[T](msg: T) =
# check msg == DialBack.decode(msg.encode()).get()
check msg == T.decode(msg.encode()).get()
proc newAutonatV2ServerSwitch(config: AutonatV2Config = AutonatV2Config.new()): Switch =
var builder = newStandardSwitchBuilder().withAutonatV2(config = config)
return builder.build()
suite "AutonatV2":
teardown:
checkTrackers()
@@ -149,3 +154,15 @@ suite "AutonatV2":
AutonatV2Response(
reachability: Reachable, dialResp: correctDialResp, addrs: Opt.some(addrs[0])
)
asyncTest "Instantiate server":
let serverSwitch = newAutonatV2ServerSwitch()
await serverSwitch.start()
await serverSwitch.stop()
asyncTest "Instantiate server with config":
let serverSwitch = newAutonatV2ServerSwitch(
config = AutonatV2Config.new(allowPrivateAddresses = true)
)
await serverSwitch.start()
await serverSwitch.stop()

View File

@@ -60,15 +60,19 @@ proc invalidCertGenerator(
except ResultError[crypto.CryptoError]:
raiseAssert "private key should be set"
proc createTransport(withInvalidCert: bool = false): Future[QuicTransport] {.async.} =
let ma = @[MultiAddress.init("/ip4/127.0.0.1/udp/0/quic-v1").tryGet()]
proc createTransport(
isServer: bool = false, withInvalidCert: bool = false
): Future[QuicTransport] {.async.} =
let privateKey = PrivateKey.random(ECDSA, (newRng())[]).tryGet()
let trans =
if withInvalidCert:
QuicTransport.new(Upgrade(), privateKey, invalidCertGenerator)
else:
QuicTransport.new(Upgrade(), privateKey)
await trans.start(ma)
if isServer: # servers are started because they need to listen
let ma = @[MultiAddress.init("/ip4/127.0.0.1/udp/0/quic-v1").tryGet()]
await trans.start(ma)
return trans
@@ -77,12 +81,46 @@ suite "Quic transport":
checkTrackers()
asyncTest "can handle local address":
let trans = await createTransport()
check trans.handles(trans.addrs[0])
await trans.stop()
let server = await createTransport(isServer = true)
check server.handles(server.addrs[0])
await server.stop()
asyncTest "handle accept cancellation":
let server = await createTransport(isServer = true)
let acceptFut = server.accept()
await acceptFut.cancelAndWait()
check acceptFut.cancelled
await server.stop()
asyncTest "handle dial cancellation":
let server = await createTransport(isServer = true)
let client = await createTransport(isServer = false)
let connFut = client.dial(server.addrs[0])
await connFut.cancelAndWait()
check connFut.cancelled
await client.stop()
await server.stop()
asyncTest "stopping transport kills connections":
let server = await createTransport(isServer = true)
let client = await createTransport(isServer = false)
let acceptFut = server.accept()
let conn = await client.dial(server.addrs[0])
let serverConn = await acceptFut
await client.stop()
await server.stop()
check serverConn.closed()
check conn.closed()
asyncTest "transport e2e":
let server = await createTransport()
let server = await createTransport(isServer = true)
asyncSpawn createServerAcceptConn(server)()
defer:
await server.stop()
@@ -101,7 +139,7 @@ suite "Quic transport":
await runClient()
asyncTest "transport e2e - invalid cert - server":
let server = await createTransport(true)
let server = await createTransport(isServer = true, withInvalidCert = true)
asyncSpawn createServerAcceptConn(server)()
defer:
await server.stop()
@@ -115,22 +153,27 @@ suite "Quic transport":
await runClient()
asyncTest "transport e2e - invalid cert - client":
let server = await createTransport()
let server = await createTransport(isServer = true)
asyncSpawn createServerAcceptConn(server)()
defer:
await server.stop()
proc runClient() {.async.} =
let client = await createTransport(true)
let client = await createTransport(withInvalidCert = true)
expect QuicTransportDialError:
discard await client.dial("", server.addrs[0])
await client.stop()
await runClient()
asyncTest "should allow multiple local addresses":
# TODO(#1663): handle multiple addr
# See test example in commonTransportTest
return
asyncTest "server not accepting":
let server = await createTransport()
# itentionally not calling createServerAcceptConn as server should not accept
let server = await createTransport(isServer = true)
# intentionally not calling createServerAcceptConn as server should not accept
defer:
await server.stop()
@@ -145,7 +188,7 @@ suite "Quic transport":
await runClient()
asyncTest "closing session should close all streams":
let server = await createTransport()
let server = await createTransport(isServer = true)
# because some clients will not write full message,
# it is expected for server to receive eof
asyncSpawn createServerAcceptConn(server, true)()
@@ -201,7 +244,7 @@ suite "Quic transport":
check (await stream.readLp(100)) == fromHex("5678")
await client.stop()
let server = await createTransport()
let server = await createTransport(isServer = true)
asyncSpawn serverHandler(server)
defer:
await server.stop()