mirror of
https://github.com/vacp2p/nim-libp2p.git
synced 2026-01-10 11:48:15 -05:00
Compare commits
9 Commits
fix/quic-t
...
test-quic-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a37c325f99 | ||
|
|
62388a7a20 | ||
|
|
27051164db | ||
|
|
f41009461b | ||
|
|
c3faabf522 | ||
|
|
10f7f5c68a | ||
|
|
f345026900 | ||
|
|
5d6578a06f | ||
|
|
871a5d047f |
2
.pinned
2
.pinned
@@ -8,7 +8,7 @@ json_serialization;https://github.com/status-im/nim-json-serialization@#2b1c5eb1
|
||||
metrics;https://github.com/status-im/nim-metrics@#6142e433fc8ea9b73379770a788017ac528d46ff
|
||||
ngtcp2;https://github.com/status-im/nim-ngtcp2@#9456daa178c655bccd4a3c78ad3b8cce1f0add73
|
||||
nimcrypto;https://github.com/cheatfate/nimcrypto@#19c41d6be4c00b4a2c8000583bd30cf8ceb5f4b1
|
||||
quic;https://github.com/vacp2p/nim-quic@#cae13c2d22ba2730c979486cf89b88927045c3ae
|
||||
quic;https://github.com/vacp2p/nim-quic@#9370190ded18d78a5a9990f57aa8cbbf947f3891
|
||||
results;https://github.com/arnetheduck/nim-results@#df8113dda4c2d74d460a8fa98252b0b771bf1f27
|
||||
secp256k1;https://github.com/status-im/nim-secp256k1@#f808ed5e7a7bfc42204ec7830f14b7a42b63c284
|
||||
serialization;https://github.com/status-im/nim-serialization@#548d0adc9797a10b2db7f788b804330306293088
|
||||
|
||||
@@ -10,7 +10,7 @@ skipDirs = @["tests", "examples", "Nim", "tools", "scripts", "docs"]
|
||||
requires "nim >= 2.0.0",
|
||||
"nimcrypto >= 0.6.0 & < 0.7.0", "dnsclient >= 0.3.0 & < 0.4.0", "bearssl >= 0.2.5",
|
||||
"chronicles >= 0.11.0 & < 0.12.0", "chronos >= 4.0.4", "metrics", "secp256k1",
|
||||
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.15",
|
||||
"stew >= 0.4.0", "websock >= 0.2.0", "unittest2", "results", "quic >= 0.2.16",
|
||||
"https://github.com/vacp2p/nim-jwt.git#18f8378de52b241f321c1f9ea905456e89b95c6f"
|
||||
|
||||
let nimc = getEnv("NIMC", "nim") # Which nim compiler to use
|
||||
|
||||
@@ -26,7 +26,8 @@ import
|
||||
transports/[transport, tcptransport, wstransport, memorytransport],
|
||||
muxers/[muxer, mplex/mplex, yamux/yamux],
|
||||
protocols/[identify, secure/secure, secure/noise, rendezvous],
|
||||
protocols/connectivity/[autonat/server, relay/relay, relay/client, relay/rtransport],
|
||||
protocols/connectivity/
|
||||
[autonat/server, autonatv2/server, relay/relay, relay/client, relay/rtransport],
|
||||
connmanager,
|
||||
upgrademngrs/muxedupgrade,
|
||||
observedaddrmanager,
|
||||
@@ -74,6 +75,8 @@ type
|
||||
nameResolver: NameResolver
|
||||
peerStoreCapacity: Opt[int]
|
||||
autonat: bool
|
||||
autonatV2: bool
|
||||
autonatV2Config: AutonatV2Config
|
||||
autotls: AutotlsService
|
||||
circuitRelay: Relay
|
||||
rdv: RendezVous
|
||||
@@ -280,6 +283,13 @@ proc withAutonat*(b: SwitchBuilder): SwitchBuilder =
|
||||
b.autonat = true
|
||||
b
|
||||
|
||||
proc withAutonatV2*(
|
||||
b: SwitchBuilder, config: AutonatV2Config = AutonatV2Config.new()
|
||||
): SwitchBuilder =
|
||||
b.autonatV2 = true
|
||||
b.autonatV2Config = config
|
||||
b
|
||||
|
||||
when defined(libp2p_autotls_support):
|
||||
proc withAutotls*(
|
||||
b: SwitchBuilder, config: AutotlsConfig = AutotlsConfig.new()
|
||||
@@ -379,7 +389,10 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
|
||||
|
||||
switch.mount(identify)
|
||||
|
||||
if b.autonat:
|
||||
if b.autonatV2:
|
||||
let autonatV2 = AutonatV2.new(switch, config = b.autonatV2Config)
|
||||
switch.mount(autonatV2)
|
||||
elif b.autonat:
|
||||
let autonat = Autonat.new(switch)
|
||||
switch.mount(autonat)
|
||||
|
||||
@@ -395,10 +408,75 @@ proc build*(b: SwitchBuilder): Switch {.raises: [LPError], public.} =
|
||||
|
||||
return switch
|
||||
|
||||
type TransportType* {.pure.} = enum
|
||||
QUIC
|
||||
TCP
|
||||
Memory
|
||||
|
||||
proc newStandardSwitchBuilder*(
|
||||
privKey = none(PrivateKey),
|
||||
addrs: MultiAddress | seq[MultiAddress] = newSeq[MultiAddress](),
|
||||
transport: TransportType = TransportType.TCP,
|
||||
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
|
||||
transportFlags: set[ServerFlags] = {},
|
||||
rng = newRng(),
|
||||
inTimeout: Duration = 5.minutes,
|
||||
outTimeout: Duration = 5.minutes,
|
||||
maxConnections = MaxConnections,
|
||||
maxIn = -1,
|
||||
maxOut = -1,
|
||||
maxConnsPerPeer = MaxConnectionsPerPeer,
|
||||
nameResolver: NameResolver = nil,
|
||||
sendSignedPeerRecord = false,
|
||||
peerStoreCapacity = 1000,
|
||||
): SwitchBuilder {.raises: [LPError], public.} =
|
||||
## Helper for common switch configurations.
|
||||
var b = SwitchBuilder
|
||||
.new()
|
||||
.withRng(rng)
|
||||
.withSignedPeerRecord(sendSignedPeerRecord)
|
||||
.withMaxConnections(maxConnections)
|
||||
.withMaxIn(maxIn)
|
||||
.withMaxOut(maxOut)
|
||||
.withMaxConnsPerPeer(maxConnsPerPeer)
|
||||
.withPeerStore(capacity = peerStoreCapacity)
|
||||
.withNameResolver(nameResolver)
|
||||
.withNoise()
|
||||
|
||||
var addrs =
|
||||
when addrs is MultiAddress:
|
||||
@[addrs]
|
||||
else:
|
||||
addrs
|
||||
|
||||
case transport
|
||||
of TransportType.QUIC:
|
||||
when defined(libp2p_quic_support):
|
||||
if addrs.len == 0:
|
||||
addrs = @[MultiAddress.init("/ip4/0.0.0.0/udp/0/quic-v1").tryGet()]
|
||||
b = b.withQuicTransport().withAddresses(addrs)
|
||||
else:
|
||||
raiseAssert "QUIC not supported in this build"
|
||||
of TransportType.TCP:
|
||||
if addrs.len == 0:
|
||||
addrs = @[MultiAddress.init("/ip4/127.0.0.1/tcp/0").tryGet()]
|
||||
b = b.withTcpTransport(transportFlags).withAddresses(addrs).withMplex(
|
||||
inTimeout, outTimeout
|
||||
)
|
||||
of TransportType.Memory:
|
||||
if addrs.len == 0:
|
||||
addrs = @[MultiAddress.init(MemoryAutoAddress).tryGet()]
|
||||
b = b.withMemoryTransport().withAddresses(addrs).withMplex(inTimeout, outTimeout)
|
||||
|
||||
privKey.withValue(pkey):
|
||||
b = b.withPrivateKey(pkey)
|
||||
|
||||
b
|
||||
|
||||
proc newStandardSwitch*(
|
||||
privKey = none(PrivateKey),
|
||||
addrs: MultiAddress | seq[MultiAddress] =
|
||||
MultiAddress.init("/ip4/127.0.0.1/tcp/0").expect("valid address"),
|
||||
addrs: MultiAddress | seq[MultiAddress] = newSeq[MultiAddress](),
|
||||
transport: TransportType = TransportType.TCP,
|
||||
secureManagers: openArray[SecureProtocol] = [SecureProtocol.Noise],
|
||||
transportFlags: set[ServerFlags] = {},
|
||||
rng = newRng(),
|
||||
@@ -412,28 +490,20 @@ proc newStandardSwitch*(
|
||||
sendSignedPeerRecord = false,
|
||||
peerStoreCapacity = 1000,
|
||||
): Switch {.raises: [LPError], public.} =
|
||||
## Helper for common switch configurations.
|
||||
let addrs =
|
||||
when addrs is MultiAddress:
|
||||
@[addrs]
|
||||
else:
|
||||
addrs
|
||||
var b = SwitchBuilder
|
||||
.new()
|
||||
.withAddresses(addrs)
|
||||
.withRng(rng)
|
||||
.withSignedPeerRecord(sendSignedPeerRecord)
|
||||
.withMaxConnections(maxConnections)
|
||||
.withMaxIn(maxIn)
|
||||
.withMaxOut(maxOut)
|
||||
.withMaxConnsPerPeer(maxConnsPerPeer)
|
||||
.withPeerStore(capacity = peerStoreCapacity)
|
||||
.withMplex(inTimeout, outTimeout)
|
||||
.withTcpTransport(transportFlags)
|
||||
.withNameResolver(nameResolver)
|
||||
.withNoise()
|
||||
|
||||
privKey.withValue(pkey):
|
||||
b = b.withPrivateKey(pkey)
|
||||
|
||||
b.build()
|
||||
newStandardSwitchBuilder(
|
||||
privKey = privKey,
|
||||
addrs = addrs,
|
||||
secureManagers = secureManagers,
|
||||
transportFlags = transportFlags,
|
||||
rng = rng,
|
||||
inTimeout = inTimeout,
|
||||
outTimeout = outTimeout,
|
||||
maxConnections = maxConnections,
|
||||
maxIn = maxIn,
|
||||
maxOut = maxOut,
|
||||
maxConnsPerPeer = maxConnsPerPeer,
|
||||
nameResolver = nameResolver,
|
||||
sendSignedPeerRecord = sendSignedPeerRecord,
|
||||
peerStoreCapacity = peerStoreCapacity,
|
||||
)
|
||||
.build()
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
import chronos
|
||||
import results
|
||||
import peerid, stream/connection, transports/transport
|
||||
import peerid, stream/connection, transports/transport, muxers/muxer
|
||||
|
||||
export results
|
||||
|
||||
@@ -65,6 +65,23 @@ method dial*(
|
||||
method addTransport*(self: Dial, transport: Transport) {.base.} =
|
||||
doAssert(false, "[Dial.addTransport] abstract method not implemented!")
|
||||
|
||||
method dialAndUpgrade*(
|
||||
self: Dial, peerId: Opt[PeerId], address: MultiAddress, dir = Direction.Out
|
||||
): Future[Muxer] {.base, async: (raises: [CancelledError]).} =
|
||||
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
|
||||
|
||||
method dialAndUpgrade*(
|
||||
self: Dial, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
|
||||
): Future[Muxer] {.
|
||||
base, async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
|
||||
.} =
|
||||
doAssert(false, "[Dial.dialAndUpgrade] abstract method not implemented!")
|
||||
|
||||
method negotiateStream*(
|
||||
self: Dial, conn: Connection, protos: seq[string]
|
||||
): Future[Connection] {.base, async: (raises: [CatchableError]).} =
|
||||
doAssert(false, "[Dial.negotiateStream] abstract method not implemented!")
|
||||
|
||||
method tryDial*(
|
||||
self: Dial, peerId: PeerId, addrs: seq[MultiAddress]
|
||||
): Future[Opt[MultiAddress]] {.
|
||||
|
||||
@@ -43,7 +43,7 @@ type Dialer* = ref object of Dial
|
||||
peerStore: PeerStore
|
||||
nameResolver: NameResolver
|
||||
|
||||
proc dialAndUpgrade(
|
||||
proc dialAndUpgrade*(
|
||||
self: Dialer,
|
||||
peerId: Opt[PeerId],
|
||||
hostname: string,
|
||||
@@ -139,7 +139,7 @@ proc expandDnsAddr(
|
||||
else:
|
||||
result.add((resolvedAddress, peerId))
|
||||
|
||||
proc dialAndUpgrade(
|
||||
proc dialAndUpgrade*(
|
||||
self: Dialer, peerId: Opt[PeerId], addrs: seq[MultiAddress], dir = Direction.Out
|
||||
): Future[Muxer] {.
|
||||
async: (raises: [CancelledError, MaError, TransportAddressError, LPError])
|
||||
@@ -284,7 +284,7 @@ method connect*(
|
||||
return
|
||||
(await self.internalConnect(Opt.none(PeerId), @[address], false)).connection.peerId
|
||||
|
||||
proc negotiateStream(
|
||||
proc negotiateStream*(
|
||||
self: Dialer, conn: Connection, protos: seq[string]
|
||||
): Future[Connection] {.async: (raises: [CatchableError]).} =
|
||||
trace "Negotiating stream", conn, protos
|
||||
|
||||
279
libp2p/protocols/connectivity/autonatv2/server.nim
Normal file
279
libp2p/protocols/connectivity/autonatv2/server.nim
Normal file
@@ -0,0 +1,279 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import results
|
||||
import chronos, chronicles
|
||||
import
|
||||
../../../../libp2p/[
|
||||
switch,
|
||||
muxers/muxer,
|
||||
dialer,
|
||||
multiaddress,
|
||||
transports/transport,
|
||||
multicodec,
|
||||
peerid,
|
||||
protobuf/minprotobuf,
|
||||
utils/ipaddr,
|
||||
],
|
||||
../../protocol,
|
||||
./types
|
||||
|
||||
logScope:
|
||||
topics = "libp2p autonat v2 server"
|
||||
|
||||
type AutonatV2Config* = object
|
||||
dialTimeout: Duration
|
||||
dialDataSize: uint64
|
||||
amplificationAttackTimeout: Duration
|
||||
allowPrivateAddresses: bool
|
||||
|
||||
type AutonatV2* = ref object of LPProtocol
|
||||
switch*: Switch
|
||||
config: AutonatV2Config
|
||||
|
||||
proc new*(
|
||||
T: typedesc[AutonatV2Config],
|
||||
dialTimeout: Duration = DefaultDialTimeout,
|
||||
dialDataSize: uint64 = DefaultDialDataSize,
|
||||
amplificationAttackTimeout: Duration = DefaultAmplificationAttackDialTimeout,
|
||||
allowPrivateAddresses: bool = false,
|
||||
): T =
|
||||
T(
|
||||
dialTimeout: dialTimeout,
|
||||
dialDataSize: dialDataSize,
|
||||
amplificationAttackTimeout: amplificationAttackTimeout,
|
||||
allowPrivateAddresses: allowPrivateAddresses,
|
||||
)
|
||||
|
||||
proc sendDialResponse(
|
||||
conn: Connection,
|
||||
status: ResponseStatus,
|
||||
addrIdx: Opt[AddrIdx] = Opt.none(AddrIdx),
|
||||
dialStatus: Opt[DialStatus] = Opt.none(DialStatus),
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
await conn.writeLp(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialResponse,
|
||||
dialResp: DialResponse(status: status, addrIdx: addrIdx, dialStatus: dialStatus),
|
||||
).encode().buffer
|
||||
)
|
||||
|
||||
proc findObservedIPAddr*(
|
||||
conn: Connection, req: DialRequest
|
||||
): Future[Opt[MultiAddress]] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let observedAddr = conn.observedAddr.valueOr:
|
||||
await conn.sendDialResponse(ResponseStatus.EInternalError)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
let isRelayed = observedAddr.contains(multiCodec("p2p-circuit")).valueOr:
|
||||
error "Invalid observed address"
|
||||
await conn.sendDialResponse(ResponseStatus.EDialRefused)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
if isRelayed:
|
||||
error "Invalid observed address: relayed address"
|
||||
await conn.sendDialResponse(ResponseStatus.EDialRefused)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
let hostIp = observedAddr[0].valueOr:
|
||||
error "Invalid observed address"
|
||||
await conn.sendDialResponse(ResponseStatus.EInternalError)
|
||||
return Opt.none(MultiAddress)
|
||||
|
||||
return Opt.some(hostIp)
|
||||
|
||||
proc dialBack(
|
||||
conn: Connection, nonce: Nonce
|
||||
): Future[DialStatus] {.
|
||||
async: (raises: [CancelledError, DialFailedError, LPStreamError])
|
||||
.} =
|
||||
try:
|
||||
# send dial back
|
||||
await conn.writeLp(DialBack(nonce: nonce).encode().buffer)
|
||||
|
||||
# receive DialBackResponse
|
||||
let dialBackResp = DialBackResponse.decode(
|
||||
initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))
|
||||
).valueOr:
|
||||
error "DialBack failed, could not decode DialBackResponse"
|
||||
return DialStatus.EDialBackError
|
||||
except LPStreamRemoteClosedError as exc:
|
||||
# failed because of nonce error (remote reset the stream): EDialBackError
|
||||
error "DialBack failed, remote closed the connection", description = exc.msg
|
||||
return DialStatus.EDialBackError
|
||||
|
||||
# TODO: failed because of client or server resources: EDialError
|
||||
|
||||
trace "DialBack successful"
|
||||
return DialStatus.Ok
|
||||
|
||||
proc handleDialDataResponses(
|
||||
self: AutonatV2, conn: Connection
|
||||
) {.async: (raises: [CancelledError, AutonatV2Error, LPStreamError]).} =
|
||||
var dataReceived: uint64 = 0
|
||||
|
||||
while dataReceived < self.config.dialDataSize:
|
||||
let msg = AutonatV2Msg.decode(
|
||||
initProtoBuffer(await conn.readLp(AutonatV2DialDataResponseLpSize))
|
||||
).valueOr:
|
||||
raise newException(AutonatV2Error, "Received malformed message")
|
||||
debug "Received message", msgType = $msg.msgType
|
||||
if msg.msgType != MsgType.DialDataResponse:
|
||||
raise
|
||||
newException(AutonatV2Error, "Expecting DialDataResponse, got " & $msg.msgType)
|
||||
let resp = msg.dialDataResp
|
||||
dataReceived += resp.data.len.uint64
|
||||
debug "received data",
|
||||
dataReceived = resp.data.len.uint64, totalDataReceived = dataReceived
|
||||
|
||||
proc amplificationAttackPrevention(
|
||||
self: AutonatV2, conn: Connection, addrIdx: AddrIdx
|
||||
): Future[bool] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
# send DialDataRequest
|
||||
await conn.writeLp(
|
||||
AutonatV2Msg(
|
||||
msgType: MsgType.DialDataRequest,
|
||||
dialDataReq: DialDataRequest(addrIdx: addrIdx, numBytes: self.config.dialDataSize),
|
||||
).encode().buffer
|
||||
)
|
||||
|
||||
# recieve DialDataResponses until we're satisfied
|
||||
try:
|
||||
if not await self.handleDialDataResponses(conn).withTimeout(self.config.dialTimeout):
|
||||
error "Amplification attack prevention timeout",
|
||||
timeout = self.config.amplificationAttackTimeout, peer = conn.peerId
|
||||
return false
|
||||
except AutonatV2Error as exc:
|
||||
error "Amplification attack prevention failed", description = exc.msg
|
||||
return false
|
||||
|
||||
return true
|
||||
|
||||
proc canDial(self: AutonatV2, addrs: MultiAddress): bool =
|
||||
let (ipv4Support, ipv6Support) = self.switch.peerInfo.listenAddrs.ipSupport()
|
||||
addrs[0].withValue(addrIp):
|
||||
if IP4.match(addrIp) and not ipv4Support:
|
||||
return false
|
||||
if IP6.match(addrIp) and not ipv6Support:
|
||||
return false
|
||||
try:
|
||||
if not self.config.allowPrivateAddresses and isPrivate($addrIp):
|
||||
return false
|
||||
except ValueError:
|
||||
warn "Unable to parse IP address, skipping", addrs = $addrIp
|
||||
return false
|
||||
for t in self.switch.transports:
|
||||
if t.handles(addrs):
|
||||
return true
|
||||
return false
|
||||
|
||||
proc forceNewConnection(
|
||||
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
|
||||
): Future[Opt[Connection]] {.async: (raises: [CancelledError]).} =
|
||||
## Bypasses connManager to force a new connection to ``pid``
|
||||
## instead of reusing a preexistent one
|
||||
try:
|
||||
let mux = await self.switch.dialer.dialAndUpgrade(Opt.some(pid), addrs)
|
||||
if mux.isNil():
|
||||
return Opt.none(Connection)
|
||||
return Opt.some(
|
||||
await self.switch.negotiateStream(
|
||||
await mux.newStream(), @[$AutonatV2Codec.DialBack]
|
||||
)
|
||||
)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError:
|
||||
return Opt.none(Connection)
|
||||
|
||||
proc chooseDialAddr(
|
||||
self: AutonatV2, pid: PeerId, addrs: seq[MultiAddress]
|
||||
): Future[Opt[(Connection, AddrIdx)]] {.async: (raises: [CancelledError]).} =
|
||||
for i, ma in addrs:
|
||||
if self.canDial(ma):
|
||||
debug "Trying to dial", chosenAddrs = ma, addrIdx = i
|
||||
let conn = (await self.forceNewConnection(pid, @[ma])).valueOr:
|
||||
return Opt.none((Connection, AddrIdx))
|
||||
return Opt.some((conn, i.AddrIdx))
|
||||
return Opt.none((Connection, AddrIdx))
|
||||
|
||||
proc handleDialRequest(
|
||||
self: AutonatV2, conn: Connection, req: DialRequest
|
||||
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
|
||||
let observedIPAddr = (await conn.findObservedIPAddr(req)).valueOr:
|
||||
error "Could not find observed IP address"
|
||||
return
|
||||
|
||||
let (dialBackConn, addrIdx) = (await self.chooseDialAddr(conn.peerId, req.addrs)).valueOr:
|
||||
error "No dialable addresses found"
|
||||
await conn.sendDialResponse(ResponseStatus.EDialRefused)
|
||||
return
|
||||
defer:
|
||||
await dialBackConn.close()
|
||||
|
||||
# if observed address for peer is not in address list to try
|
||||
# then we perform Amplification Attack Prevention
|
||||
if not ipAddrMatches(observedIPAddr, req.addrs):
|
||||
debug "Starting amplification attack prevention",
|
||||
observedIPAddr = observedIPAddr, testAddr = req.addrs[addrIdx]
|
||||
# send DialDataRequest and wait until dataReceived is enough
|
||||
if not await self.amplificationAttackPrevention(conn, addrIdx):
|
||||
return
|
||||
|
||||
debug "Sending DialBack",
|
||||
nonce = req.nonce, addrIdx = addrIdx, addr = req.addrs[addrIdx]
|
||||
|
||||
let dialStatus = await dialBackConn.dialBack(req.nonce)
|
||||
|
||||
await conn.sendDialResponse(
|
||||
ResponseStatus.Ok, addrIdx = Opt.some(addrIdx), dialStatus = Opt.some(dialStatus)
|
||||
)
|
||||
|
||||
proc new*(
|
||||
T: typedesc[AutonatV2],
|
||||
switch: Switch,
|
||||
config: AutonatV2Config = AutonatV2Config.new(),
|
||||
): T =
|
||||
let autonatV2 = T(switch: switch, config: config)
|
||||
proc handleStream(
|
||||
conn: Connection, proto: string
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
defer:
|
||||
await conn.close()
|
||||
|
||||
let msg =
|
||||
try:
|
||||
AutonatV2Msg.decode(initProtoBuffer(await conn.readLp(AutonatV2MsgLpSize))).valueOr:
|
||||
trace "Unable to decode AutonatV2Msg"
|
||||
return
|
||||
except LPStreamError as exc:
|
||||
debug "Could not receive AutonatV2Msg", description = exc.msg
|
||||
return
|
||||
|
||||
debug "Received message", msgType = $msg.msgType
|
||||
if msg.msgType != MsgType.DialRequest:
|
||||
debug "Expecting DialRequest", receivedMsgType = msg.msgType
|
||||
return
|
||||
|
||||
try:
|
||||
await autonatV2.handleDialRequest(conn, msg.dialReq)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except LPStreamRemoteClosedError as exc:
|
||||
debug "Connection closed by peer", description = exc.msg, peer = conn.peerId
|
||||
except LPStreamError as exc:
|
||||
debug "Stream Error", description = exc.msg
|
||||
except DialFailedError as exc:
|
||||
debug "Could not dial peer", description = exc.msg, peer = conn.peerId
|
||||
|
||||
autonatV2.handler = handleStream
|
||||
autonatV2.codec = $AutonatV2Codec.DialRequest
|
||||
autonatV2
|
||||
@@ -14,6 +14,14 @@ import
|
||||
../../../multiaddress, ../../../peerid, ../../../protobuf/minprotobuf, ../../../switch
|
||||
from ../autonat/types import NetworkReachability
|
||||
|
||||
const
|
||||
DefaultDialTimeout*: Duration = 15.seconds
|
||||
DefaultAmplificationAttackDialTimeout*: Duration = 3.seconds
|
||||
DefaultDialDataSize*: uint64 = 50 * 1024 # 50 KiB > 50 KB
|
||||
AutonatV2MsgLpSize*: int = 1024
|
||||
# readLp needs to receive more than 4096 bytes (since it's a DialDataResponse) + overhead
|
||||
AutonatV2DialDataResponseLpSize*: int = 5000
|
||||
|
||||
type
|
||||
AutonatV2Codec* {.pure.} = enum
|
||||
DialRequest = "/libp2p/autonat/2/dial-request"
|
||||
|
||||
@@ -478,19 +478,16 @@ iterator splitRPCMsg(
|
||||
## exceeds the `maxSize` when trying to fit into an empty `RPCMsg`, the latter is skipped as too large to send.
|
||||
## Every constructed `RPCMsg` is then encoded, optionally anonymized, and yielded as a sequence of bytes.
|
||||
|
||||
var currentRPCMsg = rpcMsg
|
||||
currentRPCMsg.messages = newSeq[Message]()
|
||||
|
||||
var currentSize = byteSize(currentRPCMsg)
|
||||
var currentRPCMsg = RPCMsg()
|
||||
var currentSize = 0
|
||||
|
||||
for msg in rpcMsg.messages:
|
||||
let msgSize = byteSize(msg)
|
||||
|
||||
# Check if adding the next message will exceed maxSize
|
||||
if float(currentSize + msgSize) * 1.1 > float(maxSize):
|
||||
# Guessing 10% protobuf overhead
|
||||
if currentRPCMsg.messages.len == 0:
|
||||
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
if currentSize + msgSize > maxSize:
|
||||
if msgSize > maxSize:
|
||||
warn "message too big to sent", peer, rpcMsg = shortLog(msg)
|
||||
continue # Skip this message
|
||||
|
||||
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
@@ -502,11 +499,9 @@ iterator splitRPCMsg(
|
||||
currentSize += msgSize
|
||||
|
||||
# Check if there is a non-empty currentRPCMsg left to be added
|
||||
if currentSize > 0 and currentRPCMsg.messages.len > 0:
|
||||
if currentRPCMsg.messages.len > 0:
|
||||
trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
yield encodeRpcMsg(currentRPCMsg, anonymize)
|
||||
else:
|
||||
trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg)
|
||||
|
||||
proc send*(
|
||||
p: PubSubPeer,
|
||||
@@ -542,8 +537,11 @@ proc send*(
|
||||
sendMetrics(msg)
|
||||
encodeRpcMsg(msg, anonymize)
|
||||
|
||||
if encoded.len > p.maxMessageSize and msg.messages.len > 1:
|
||||
for encodedSplitMsg in splitRPCMsg(p, msg, p.maxMessageSize, anonymize):
|
||||
# Messages should not exceed 90% of maxMessageSize. Guessing 10% protobuf overhead.
|
||||
let maxEncodedMsgSize = (p.maxMessageSize * 90) div 100
|
||||
|
||||
if encoded.len > maxEncodedMsgSize and msg.messages.len > 1:
|
||||
for encodedSplitMsg in splitRPCMsg(p, msg, maxEncodedMsgSize, anonymize):
|
||||
asyncSpawn p.sendEncoded(encodedSplitMsg, isHighPriority, useCustomConn)
|
||||
else:
|
||||
# If the message size is within limits, send it as is
|
||||
|
||||
@@ -1,851 +1,3 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
import ./rendezvous/rendezvous
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import tables, sequtils, sugar, sets
|
||||
import metrics except collect
|
||||
import chronos, chronicles, bearssl/rand, stew/[byteutils, objects]
|
||||
import
|
||||
./protocol,
|
||||
../protobuf/minprotobuf,
|
||||
../switch,
|
||||
../routing_record,
|
||||
../utils/heartbeat,
|
||||
../stream/connection,
|
||||
../utils/offsettedseq,
|
||||
../utils/semaphore,
|
||||
../discovery/discoverymngr
|
||||
|
||||
export chronicles
|
||||
|
||||
logScope:
|
||||
topics = "libp2p discovery rendezvous"
|
||||
|
||||
declareCounter(libp2p_rendezvous_register, "number of advertise requests")
|
||||
declareCounter(libp2p_rendezvous_discover, "number of discovery requests")
|
||||
declareGauge(libp2p_rendezvous_registered, "number of registered peers")
|
||||
declareGauge(libp2p_rendezvous_namespaces, "number of registered namespaces")
|
||||
|
||||
const
|
||||
RendezVousCodec* = "/rendezvous/1.0.0"
|
||||
# Default minimum TTL per libp2p spec
|
||||
MinimumDuration* = 2.hours
|
||||
# Lower validation limit to accommodate Waku requirements
|
||||
MinimumAcceptedDuration = 1.minutes
|
||||
MaximumDuration = 72.hours
|
||||
MaximumMessageLen = 1 shl 22 # 4MB
|
||||
MinimumNamespaceLen = 1
|
||||
MaximumNamespaceLen = 255
|
||||
RegistrationLimitPerPeer* = 1000
|
||||
DiscoverLimit = 1000'u64
|
||||
SemaphoreDefaultSize = 5
|
||||
|
||||
type
|
||||
MessageType {.pure.} = enum
|
||||
Register = 0
|
||||
RegisterResponse = 1
|
||||
Unregister = 2
|
||||
Discover = 3
|
||||
DiscoverResponse = 4
|
||||
|
||||
ResponseStatus = enum
|
||||
Ok = 0
|
||||
InvalidNamespace = 100
|
||||
InvalidSignedPeerRecord = 101
|
||||
InvalidTTL = 102
|
||||
InvalidCookie = 103
|
||||
NotAuthorized = 200
|
||||
InternalError = 300
|
||||
Unavailable = 400
|
||||
|
||||
Cookie = object
|
||||
offset: uint64
|
||||
ns: Opt[string]
|
||||
|
||||
Register = object
|
||||
ns: string
|
||||
signedPeerRecord: seq[byte]
|
||||
ttl*: Opt[uint64] # in seconds
|
||||
|
||||
RegisterResponse = object
|
||||
status: ResponseStatus
|
||||
text: Opt[string]
|
||||
ttl: Opt[uint64] # in seconds
|
||||
|
||||
Unregister = object
|
||||
ns: string
|
||||
|
||||
Discover = object
|
||||
ns: Opt[string]
|
||||
limit: Opt[uint64]
|
||||
cookie: Opt[seq[byte]]
|
||||
|
||||
DiscoverResponse = object
|
||||
registrations: seq[Register]
|
||||
cookie: Opt[seq[byte]]
|
||||
status: ResponseStatus
|
||||
text: Opt[string]
|
||||
|
||||
Message = object
|
||||
msgType: MessageType
|
||||
register: Opt[Register]
|
||||
registerResponse: Opt[RegisterResponse]
|
||||
unregister: Opt[Unregister]
|
||||
discover: Opt[Discover]
|
||||
discoverResponse: Opt[DiscoverResponse]
|
||||
|
||||
proc encode(c: Cookie): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, c.offset)
|
||||
if c.ns.isSome():
|
||||
result.write(2, c.ns.get())
|
||||
result.finish()
|
||||
|
||||
proc encode(r: Register): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, r.ns)
|
||||
result.write(2, r.signedPeerRecord)
|
||||
r.ttl.withValue(ttl):
|
||||
result.write(3, ttl)
|
||||
result.finish()
|
||||
|
||||
proc encode(rr: RegisterResponse): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, rr.status.uint)
|
||||
rr.text.withValue(text):
|
||||
result.write(2, text)
|
||||
rr.ttl.withValue(ttl):
|
||||
result.write(3, ttl)
|
||||
result.finish()
|
||||
|
||||
proc encode(u: Unregister): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, u.ns)
|
||||
result.finish()
|
||||
|
||||
proc encode(d: Discover): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
if d.ns.isSome():
|
||||
result.write(1, d.ns.get())
|
||||
d.limit.withValue(limit):
|
||||
result.write(2, limit)
|
||||
d.cookie.withValue(cookie):
|
||||
result.write(3, cookie)
|
||||
result.finish()
|
||||
|
||||
proc encode(dr: DiscoverResponse): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
for reg in dr.registrations:
|
||||
result.write(1, reg.encode())
|
||||
dr.cookie.withValue(cookie):
|
||||
result.write(2, cookie)
|
||||
result.write(3, dr.status.uint)
|
||||
dr.text.withValue(text):
|
||||
result.write(4, text)
|
||||
result.finish()
|
||||
|
||||
proc encode(msg: Message): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, msg.msgType.uint)
|
||||
msg.register.withValue(register):
|
||||
result.write(2, register.encode())
|
||||
msg.registerResponse.withValue(registerResponse):
|
||||
result.write(3, registerResponse.encode())
|
||||
msg.unregister.withValue(unregister):
|
||||
result.write(4, unregister.encode())
|
||||
msg.discover.withValue(discover):
|
||||
result.write(5, discover.encode())
|
||||
msg.discoverResponse.withValue(discoverResponse):
|
||||
result.write(6, discoverResponse.encode())
|
||||
result.finish()
|
||||
|
||||
proc decode(_: typedesc[Cookie], buf: seq[byte]): Opt[Cookie] =
|
||||
var
|
||||
c: Cookie
|
||||
ns: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, c.offset)
|
||||
r2 = pb.getField(2, ns)
|
||||
if r1.isErr() or r2.isErr():
|
||||
return Opt.none(Cookie)
|
||||
if r2.get(false):
|
||||
c.ns = Opt.some(ns)
|
||||
Opt.some(c)
|
||||
|
||||
proc decode(_: typedesc[Register], buf: seq[byte]): Opt[Register] =
|
||||
var
|
||||
r: Register
|
||||
ttl: uint64
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, r.ns)
|
||||
r2 = pb.getRequiredField(2, r.signedPeerRecord)
|
||||
r3 = pb.getField(3, ttl)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr():
|
||||
return Opt.none(Register)
|
||||
if r3.get(false):
|
||||
r.ttl = Opt.some(ttl)
|
||||
Opt.some(r)
|
||||
|
||||
proc decode(_: typedesc[RegisterResponse], buf: seq[byte]): Opt[RegisterResponse] =
|
||||
var
|
||||
rr: RegisterResponse
|
||||
statusOrd: uint
|
||||
text: string
|
||||
ttl: uint64
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, statusOrd)
|
||||
r2 = pb.getField(2, text)
|
||||
r3 = pb.getField(3, ttl)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr() or
|
||||
not checkedEnumAssign(rr.status, statusOrd):
|
||||
return Opt.none(RegisterResponse)
|
||||
if r2.get(false):
|
||||
rr.text = Opt.some(text)
|
||||
if r3.get(false):
|
||||
rr.ttl = Opt.some(ttl)
|
||||
Opt.some(rr)
|
||||
|
||||
proc decode(_: typedesc[Unregister], buf: seq[byte]): Opt[Unregister] =
|
||||
var u: Unregister
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, u.ns)
|
||||
if r1.isErr():
|
||||
return Opt.none(Unregister)
|
||||
Opt.some(u)
|
||||
|
||||
proc decode(_: typedesc[Discover], buf: seq[byte]): Opt[Discover] =
|
||||
var
|
||||
d: Discover
|
||||
limit: uint64
|
||||
cookie: seq[byte]
|
||||
ns: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getField(1, ns)
|
||||
r2 = pb.getField(2, limit)
|
||||
r3 = pb.getField(3, cookie)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr:
|
||||
return Opt.none(Discover)
|
||||
if r1.get(false):
|
||||
d.ns = Opt.some(ns)
|
||||
if r2.get(false):
|
||||
d.limit = Opt.some(limit)
|
||||
if r3.get(false):
|
||||
d.cookie = Opt.some(cookie)
|
||||
Opt.some(d)
|
||||
|
||||
proc decode(_: typedesc[DiscoverResponse], buf: seq[byte]): Opt[DiscoverResponse] =
|
||||
var
|
||||
dr: DiscoverResponse
|
||||
registrations: seq[seq[byte]]
|
||||
cookie: seq[byte]
|
||||
statusOrd: uint
|
||||
text: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRepeatedField(1, registrations)
|
||||
r2 = pb.getField(2, cookie)
|
||||
r3 = pb.getRequiredField(3, statusOrd)
|
||||
r4 = pb.getField(4, text)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr or r4.isErr() or
|
||||
not checkedEnumAssign(dr.status, statusOrd):
|
||||
return Opt.none(DiscoverResponse)
|
||||
for reg in registrations:
|
||||
var r: Register
|
||||
let regOpt = Register.decode(reg).valueOr:
|
||||
return
|
||||
dr.registrations.add(regOpt)
|
||||
if r2.get(false):
|
||||
dr.cookie = Opt.some(cookie)
|
||||
if r4.get(false):
|
||||
dr.text = Opt.some(text)
|
||||
Opt.some(dr)
|
||||
|
||||
proc decode(_: typedesc[Message], buf: seq[byte]): Opt[Message] =
|
||||
var
|
||||
msg: Message
|
||||
statusOrd: uint
|
||||
pbr, pbrr, pbu, pbd, pbdr: ProtoBuffer
|
||||
let pb = initProtoBuffer(buf)
|
||||
|
||||
?pb.getRequiredField(1, statusOrd).toOpt
|
||||
if not checkedEnumAssign(msg.msgType, statusOrd):
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(2, pbr).optValue:
|
||||
msg.register = Register.decode(pbr.buffer)
|
||||
if msg.register.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(3, pbrr).optValue:
|
||||
msg.registerResponse = RegisterResponse.decode(pbrr.buffer)
|
||||
if msg.registerResponse.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(4, pbu).optValue:
|
||||
msg.unregister = Unregister.decode(pbu.buffer)
|
||||
if msg.unregister.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(5, pbd).optValue:
|
||||
msg.discover = Discover.decode(pbd.buffer)
|
||||
if msg.discover.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(6, pbdr).optValue:
|
||||
msg.discoverResponse = DiscoverResponse.decode(pbdr.buffer)
|
||||
if msg.discoverResponse.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
Opt.some(msg)
|
||||
|
||||
type
|
||||
RendezVousError* = object of DiscoveryError
|
||||
RegisteredData = object
|
||||
expiration*: Moment
|
||||
peerId*: PeerId
|
||||
data*: Register
|
||||
|
||||
RendezVous* = ref object of LPProtocol
|
||||
# Registered needs to be an offsetted sequence
|
||||
# because we need stable index for the cookies.
|
||||
registered*: OffsettedSeq[RegisteredData]
|
||||
# Namespaces is a table whose key is a salted namespace and
|
||||
# the value is the index sequence corresponding to this
|
||||
# namespace in the offsettedqueue.
|
||||
namespaces*: Table[string, seq[int]]
|
||||
rng: ref HmacDrbgContext
|
||||
salt: string
|
||||
expiredDT: Moment
|
||||
registerDeletionLoop: Future[void]
|
||||
#registerEvent: AsyncEvent # TODO: to raise during the heartbeat
|
||||
# + make the heartbeat sleep duration "smarter"
|
||||
sema: AsyncSemaphore
|
||||
peers: seq[PeerId]
|
||||
cookiesSaved*: Table[PeerId, Table[string, seq[byte]]]
|
||||
switch: Switch
|
||||
minDuration: Duration
|
||||
maxDuration: Duration
|
||||
minTTL: uint64
|
||||
maxTTL: uint64
|
||||
|
||||
proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] =
|
||||
if spr.len == 0:
|
||||
return err("Empty peer record")
|
||||
let signedEnv = ?SignedPeerRecord.decode(spr).mapErr(x => $x)
|
||||
if signedEnv.data.peerId != peerId:
|
||||
return err("Bad Peer ID")
|
||||
return ok()
|
||||
|
||||
proc sendRegisterResponse(
|
||||
conn: Connection, ttl: uint64
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.RegisterResponse,
|
||||
registerResponse: Opt.some(RegisterResponse(status: Ok, ttl: Opt.some(ttl))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendRegisterResponseError(
|
||||
conn: Connection, status: ResponseStatus, text: string = ""
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.RegisterResponse,
|
||||
registerResponse: Opt.some(RegisterResponse(status: status, text: Opt.some(text))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendDiscoverResponse(
|
||||
conn: Connection, s: seq[Register], cookie: Cookie
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.DiscoverResponse,
|
||||
discoverResponse: Opt.some(
|
||||
DiscoverResponse(
|
||||
status: Ok, registrations: s, cookie: Opt.some(cookie.encode().buffer)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendDiscoverResponseError(
|
||||
conn: Connection, status: ResponseStatus, text: string = ""
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.DiscoverResponse,
|
||||
discoverResponse: Opt.some(DiscoverResponse(status: status, text: Opt.some(text))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc countRegister(rdv: RendezVous, peerId: PeerId): int =
|
||||
for data in rdv.registered:
|
||||
if data.peerId == peerId:
|
||||
result.inc()
|
||||
|
||||
proc save(
|
||||
rdv: RendezVous, ns: string, peerId: PeerId, r: Register, update: bool = true
|
||||
) =
|
||||
let nsSalted = ns & rdv.salt
|
||||
discard rdv.namespaces.hasKeyOrPut(nsSalted, newSeq[int]())
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == peerId:
|
||||
if update == false:
|
||||
return
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
rdv.registered.add(
|
||||
RegisteredData(
|
||||
peerId: peerId,
|
||||
expiration: Moment.now() + r.ttl.get(rdv.minTTL).int64.seconds,
|
||||
data: r,
|
||||
)
|
||||
)
|
||||
rdv.namespaces[nsSalted].add(rdv.registered.high)
|
||||
# rdv.registerEvent.fire()
|
||||
except KeyError as e:
|
||||
doAssert false, "Should have key: " & e.msg
|
||||
|
||||
proc register(rdv: RendezVous, conn: Connection, r: Register): Future[void] =
|
||||
trace "Received Register", peerId = conn.peerId, ns = r.ns
|
||||
libp2p_rendezvous_register.inc()
|
||||
if r.ns.len < MinimumNamespaceLen or r.ns.len > MaximumNamespaceLen:
|
||||
return conn.sendRegisterResponseError(InvalidNamespace)
|
||||
let ttl = r.ttl.get(rdv.minTTL)
|
||||
if ttl < rdv.minTTL or ttl > rdv.maxTTL:
|
||||
return conn.sendRegisterResponseError(InvalidTTL)
|
||||
let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId)
|
||||
if pr.isErr():
|
||||
return conn.sendRegisterResponseError(InvalidSignedPeerRecord, pr.error())
|
||||
if rdv.countRegister(conn.peerId) >= RegistrationLimitPerPeer:
|
||||
return conn.sendRegisterResponseError(NotAuthorized, "Registration limit reached")
|
||||
rdv.save(r.ns, conn.peerId, r)
|
||||
libp2p_rendezvous_registered.inc()
|
||||
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
|
||||
conn.sendRegisterResponse(ttl)
|
||||
|
||||
proc unregister(rdv: RendezVous, conn: Connection, u: Unregister) =
|
||||
trace "Received Unregister", peerId = conn.peerId, ns = u.ns
|
||||
let nsSalted = u.ns & rdv.salt
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == conn.peerId:
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
libp2p_rendezvous_registered.dec()
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
proc discover(
|
||||
rdv: RendezVous, conn: Connection, d: Discover
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
trace "Received Discover", peerId = conn.peerId, ns = d.ns
|
||||
libp2p_rendezvous_discover.inc()
|
||||
if d.ns.isSome() and d.ns.get().len > MaximumNamespaceLen:
|
||||
await conn.sendDiscoverResponseError(InvalidNamespace)
|
||||
return
|
||||
var limit = min(DiscoverLimit, d.limit.get(DiscoverLimit))
|
||||
var cookie =
|
||||
if d.cookie.isSome():
|
||||
try:
|
||||
Cookie.decode(d.cookie.tryGet()).tryGet()
|
||||
except CatchableError:
|
||||
await conn.sendDiscoverResponseError(InvalidCookie)
|
||||
return
|
||||
else:
|
||||
# Start from the current lowest index (inclusive)
|
||||
Cookie(offset: rdv.registered.low().uint64)
|
||||
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get():
|
||||
# Namespace changed: start from the beginning of that namespace
|
||||
cookie = Cookie(offset: rdv.registered.low().uint64)
|
||||
elif cookie.offset < rdv.registered.low().uint64:
|
||||
# Cookie behind available range: reset to current low
|
||||
cookie.offset = rdv.registered.low().uint64
|
||||
elif cookie.offset > (rdv.registered.high() + 1).uint64:
|
||||
# Cookie ahead of available range: reset to one past current high (empty page)
|
||||
cookie.offset = (rdv.registered.high() + 1).uint64
|
||||
let namespaces =
|
||||
if d.ns.isSome():
|
||||
try:
|
||||
rdv.namespaces[d.ns.get() & rdv.salt]
|
||||
except KeyError:
|
||||
await conn.sendDiscoverResponseError(InvalidNamespace)
|
||||
return
|
||||
else:
|
||||
toSeq(max(cookie.offset.int, rdv.registered.offset) .. rdv.registered.high())
|
||||
if namespaces.len() == 0:
|
||||
await conn.sendDiscoverResponse(@[], Cookie())
|
||||
return
|
||||
var nextOffset = cookie.offset
|
||||
let n = Moment.now()
|
||||
var s = collect(newSeq()):
|
||||
for index in namespaces:
|
||||
var reg = rdv.registered[index]
|
||||
if limit == 0:
|
||||
break
|
||||
if reg.expiration < n or index.uint64 < cookie.offset:
|
||||
continue
|
||||
limit.dec()
|
||||
nextOffset = index.uint64 + 1
|
||||
reg.data.ttl = Opt.some((reg.expiration - Moment.now()).seconds.uint64)
|
||||
reg.data
|
||||
rdv.rng.shuffle(s)
|
||||
await conn.sendDiscoverResponse(s, Cookie(offset: nextOffset, ns: d.ns))
|
||||
|
||||
proc advertisePeer(
|
||||
rdv: RendezVous, peer: PeerId, msg: seq[byte]
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
proc advertiseWrap() {.async: (raises: []).} =
|
||||
try:
|
||||
let conn = await rdv.switch.dial(peer, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
await conn.writeLp(msg)
|
||||
let
|
||||
buf = await conn.readLp(4096)
|
||||
msgRecv = Message.decode(buf).tryGet()
|
||||
if msgRecv.msgType != MessageType.RegisterResponse:
|
||||
trace "Unexpected register response", peer, msgType = msgRecv.msgType
|
||||
elif msgRecv.registerResponse.tryGet().status != ResponseStatus.Ok:
|
||||
trace "Refuse to register", peer, response = msgRecv.registerResponse
|
||||
else:
|
||||
trace "Successfully registered", peer, response = msgRecv.registerResponse
|
||||
except CatchableError as exc:
|
||||
trace "exception in the advertise", description = exc.msg
|
||||
finally:
|
||||
rdv.sema.release()
|
||||
|
||||
await rdv.sema.acquire()
|
||||
await advertiseWrap()
|
||||
|
||||
proc advertise*(
|
||||
rdv: RendezVous, ns: string, ttl: Duration, peers: seq[PeerId]
|
||||
) {.async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
|
||||
raise newException(AdvertiseError, "Invalid namespace")
|
||||
|
||||
if ttl < rdv.minDuration or ttl > rdv.maxDuration:
|
||||
raise newException(AdvertiseError, "Invalid time to live: " & $ttl)
|
||||
|
||||
let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr:
|
||||
raise newException(AdvertiseError, "Wrong Signed Peer Record")
|
||||
|
||||
let
|
||||
r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64))
|
||||
msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r)))
|
||||
|
||||
rdv.save(ns, rdv.switch.peerInfo.peerId, r)
|
||||
|
||||
let futs = collect(newSeq()):
|
||||
for peer in peers:
|
||||
trace "Send Advertise", peerId = peer, ns
|
||||
rdv.advertisePeer(peer, msg.buffer).withTimeout(5.seconds)
|
||||
|
||||
await allFutures(futs)
|
||||
|
||||
method advertise*(
|
||||
rdv: RendezVous, ns: string, ttl: Duration = rdv.minDuration
|
||||
) {.base, async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
await rdv.advertise(ns, ttl, rdv.peers)
|
||||
|
||||
proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] =
|
||||
let
|
||||
nsSalted = ns & rdv.salt
|
||||
n = Moment.now()
|
||||
try:
|
||||
collect(newSeq()):
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].expiration > n:
|
||||
let res = SignedPeerRecord.decode(rdv.registered[index].data.signedPeerRecord).valueOr:
|
||||
continue
|
||||
res.data
|
||||
except KeyError as exc:
|
||||
@[]
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int, peers: seq[PeerId]
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
var
|
||||
s: Table[PeerId, (PeerRecord, Register)]
|
||||
limit: uint64
|
||||
d = Discover(ns: ns)
|
||||
|
||||
if l <= 0 or l > DiscoverLimit.int:
|
||||
raise newException(AdvertiseError, "Invalid limit")
|
||||
if ns.isSome() and ns.get().len > MaximumNamespaceLen:
|
||||
raise newException(AdvertiseError, "Invalid namespace")
|
||||
|
||||
limit = l.uint64
|
||||
proc requestPeer(
|
||||
peer: PeerId
|
||||
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
|
||||
let conn = await rdv.switch.dial(peer, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
d.limit = Opt.some(limit)
|
||||
d.cookie =
|
||||
if ns.isSome():
|
||||
try:
|
||||
Opt.some(rdv.cookiesSaved[peer][ns.get()])
|
||||
except KeyError, CatchableError:
|
||||
Opt.none(seq[byte])
|
||||
else:
|
||||
Opt.none(seq[byte])
|
||||
await conn.writeLp(
|
||||
encode(Message(msgType: MessageType.Discover, discover: Opt.some(d))).buffer
|
||||
)
|
||||
let
|
||||
buf = await conn.readLp(MaximumMessageLen)
|
||||
msgRcv = Message.decode(buf).valueOr:
|
||||
debug "Message undecodable"
|
||||
return
|
||||
if msgRcv.msgType != MessageType.DiscoverResponse:
|
||||
debug "Unexpected discover response", msgType = msgRcv.msgType
|
||||
return
|
||||
let resp = msgRcv.discoverResponse.valueOr:
|
||||
debug "Discover response is empty"
|
||||
return
|
||||
if resp.status != ResponseStatus.Ok:
|
||||
trace "Cannot discover", ns, status = resp.status, text = resp.text
|
||||
return
|
||||
resp.cookie.withValue(cookie):
|
||||
if ns.isSome:
|
||||
let namespace = ns.get()
|
||||
if cookie.len() < 1000 and
|
||||
rdv.cookiesSaved.hasKeyOrPut(peer, {namespace: cookie}.toTable()):
|
||||
try:
|
||||
rdv.cookiesSaved[peer][namespace] = cookie
|
||||
except KeyError:
|
||||
raiseAssert "checked with hasKeyOrPut"
|
||||
for r in resp.registrations:
|
||||
if limit == 0:
|
||||
return
|
||||
let ttl = r.ttl.get(rdv.maxTTL + 1)
|
||||
if ttl > rdv.maxTTL:
|
||||
continue
|
||||
let
|
||||
spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr:
|
||||
continue
|
||||
pr = spr.data
|
||||
if s.hasKey(pr.peerId):
|
||||
let (prSaved, rSaved) =
|
||||
try:
|
||||
s[pr.peerId]
|
||||
except KeyError:
|
||||
raiseAssert "checked with hasKey"
|
||||
if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(rdv.maxTTL) < ttl) or
|
||||
prSaved.seqNo < pr.seqNo:
|
||||
s[pr.peerId] = (pr, r)
|
||||
else:
|
||||
s[pr.peerId] = (pr, r)
|
||||
limit.dec()
|
||||
if ns.isSome():
|
||||
for (_, r) in s.values():
|
||||
rdv.save(ns.get(), peer, r, false)
|
||||
|
||||
for peer in peers:
|
||||
if limit == 0:
|
||||
break
|
||||
if RendezVousCodec notin rdv.switch.peerStore[ProtoBook][peer]:
|
||||
continue
|
||||
try:
|
||||
trace "Send Request", peerId = peer, ns
|
||||
await peer.requestPeer()
|
||||
except CancelledError as e:
|
||||
raise e
|
||||
except DialFailedError as e:
|
||||
trace "failed to dial a peer", description = e.msg
|
||||
except LPStreamError as e:
|
||||
trace "failed to communicate with a peer", description = e.msg
|
||||
return toSeq(s.values()).mapIt(it[0])
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
await rdv.request(ns, l, rdv.peers)
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, l: int = DiscoverLimit.int
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
await rdv.request(Opt.none(string), l, rdv.peers)
|
||||
|
||||
proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
|
||||
let nsSalted = ns & rdv.salt
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == rdv.switch.peerInfo.peerId:
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
proc unsubscribe*(
|
||||
rdv: RendezVous, ns: string, peerIds: seq[PeerId]
|
||||
) {.async: (raises: [RendezVousError, CancelledError]).} =
|
||||
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
|
||||
raise newException(RendezVousError, "Invalid namespace")
|
||||
|
||||
let msg = encode(
|
||||
Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns)))
|
||||
)
|
||||
|
||||
proc unsubscribePeer(peerId: PeerId) {.async: (raises: []).} =
|
||||
try:
|
||||
let conn = await rdv.switch.dial(peerId, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
await conn.writeLp(msg.buffer)
|
||||
except CatchableError as exc:
|
||||
trace "exception while unsubscribing", description = exc.msg
|
||||
|
||||
let futs = collect(newSeq()):
|
||||
for peer in peerIds:
|
||||
unsubscribePeer(peer)
|
||||
|
||||
await allFutures(futs)
|
||||
|
||||
proc unsubscribe*(
|
||||
rdv: RendezVous, ns: string
|
||||
) {.async: (raises: [RendezVousError, CancelledError]).} =
|
||||
rdv.unsubscribeLocally(ns)
|
||||
|
||||
await rdv.unsubscribe(ns, rdv.peers)
|
||||
|
||||
proc setup*(rdv: RendezVous, switch: Switch) =
|
||||
rdv.switch = switch
|
||||
proc handlePeer(
|
||||
peerId: PeerId, event: PeerEvent
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
if event.kind == PeerEventKind.Joined:
|
||||
rdv.peers.add(peerId)
|
||||
elif event.kind == PeerEventKind.Left:
|
||||
rdv.peers.keepItIf(it != peerId)
|
||||
|
||||
rdv.switch.addPeerEventHandler(handlePeer, Joined)
|
||||
rdv.switch.addPeerEventHandler(handlePeer, Left)
|
||||
|
||||
proc new*(
|
||||
T: typedesc[RendezVous],
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
minDuration = MinimumDuration,
|
||||
maxDuration = MaximumDuration,
|
||||
): T {.raises: [RendezVousError].} =
|
||||
if minDuration < MinimumAcceptedDuration:
|
||||
raise newException(RendezVousError, "TTL too short: 1 minute minimum")
|
||||
|
||||
if maxDuration > MaximumDuration:
|
||||
raise newException(RendezVousError, "TTL too long: 72 hours maximum")
|
||||
|
||||
if minDuration >= maxDuration:
|
||||
raise newException(RendezVousError, "Minimum TTL longer than maximum")
|
||||
|
||||
let
|
||||
minTTL = minDuration.seconds.uint64
|
||||
maxTTL = maxDuration.seconds.uint64
|
||||
|
||||
let rdv = T(
|
||||
rng: rng,
|
||||
salt: string.fromBytes(generateBytes(rng[], 8)),
|
||||
registered: initOffsettedSeq[RegisteredData](),
|
||||
expiredDT: Moment.now() - 1.days,
|
||||
#registerEvent: newAsyncEvent(),
|
||||
sema: newAsyncSemaphore(SemaphoreDefaultSize),
|
||||
minDuration: minDuration,
|
||||
maxDuration: maxDuration,
|
||||
minTTL: minTTL,
|
||||
maxTTL: maxTTL,
|
||||
)
|
||||
logScope:
|
||||
topics = "libp2p discovery rendezvous"
|
||||
proc handleStream(
|
||||
conn: Connection, proto: string
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
try:
|
||||
let
|
||||
buf = await conn.readLp(4096)
|
||||
msg = Message.decode(buf).tryGet()
|
||||
case msg.msgType
|
||||
of MessageType.Register:
|
||||
await rdv.register(conn, msg.register.tryGet())
|
||||
of MessageType.RegisterResponse:
|
||||
trace "Got an unexpected Register Response", response = msg.registerResponse
|
||||
of MessageType.Unregister:
|
||||
rdv.unregister(conn, msg.unregister.tryGet())
|
||||
of MessageType.Discover:
|
||||
await rdv.discover(conn, msg.discover.tryGet())
|
||||
of MessageType.DiscoverResponse:
|
||||
trace "Got an unexpected Discover Response", response = msg.discoverResponse
|
||||
except CancelledError as exc:
|
||||
trace "cancelled rendezvous handler"
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "exception in rendezvous handler", description = exc.msg
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
rdv.handler = handleStream
|
||||
rdv.codec = RendezVousCodec
|
||||
return rdv
|
||||
|
||||
proc new*(
|
||||
T: typedesc[RendezVous],
|
||||
switch: Switch,
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
minDuration = MinimumDuration,
|
||||
maxDuration = MaximumDuration,
|
||||
): T {.raises: [RendezVousError].} =
|
||||
let rdv = T.new(rng, minDuration, maxDuration)
|
||||
rdv.setup(switch)
|
||||
return rdv
|
||||
|
||||
proc deletesRegister*(
|
||||
rdv: RendezVous, interval = 1.minutes
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
heartbeat "Register timeout", interval:
|
||||
let n = Moment.now()
|
||||
var total = 0
|
||||
rdv.registered.flushIfIt(it.expiration < n)
|
||||
for data in rdv.namespaces.mvalues():
|
||||
data.keepItIf(it >= rdv.registered.offset)
|
||||
total += data.len
|
||||
libp2p_rendezvous_registered.set(int64(total))
|
||||
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
|
||||
|
||||
method start*(
|
||||
rdv: RendezVous
|
||||
): Future[void] {.async: (raises: [CancelledError], raw: true).} =
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
if not rdv.registerDeletionLoop.isNil:
|
||||
warn "Starting rendezvous twice"
|
||||
return fut
|
||||
rdv.registerDeletionLoop = rdv.deletesRegister()
|
||||
rdv.started = true
|
||||
fut
|
||||
|
||||
method stop*(rdv: RendezVous): Future[void] {.async: (raises: [], raw: true).} =
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
if rdv.registerDeletionLoop.isNil:
|
||||
warn "Stopping rendezvous without starting it"
|
||||
return fut
|
||||
rdv.started = false
|
||||
rdv.registerDeletionLoop.cancelSoon()
|
||||
rdv.registerDeletionLoop = nil
|
||||
fut
|
||||
export rendezvous
|
||||
|
||||
275
libp2p/protocols/rendezvous/protobuf.nim
Normal file
275
libp2p/protocols/rendezvous/protobuf.nim
Normal file
@@ -0,0 +1,275 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import results
|
||||
import stew/objects
|
||||
import ../../protobuf/minprotobuf
|
||||
|
||||
type
|
||||
MessageType* {.pure.} = enum
|
||||
Register = 0
|
||||
RegisterResponse = 1
|
||||
Unregister = 2
|
||||
Discover = 3
|
||||
DiscoverResponse = 4
|
||||
|
||||
ResponseStatus* = enum
|
||||
Ok = 0
|
||||
InvalidNamespace = 100
|
||||
InvalidSignedPeerRecord = 101
|
||||
InvalidTTL = 102
|
||||
InvalidCookie = 103
|
||||
NotAuthorized = 200
|
||||
InternalError = 300
|
||||
Unavailable = 400
|
||||
|
||||
Cookie* = object
|
||||
offset*: uint64
|
||||
ns*: Opt[string]
|
||||
|
||||
Register* = object
|
||||
ns*: string
|
||||
signedPeerRecord*: seq[byte]
|
||||
ttl*: Opt[uint64] # in seconds
|
||||
|
||||
RegisterResponse* = object
|
||||
status*: ResponseStatus
|
||||
text*: Opt[string]
|
||||
ttl*: Opt[uint64] # in seconds
|
||||
|
||||
Unregister* = object
|
||||
ns*: string
|
||||
|
||||
Discover* = object
|
||||
ns*: Opt[string]
|
||||
limit*: Opt[uint64]
|
||||
cookie*: Opt[seq[byte]]
|
||||
|
||||
DiscoverResponse* = object
|
||||
registrations*: seq[Register]
|
||||
cookie*: Opt[seq[byte]]
|
||||
status*: ResponseStatus
|
||||
text*: Opt[string]
|
||||
|
||||
Message* = object
|
||||
msgType*: MessageType
|
||||
register*: Opt[Register]
|
||||
registerResponse*: Opt[RegisterResponse]
|
||||
unregister*: Opt[Unregister]
|
||||
discover*: Opt[Discover]
|
||||
discoverResponse*: Opt[DiscoverResponse]
|
||||
|
||||
proc encode*(c: Cookie): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, c.offset)
|
||||
if c.ns.isSome():
|
||||
result.write(2, c.ns.get())
|
||||
result.finish()
|
||||
|
||||
proc encode*(r: Register): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, r.ns)
|
||||
result.write(2, r.signedPeerRecord)
|
||||
r.ttl.withValue(ttl):
|
||||
result.write(3, ttl)
|
||||
result.finish()
|
||||
|
||||
proc encode*(rr: RegisterResponse): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, rr.status.uint)
|
||||
rr.text.withValue(text):
|
||||
result.write(2, text)
|
||||
rr.ttl.withValue(ttl):
|
||||
result.write(3, ttl)
|
||||
result.finish()
|
||||
|
||||
proc encode*(u: Unregister): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, u.ns)
|
||||
result.finish()
|
||||
|
||||
proc encode*(d: Discover): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
if d.ns.isSome():
|
||||
result.write(1, d.ns.get())
|
||||
d.limit.withValue(limit):
|
||||
result.write(2, limit)
|
||||
d.cookie.withValue(cookie):
|
||||
result.write(3, cookie)
|
||||
result.finish()
|
||||
|
||||
proc encode*(dr: DiscoverResponse): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
for reg in dr.registrations:
|
||||
result.write(1, reg.encode())
|
||||
dr.cookie.withValue(cookie):
|
||||
result.write(2, cookie)
|
||||
result.write(3, dr.status.uint)
|
||||
dr.text.withValue(text):
|
||||
result.write(4, text)
|
||||
result.finish()
|
||||
|
||||
proc encode*(msg: Message): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
result.write(1, msg.msgType.uint)
|
||||
msg.register.withValue(register):
|
||||
result.write(2, register.encode())
|
||||
msg.registerResponse.withValue(registerResponse):
|
||||
result.write(3, registerResponse.encode())
|
||||
msg.unregister.withValue(unregister):
|
||||
result.write(4, unregister.encode())
|
||||
msg.discover.withValue(discover):
|
||||
result.write(5, discover.encode())
|
||||
msg.discoverResponse.withValue(discoverResponse):
|
||||
result.write(6, discoverResponse.encode())
|
||||
result.finish()
|
||||
|
||||
proc decode*(_: typedesc[Cookie], buf: seq[byte]): Opt[Cookie] =
|
||||
var
|
||||
c: Cookie
|
||||
ns: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, c.offset)
|
||||
r2 = pb.getField(2, ns)
|
||||
if r1.isErr() or r2.isErr():
|
||||
return Opt.none(Cookie)
|
||||
if r2.get(false):
|
||||
c.ns = Opt.some(ns)
|
||||
Opt.some(c)
|
||||
|
||||
proc decode*(_: typedesc[Register], buf: seq[byte]): Opt[Register] =
|
||||
var
|
||||
r: Register
|
||||
ttl: uint64
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, r.ns)
|
||||
r2 = pb.getRequiredField(2, r.signedPeerRecord)
|
||||
r3 = pb.getField(3, ttl)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr():
|
||||
return Opt.none(Register)
|
||||
if r3.get(false):
|
||||
r.ttl = Opt.some(ttl)
|
||||
Opt.some(r)
|
||||
|
||||
proc decode*(_: typedesc[RegisterResponse], buf: seq[byte]): Opt[RegisterResponse] =
|
||||
var
|
||||
rr: RegisterResponse
|
||||
statusOrd: uint
|
||||
text: string
|
||||
ttl: uint64
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, statusOrd)
|
||||
r2 = pb.getField(2, text)
|
||||
r3 = pb.getField(3, ttl)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr() or
|
||||
not checkedEnumAssign(rr.status, statusOrd):
|
||||
return Opt.none(RegisterResponse)
|
||||
if r2.get(false):
|
||||
rr.text = Opt.some(text)
|
||||
if r3.get(false):
|
||||
rr.ttl = Opt.some(ttl)
|
||||
Opt.some(rr)
|
||||
|
||||
proc decode*(_: typedesc[Unregister], buf: seq[byte]): Opt[Unregister] =
|
||||
var u: Unregister
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRequiredField(1, u.ns)
|
||||
if r1.isErr():
|
||||
return Opt.none(Unregister)
|
||||
Opt.some(u)
|
||||
|
||||
proc decode*(_: typedesc[Discover], buf: seq[byte]): Opt[Discover] =
|
||||
var
|
||||
d: Discover
|
||||
limit: uint64
|
||||
cookie: seq[byte]
|
||||
ns: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getField(1, ns)
|
||||
r2 = pb.getField(2, limit)
|
||||
r3 = pb.getField(3, cookie)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr:
|
||||
return Opt.none(Discover)
|
||||
if r1.get(false):
|
||||
d.ns = Opt.some(ns)
|
||||
if r2.get(false):
|
||||
d.limit = Opt.some(limit)
|
||||
if r3.get(false):
|
||||
d.cookie = Opt.some(cookie)
|
||||
Opt.some(d)
|
||||
|
||||
proc decode*(_: typedesc[DiscoverResponse], buf: seq[byte]): Opt[DiscoverResponse] =
|
||||
var
|
||||
dr: DiscoverResponse
|
||||
registrations: seq[seq[byte]]
|
||||
cookie: seq[byte]
|
||||
statusOrd: uint
|
||||
text: string
|
||||
let
|
||||
pb = initProtoBuffer(buf)
|
||||
r1 = pb.getRepeatedField(1, registrations)
|
||||
r2 = pb.getField(2, cookie)
|
||||
r3 = pb.getRequiredField(3, statusOrd)
|
||||
r4 = pb.getField(4, text)
|
||||
if r1.isErr() or r2.isErr() or r3.isErr or r4.isErr() or
|
||||
not checkedEnumAssign(dr.status, statusOrd):
|
||||
return Opt.none(DiscoverResponse)
|
||||
for reg in registrations:
|
||||
var r: Register
|
||||
let regOpt = Register.decode(reg).valueOr:
|
||||
return
|
||||
dr.registrations.add(regOpt)
|
||||
if r2.get(false):
|
||||
dr.cookie = Opt.some(cookie)
|
||||
if r4.get(false):
|
||||
dr.text = Opt.some(text)
|
||||
Opt.some(dr)
|
||||
|
||||
proc decode*(_: typedesc[Message], buf: seq[byte]): Opt[Message] =
|
||||
var
|
||||
msg: Message
|
||||
statusOrd: uint
|
||||
pbr, pbrr, pbu, pbd, pbdr: ProtoBuffer
|
||||
let pb = initProtoBuffer(buf)
|
||||
|
||||
?pb.getRequiredField(1, statusOrd).toOpt
|
||||
if not checkedEnumAssign(msg.msgType, statusOrd):
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(2, pbr).optValue:
|
||||
msg.register = Register.decode(pbr.buffer)
|
||||
if msg.register.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(3, pbrr).optValue:
|
||||
msg.registerResponse = RegisterResponse.decode(pbrr.buffer)
|
||||
if msg.registerResponse.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(4, pbu).optValue:
|
||||
msg.unregister = Unregister.decode(pbu.buffer)
|
||||
if msg.unregister.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(5, pbd).optValue:
|
||||
msg.discover = Discover.decode(pbd.buffer)
|
||||
if msg.discover.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
if ?pb.getField(6, pbdr).optValue:
|
||||
msg.discoverResponse = DiscoverResponse.decode(pbdr.buffer)
|
||||
if msg.discoverResponse.isNone():
|
||||
return Opt.none(Message)
|
||||
|
||||
Opt.some(msg)
|
||||
589
libp2p/protocols/rendezvous/rendezvous.nim
Normal file
589
libp2p/protocols/rendezvous/rendezvous.nim
Normal file
@@ -0,0 +1,589 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import tables, sequtils, sugar, sets
|
||||
import metrics except collect
|
||||
import chronos, chronicles, bearssl/rand, stew/[byteutils, objects]
|
||||
import
|
||||
./protobuf,
|
||||
../protocol,
|
||||
../../protobuf/minprotobuf,
|
||||
../../switch,
|
||||
../../routing_record,
|
||||
../../utils/heartbeat,
|
||||
../../stream/connection,
|
||||
../../utils/offsettedseq,
|
||||
../../utils/semaphore,
|
||||
../../discovery/discoverymngr
|
||||
|
||||
export chronicles
|
||||
|
||||
logScope:
|
||||
topics = "libp2p discovery rendezvous"
|
||||
|
||||
declareCounter(libp2p_rendezvous_register, "number of advertise requests")
|
||||
declareCounter(libp2p_rendezvous_discover, "number of discovery requests")
|
||||
declareGauge(libp2p_rendezvous_registered, "number of registered peers")
|
||||
declareGauge(libp2p_rendezvous_namespaces, "number of registered namespaces")
|
||||
|
||||
const
|
||||
RendezVousCodec* = "/rendezvous/1.0.0"
|
||||
# Default minimum TTL per libp2p spec
|
||||
MinimumDuration* = 2.hours
|
||||
# Lower validation limit to accommodate Waku requirements
|
||||
MinimumAcceptedDuration = 1.minutes
|
||||
MaximumDuration = 72.hours
|
||||
MaximumMessageLen = 1 shl 22 # 4MB
|
||||
MinimumNamespaceLen = 1
|
||||
MaximumNamespaceLen = 255
|
||||
RegistrationLimitPerPeer* = 1000
|
||||
DiscoverLimit = 1000'u64
|
||||
SemaphoreDefaultSize = 5
|
||||
|
||||
type
|
||||
RendezVousError* = object of DiscoveryError
|
||||
RegisteredData = object
|
||||
expiration*: Moment
|
||||
peerId*: PeerId
|
||||
data*: Register
|
||||
|
||||
RendezVous* = ref object of LPProtocol
|
||||
# Registered needs to be an offsetted sequence
|
||||
# because we need stable index for the cookies.
|
||||
registered*: OffsettedSeq[RegisteredData]
|
||||
# Namespaces is a table whose key is a salted namespace and
|
||||
# the value is the index sequence corresponding to this
|
||||
# namespace in the offsettedqueue.
|
||||
namespaces*: Table[string, seq[int]]
|
||||
rng: ref HmacDrbgContext
|
||||
salt: string
|
||||
expiredDT: Moment
|
||||
registerDeletionLoop: Future[void]
|
||||
#registerEvent: AsyncEvent # TODO: to raise during the heartbeat
|
||||
# + make the heartbeat sleep duration "smarter"
|
||||
sema: AsyncSemaphore
|
||||
peers: seq[PeerId]
|
||||
cookiesSaved*: Table[PeerId, Table[string, seq[byte]]]
|
||||
switch: Switch
|
||||
minDuration: Duration
|
||||
maxDuration: Duration
|
||||
minTTL: uint64
|
||||
maxTTL: uint64
|
||||
|
||||
proc checkPeerRecord(spr: seq[byte], peerId: PeerId): Result[void, string] =
|
||||
if spr.len == 0:
|
||||
return err("Empty peer record")
|
||||
let signedEnv = ?SignedPeerRecord.decode(spr).mapErr(x => $x)
|
||||
if signedEnv.data.peerId != peerId:
|
||||
return err("Bad Peer ID")
|
||||
return ok()
|
||||
|
||||
proc sendRegisterResponse(
|
||||
conn: Connection, ttl: uint64
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.RegisterResponse,
|
||||
registerResponse: Opt.some(RegisterResponse(status: Ok, ttl: Opt.some(ttl))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendRegisterResponseError(
|
||||
conn: Connection, status: ResponseStatus, text: string = ""
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.RegisterResponse,
|
||||
registerResponse: Opt.some(RegisterResponse(status: status, text: Opt.some(text))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendDiscoverResponse(
|
||||
conn: Connection, s: seq[Register], cookie: Cookie
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.DiscoverResponse,
|
||||
discoverResponse: Opt.some(
|
||||
DiscoverResponse(
|
||||
status: Ok, registrations: s, cookie: Opt.some(cookie.encode().buffer)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc sendDiscoverResponseError(
|
||||
conn: Connection, status: ResponseStatus, text: string = ""
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let msg = encode(
|
||||
Message(
|
||||
msgType: MessageType.DiscoverResponse,
|
||||
discoverResponse: Opt.some(DiscoverResponse(status: status, text: Opt.some(text))),
|
||||
)
|
||||
)
|
||||
await conn.writeLp(msg.buffer)
|
||||
|
||||
proc countRegister(rdv: RendezVous, peerId: PeerId): int =
|
||||
for data in rdv.registered:
|
||||
if data.peerId == peerId:
|
||||
result.inc()
|
||||
|
||||
proc save(
|
||||
rdv: RendezVous, ns: string, peerId: PeerId, r: Register, update: bool = true
|
||||
) =
|
||||
let nsSalted = ns & rdv.salt
|
||||
discard rdv.namespaces.hasKeyOrPut(nsSalted, newSeq[int]())
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == peerId:
|
||||
if update == false:
|
||||
return
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
rdv.registered.add(
|
||||
RegisteredData(
|
||||
peerId: peerId,
|
||||
expiration: Moment.now() + r.ttl.get(rdv.minTTL).int64.seconds,
|
||||
data: r,
|
||||
)
|
||||
)
|
||||
rdv.namespaces[nsSalted].add(rdv.registered.high)
|
||||
# rdv.registerEvent.fire()
|
||||
except KeyError as e:
|
||||
doAssert false, "Should have key: " & e.msg
|
||||
|
||||
proc register(rdv: RendezVous, conn: Connection, r: Register): Future[void] =
|
||||
trace "Received Register", peerId = conn.peerId, ns = r.ns
|
||||
libp2p_rendezvous_register.inc()
|
||||
if r.ns.len < MinimumNamespaceLen or r.ns.len > MaximumNamespaceLen:
|
||||
return conn.sendRegisterResponseError(InvalidNamespace)
|
||||
let ttl = r.ttl.get(rdv.minTTL)
|
||||
if ttl < rdv.minTTL or ttl > rdv.maxTTL:
|
||||
return conn.sendRegisterResponseError(InvalidTTL)
|
||||
let pr = checkPeerRecord(r.signedPeerRecord, conn.peerId)
|
||||
if pr.isErr():
|
||||
return conn.sendRegisterResponseError(InvalidSignedPeerRecord, pr.error())
|
||||
if rdv.countRegister(conn.peerId) >= RegistrationLimitPerPeer:
|
||||
return conn.sendRegisterResponseError(NotAuthorized, "Registration limit reached")
|
||||
rdv.save(r.ns, conn.peerId, r)
|
||||
libp2p_rendezvous_registered.inc()
|
||||
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
|
||||
conn.sendRegisterResponse(ttl)
|
||||
|
||||
proc unregister(rdv: RendezVous, conn: Connection, u: Unregister) =
|
||||
trace "Received Unregister", peerId = conn.peerId, ns = u.ns
|
||||
let nsSalted = u.ns & rdv.salt
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == conn.peerId:
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
libp2p_rendezvous_registered.dec()
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
proc discover(
|
||||
rdv: RendezVous, conn: Connection, d: Discover
|
||||
) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
trace "Received Discover", peerId = conn.peerId, ns = d.ns
|
||||
libp2p_rendezvous_discover.inc()
|
||||
if d.ns.isSome() and d.ns.get().len > MaximumNamespaceLen:
|
||||
await conn.sendDiscoverResponseError(InvalidNamespace)
|
||||
return
|
||||
var limit = min(DiscoverLimit, d.limit.get(DiscoverLimit))
|
||||
var cookie =
|
||||
if d.cookie.isSome():
|
||||
try:
|
||||
Cookie.decode(d.cookie.tryGet()).tryGet()
|
||||
except CatchableError:
|
||||
await conn.sendDiscoverResponseError(InvalidCookie)
|
||||
return
|
||||
else:
|
||||
# Start from the current lowest index (inclusive)
|
||||
Cookie(offset: rdv.registered.low().uint64)
|
||||
if d.ns.isSome() and cookie.ns.isSome() and cookie.ns.get() != d.ns.get():
|
||||
# Namespace changed: start from the beginning of that namespace
|
||||
cookie = Cookie(offset: rdv.registered.low().uint64)
|
||||
elif cookie.offset < rdv.registered.low().uint64:
|
||||
# Cookie behind available range: reset to current low
|
||||
cookie.offset = rdv.registered.low().uint64
|
||||
elif cookie.offset > (rdv.registered.high() + 1).uint64:
|
||||
# Cookie ahead of available range: reset to one past current high (empty page)
|
||||
cookie.offset = (rdv.registered.high() + 1).uint64
|
||||
let namespaces =
|
||||
if d.ns.isSome():
|
||||
try:
|
||||
rdv.namespaces[d.ns.get() & rdv.salt]
|
||||
except KeyError:
|
||||
await conn.sendDiscoverResponseError(InvalidNamespace)
|
||||
return
|
||||
else:
|
||||
toSeq(max(cookie.offset.int, rdv.registered.offset) .. rdv.registered.high())
|
||||
if namespaces.len() == 0:
|
||||
await conn.sendDiscoverResponse(@[], Cookie())
|
||||
return
|
||||
var nextOffset = cookie.offset
|
||||
let n = Moment.now()
|
||||
var s = collect(newSeq()):
|
||||
for index in namespaces:
|
||||
var reg = rdv.registered[index]
|
||||
if limit == 0:
|
||||
break
|
||||
if reg.expiration < n or index.uint64 < cookie.offset:
|
||||
continue
|
||||
limit.dec()
|
||||
nextOffset = index.uint64 + 1
|
||||
reg.data.ttl = Opt.some((reg.expiration - Moment.now()).seconds.uint64)
|
||||
reg.data
|
||||
rdv.rng.shuffle(s)
|
||||
await conn.sendDiscoverResponse(s, Cookie(offset: nextOffset, ns: d.ns))
|
||||
|
||||
proc advertisePeer(
|
||||
rdv: RendezVous, peer: PeerId, msg: seq[byte]
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
proc advertiseWrap() {.async: (raises: []).} =
|
||||
try:
|
||||
let conn = await rdv.switch.dial(peer, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
await conn.writeLp(msg)
|
||||
let
|
||||
buf = await conn.readLp(4096)
|
||||
msgRecv = Message.decode(buf).tryGet()
|
||||
if msgRecv.msgType != MessageType.RegisterResponse:
|
||||
trace "Unexpected register response", peer, msgType = msgRecv.msgType
|
||||
elif msgRecv.registerResponse.tryGet().status != ResponseStatus.Ok:
|
||||
trace "Refuse to register", peer, response = msgRecv.registerResponse
|
||||
else:
|
||||
trace "Successfully registered", peer, response = msgRecv.registerResponse
|
||||
except CatchableError as exc:
|
||||
trace "exception in the advertise", description = exc.msg
|
||||
finally:
|
||||
rdv.sema.release()
|
||||
|
||||
await rdv.sema.acquire()
|
||||
await advertiseWrap()
|
||||
|
||||
proc advertise*(
|
||||
rdv: RendezVous, ns: string, ttl: Duration, peers: seq[PeerId]
|
||||
) {.async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
|
||||
raise newException(AdvertiseError, "Invalid namespace")
|
||||
|
||||
if ttl < rdv.minDuration or ttl > rdv.maxDuration:
|
||||
raise newException(AdvertiseError, "Invalid time to live: " & $ttl)
|
||||
|
||||
let sprBuff = rdv.switch.peerInfo.signedPeerRecord.encode().valueOr:
|
||||
raise newException(AdvertiseError, "Wrong Signed Peer Record")
|
||||
|
||||
let
|
||||
r = Register(ns: ns, signedPeerRecord: sprBuff, ttl: Opt.some(ttl.seconds.uint64))
|
||||
msg = encode(Message(msgType: MessageType.Register, register: Opt.some(r)))
|
||||
|
||||
rdv.save(ns, rdv.switch.peerInfo.peerId, r)
|
||||
|
||||
let futs = collect(newSeq()):
|
||||
for peer in peers:
|
||||
trace "Send Advertise", peerId = peer, ns
|
||||
rdv.advertisePeer(peer, msg.buffer).withTimeout(5.seconds)
|
||||
|
||||
await allFutures(futs)
|
||||
|
||||
method advertise*(
|
||||
rdv: RendezVous, ns: string, ttl: Duration = rdv.minDuration
|
||||
) {.base, async: (raises: [CancelledError, AdvertiseError]).} =
|
||||
await rdv.advertise(ns, ttl, rdv.peers)
|
||||
|
||||
proc requestLocally*(rdv: RendezVous, ns: string): seq[PeerRecord] =
|
||||
let
|
||||
nsSalted = ns & rdv.salt
|
||||
n = Moment.now()
|
||||
try:
|
||||
collect(newSeq()):
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].expiration > n:
|
||||
let res = SignedPeerRecord.decode(rdv.registered[index].data.signedPeerRecord).valueOr:
|
||||
continue
|
||||
res.data
|
||||
except KeyError as exc:
|
||||
@[]
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int, peers: seq[PeerId]
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
var
|
||||
s: Table[PeerId, (PeerRecord, Register)]
|
||||
limit: uint64
|
||||
d = Discover(ns: ns)
|
||||
|
||||
if l <= 0 or l > DiscoverLimit.int:
|
||||
raise newException(AdvertiseError, "Invalid limit")
|
||||
if ns.isSome() and ns.get().len > MaximumNamespaceLen:
|
||||
raise newException(AdvertiseError, "Invalid namespace")
|
||||
|
||||
limit = l.uint64
|
||||
proc requestPeer(
|
||||
peer: PeerId
|
||||
) {.async: (raises: [CancelledError, DialFailedError, LPStreamError]).} =
|
||||
let conn = await rdv.switch.dial(peer, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
d.limit = Opt.some(limit)
|
||||
d.cookie =
|
||||
if ns.isSome():
|
||||
try:
|
||||
Opt.some(rdv.cookiesSaved[peer][ns.get()])
|
||||
except KeyError, CatchableError:
|
||||
Opt.none(seq[byte])
|
||||
else:
|
||||
Opt.none(seq[byte])
|
||||
await conn.writeLp(
|
||||
encode(Message(msgType: MessageType.Discover, discover: Opt.some(d))).buffer
|
||||
)
|
||||
let
|
||||
buf = await conn.readLp(MaximumMessageLen)
|
||||
msgRcv = Message.decode(buf).valueOr:
|
||||
debug "Message undecodable"
|
||||
return
|
||||
if msgRcv.msgType != MessageType.DiscoverResponse:
|
||||
debug "Unexpected discover response", msgType = msgRcv.msgType
|
||||
return
|
||||
let resp = msgRcv.discoverResponse.valueOr:
|
||||
debug "Discover response is empty"
|
||||
return
|
||||
if resp.status != ResponseStatus.Ok:
|
||||
trace "Cannot discover", ns, status = resp.status, text = resp.text
|
||||
return
|
||||
resp.cookie.withValue(cookie):
|
||||
if ns.isSome:
|
||||
let namespace = ns.get()
|
||||
if cookie.len() < 1000 and
|
||||
rdv.cookiesSaved.hasKeyOrPut(peer, {namespace: cookie}.toTable()):
|
||||
try:
|
||||
rdv.cookiesSaved[peer][namespace] = cookie
|
||||
except KeyError:
|
||||
raiseAssert "checked with hasKeyOrPut"
|
||||
for r in resp.registrations:
|
||||
if limit == 0:
|
||||
return
|
||||
let ttl = r.ttl.get(rdv.maxTTL + 1)
|
||||
if ttl > rdv.maxTTL:
|
||||
continue
|
||||
let
|
||||
spr = SignedPeerRecord.decode(r.signedPeerRecord).valueOr:
|
||||
continue
|
||||
pr = spr.data
|
||||
if s.hasKey(pr.peerId):
|
||||
let (prSaved, rSaved) =
|
||||
try:
|
||||
s[pr.peerId]
|
||||
except KeyError:
|
||||
raiseAssert "checked with hasKey"
|
||||
if (prSaved.seqNo == pr.seqNo and rSaved.ttl.get(rdv.maxTTL) < ttl) or
|
||||
prSaved.seqNo < pr.seqNo:
|
||||
s[pr.peerId] = (pr, r)
|
||||
else:
|
||||
s[pr.peerId] = (pr, r)
|
||||
limit.dec()
|
||||
if ns.isSome():
|
||||
for (_, r) in s.values():
|
||||
rdv.save(ns.get(), peer, r, false)
|
||||
|
||||
for peer in peers:
|
||||
if limit == 0:
|
||||
break
|
||||
if RendezVousCodec notin rdv.switch.peerStore[ProtoBook][peer]:
|
||||
continue
|
||||
try:
|
||||
trace "Send Request", peerId = peer, ns
|
||||
await peer.requestPeer()
|
||||
except CancelledError as e:
|
||||
raise e
|
||||
except DialFailedError as e:
|
||||
trace "failed to dial a peer", description = e.msg
|
||||
except LPStreamError as e:
|
||||
trace "failed to communicate with a peer", description = e.msg
|
||||
return toSeq(s.values()).mapIt(it[0])
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, ns: Opt[string], l: int = DiscoverLimit.int
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
await rdv.request(ns, l, rdv.peers)
|
||||
|
||||
proc request*(
|
||||
rdv: RendezVous, l: int = DiscoverLimit.int
|
||||
): Future[seq[PeerRecord]] {.async: (raises: [DiscoveryError, CancelledError]).} =
|
||||
await rdv.request(Opt.none(string), l, rdv.peers)
|
||||
|
||||
proc unsubscribeLocally*(rdv: RendezVous, ns: string) =
|
||||
let nsSalted = ns & rdv.salt
|
||||
try:
|
||||
for index in rdv.namespaces[nsSalted]:
|
||||
if rdv.registered[index].peerId == rdv.switch.peerInfo.peerId:
|
||||
rdv.registered[index].expiration = rdv.expiredDT
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
proc unsubscribe*(
|
||||
rdv: RendezVous, ns: string, peerIds: seq[PeerId]
|
||||
) {.async: (raises: [RendezVousError, CancelledError]).} =
|
||||
if ns.len < MinimumNamespaceLen or ns.len > MaximumNamespaceLen:
|
||||
raise newException(RendezVousError, "Invalid namespace")
|
||||
|
||||
let msg = encode(
|
||||
Message(msgType: MessageType.Unregister, unregister: Opt.some(Unregister(ns: ns)))
|
||||
)
|
||||
|
||||
proc unsubscribePeer(peerId: PeerId) {.async: (raises: []).} =
|
||||
try:
|
||||
let conn = await rdv.switch.dial(peerId, RendezVousCodec)
|
||||
defer:
|
||||
await conn.close()
|
||||
await conn.writeLp(msg.buffer)
|
||||
except CatchableError as exc:
|
||||
trace "exception while unsubscribing", description = exc.msg
|
||||
|
||||
let futs = collect(newSeq()):
|
||||
for peer in peerIds:
|
||||
unsubscribePeer(peer)
|
||||
|
||||
await allFutures(futs)
|
||||
|
||||
proc unsubscribe*(
|
||||
rdv: RendezVous, ns: string
|
||||
) {.async: (raises: [RendezVousError, CancelledError]).} =
|
||||
rdv.unsubscribeLocally(ns)
|
||||
|
||||
await rdv.unsubscribe(ns, rdv.peers)
|
||||
|
||||
proc setup*(rdv: RendezVous, switch: Switch) =
|
||||
rdv.switch = switch
|
||||
proc handlePeer(
|
||||
peerId: PeerId, event: PeerEvent
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
if event.kind == PeerEventKind.Joined:
|
||||
rdv.peers.add(peerId)
|
||||
elif event.kind == PeerEventKind.Left:
|
||||
rdv.peers.keepItIf(it != peerId)
|
||||
|
||||
rdv.switch.addPeerEventHandler(handlePeer, Joined)
|
||||
rdv.switch.addPeerEventHandler(handlePeer, Left)
|
||||
|
||||
proc new*(
|
||||
T: typedesc[RendezVous],
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
minDuration = MinimumDuration,
|
||||
maxDuration = MaximumDuration,
|
||||
): T {.raises: [RendezVousError].} =
|
||||
if minDuration < MinimumAcceptedDuration:
|
||||
raise newException(RendezVousError, "TTL too short: 1 minute minimum")
|
||||
|
||||
if maxDuration > MaximumDuration:
|
||||
raise newException(RendezVousError, "TTL too long: 72 hours maximum")
|
||||
|
||||
if minDuration >= maxDuration:
|
||||
raise newException(RendezVousError, "Minimum TTL longer than maximum")
|
||||
|
||||
let
|
||||
minTTL = minDuration.seconds.uint64
|
||||
maxTTL = maxDuration.seconds.uint64
|
||||
|
||||
let rdv = T(
|
||||
rng: rng,
|
||||
salt: string.fromBytes(generateBytes(rng[], 8)),
|
||||
registered: initOffsettedSeq[RegisteredData](),
|
||||
expiredDT: Moment.now() - 1.days,
|
||||
#registerEvent: newAsyncEvent(),
|
||||
sema: newAsyncSemaphore(SemaphoreDefaultSize),
|
||||
minDuration: minDuration,
|
||||
maxDuration: maxDuration,
|
||||
minTTL: minTTL,
|
||||
maxTTL: maxTTL,
|
||||
)
|
||||
logScope:
|
||||
topics = "libp2p discovery rendezvous"
|
||||
proc handleStream(
|
||||
conn: Connection, proto: string
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
try:
|
||||
let
|
||||
buf = await conn.readLp(4096)
|
||||
msg = Message.decode(buf).tryGet()
|
||||
case msg.msgType
|
||||
of MessageType.Register:
|
||||
await rdv.register(conn, msg.register.tryGet())
|
||||
of MessageType.RegisterResponse:
|
||||
trace "Got an unexpected Register Response", response = msg.registerResponse
|
||||
of MessageType.Unregister:
|
||||
rdv.unregister(conn, msg.unregister.tryGet())
|
||||
of MessageType.Discover:
|
||||
await rdv.discover(conn, msg.discover.tryGet())
|
||||
of MessageType.DiscoverResponse:
|
||||
trace "Got an unexpected Discover Response", response = msg.discoverResponse
|
||||
except CancelledError as exc:
|
||||
trace "cancelled rendezvous handler"
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "exception in rendezvous handler", description = exc.msg
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
rdv.handler = handleStream
|
||||
rdv.codec = RendezVousCodec
|
||||
return rdv
|
||||
|
||||
proc new*(
|
||||
T: typedesc[RendezVous],
|
||||
switch: Switch,
|
||||
rng: ref HmacDrbgContext = newRng(),
|
||||
minDuration = MinimumDuration,
|
||||
maxDuration = MaximumDuration,
|
||||
): T {.raises: [RendezVousError].} =
|
||||
let rdv = T.new(rng, minDuration, maxDuration)
|
||||
rdv.setup(switch)
|
||||
return rdv
|
||||
|
||||
proc deletesRegister*(
|
||||
rdv: RendezVous, interval = 1.minutes
|
||||
) {.async: (raises: [CancelledError]).} =
|
||||
heartbeat "Register timeout", interval:
|
||||
let n = Moment.now()
|
||||
var total = 0
|
||||
rdv.registered.flushIfIt(it.expiration < n)
|
||||
for data in rdv.namespaces.mvalues():
|
||||
data.keepItIf(it >= rdv.registered.offset)
|
||||
total += data.len
|
||||
libp2p_rendezvous_registered.set(int64(total))
|
||||
libp2p_rendezvous_namespaces.set(int64(rdv.namespaces.len))
|
||||
|
||||
method start*(
|
||||
rdv: RendezVous
|
||||
): Future[void] {.async: (raises: [CancelledError], raw: true).} =
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
if not rdv.registerDeletionLoop.isNil:
|
||||
warn "Starting rendezvous twice"
|
||||
return fut
|
||||
rdv.registerDeletionLoop = rdv.deletesRegister()
|
||||
rdv.started = true
|
||||
fut
|
||||
|
||||
method stop*(rdv: RendezVous): Future[void] {.async: (raises: [], raw: true).} =
|
||||
let fut = newFuture[void]()
|
||||
fut.complete()
|
||||
if rdv.registerDeletionLoop.isNil:
|
||||
warn "Stopping rendezvous without starting it"
|
||||
return fut
|
||||
rdv.started = false
|
||||
rdv.registerDeletionLoop.cancelSoon()
|
||||
rdv.registerDeletionLoop = nil
|
||||
fut
|
||||
@@ -195,7 +195,7 @@ type CertGenerator =
|
||||
|
||||
type QuicTransport* = ref object of Transport
|
||||
listener: Listener
|
||||
client: QuicClient
|
||||
client: Opt[QuicClient]
|
||||
privateKey: PrivateKey
|
||||
connections: seq[P2PConnection]
|
||||
rng: ref HmacDrbgContext
|
||||
@@ -248,27 +248,33 @@ method handles*(transport: QuicTransport, address: MultiAddress): bool {.raises:
|
||||
return false
|
||||
QUIC_V1.match(address)
|
||||
|
||||
method start*(
|
||||
self: QuicTransport, addrs: seq[MultiAddress]
|
||||
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
|
||||
doAssert self.listener.isNil, "start() already called"
|
||||
#TODO handle multiple addr
|
||||
|
||||
proc makeConfig(self: QuicTransport): TLSConfig =
|
||||
let pubkey = self.privateKey.getPublicKey().valueOr:
|
||||
doAssert false, "could not obtain public key"
|
||||
return
|
||||
|
||||
try:
|
||||
if self.rng.isNil:
|
||||
self.rng = newRng()
|
||||
let cert = self.certGenerator(KeyPair(seckey: self.privateKey, pubkey: pubkey))
|
||||
let tlsConfig = TLSConfig.init(
|
||||
cert.certificate, cert.privateKey, @[alpn], Opt.some(makeCertificateVerifier())
|
||||
)
|
||||
return tlsConfig
|
||||
|
||||
let cert = self.certGenerator(KeyPair(seckey: self.privateKey, pubkey: pubkey))
|
||||
let tlsConfig = TLSConfig.init(
|
||||
cert.certificate, cert.privateKey, @[alpn], Opt.some(makeCertificateVerifier())
|
||||
)
|
||||
self.client = QuicClient.init(tlsConfig, rng = self.rng)
|
||||
self.listener =
|
||||
QuicServer.init(tlsConfig, rng = self.rng).listen(initTAddress(addrs[0]).tryGet)
|
||||
proc getRng(self: QuicTransport): ref HmacDrbgContext =
|
||||
if self.rng.isNil:
|
||||
self.rng = newRng()
|
||||
|
||||
return self.rng
|
||||
|
||||
method start*(
|
||||
self: QuicTransport, addrs: seq[MultiAddress]
|
||||
) {.async: (raises: [LPError, transport.TransportError, CancelledError]).} =
|
||||
doAssert self.listener.isNil, "start() already called"
|
||||
# TODO(#1663): handle multiple addr
|
||||
|
||||
try:
|
||||
self.listener = QuicServer.init(self.makeConfig(), rng = self.getRng()).listen(
|
||||
initTAddress(addrs[0]).tryGet
|
||||
)
|
||||
await procCall Transport(self).start(addrs)
|
||||
self.addrs[0] =
|
||||
MultiAddress.init(self.listener.localAddress(), IPPROTO_UDP).tryGet() &
|
||||
@@ -289,19 +295,21 @@ method start*(
|
||||
self.running = true
|
||||
|
||||
method stop*(transport: QuicTransport) {.async: (raises: []).} =
|
||||
if transport.running:
|
||||
let conns = transport.connections[0 .. ^1]
|
||||
for c in conns:
|
||||
await c.close()
|
||||
await procCall Transport(transport).stop()
|
||||
let conns = transport.connections[0 .. ^1]
|
||||
for c in conns:
|
||||
await c.close()
|
||||
|
||||
if not transport.listener.isNil:
|
||||
try:
|
||||
await transport.listener.stop()
|
||||
except CatchableError as exc:
|
||||
trace "Error shutting down Quic transport", description = exc.msg
|
||||
transport.listener.destroy()
|
||||
transport.running = false
|
||||
transport.listener = nil
|
||||
|
||||
transport.client = Opt.none(QuicClient)
|
||||
await procCall Transport(transport).stop()
|
||||
|
||||
proc wrapConnection(
|
||||
transport: QuicTransport, connection: QuicConnection
|
||||
): QuicSession {.raises: [TransportOsError, MaError].} =
|
||||
@@ -365,7 +373,11 @@ method dial*(
|
||||
async: (raises: [transport.TransportError, CancelledError])
|
||||
.} =
|
||||
try:
|
||||
let quicConnection = await self.client.dial(initTAddress(address).tryGet)
|
||||
if not self.client.isSome:
|
||||
self.client = Opt.some(QuicClient.init(self.makeConfig(), rng = self.getRng()))
|
||||
|
||||
let client = self.client.get()
|
||||
let quicConnection = await client.dial(initTAddress(address).tryGet)
|
||||
return self.wrapConnection(quicConnection)
|
||||
except CancelledError as e:
|
||||
raise e
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{.used.}
|
||||
|
||||
import testdiscoverymngr, testrendezvous, testrendezvousinterface
|
||||
import
|
||||
testdiscoverymngr, testrendezvous, testrendezvousprotobuf, testrendezvousinterface
|
||||
|
||||
256
tests/discovery/testrendezvousprotobuf.nim
Normal file
256
tests/discovery/testrendezvousprotobuf.nim
Normal file
@@ -0,0 +1,256 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2025 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import ../../libp2p/[protocols/rendezvous/protobuf]
|
||||
import ../../libp2p/protobuf/minprotobuf
|
||||
import ../helpers
|
||||
|
||||
suite "RendezVous Protobuf":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
|
||||
const namespace = "ns"
|
||||
|
||||
test "Cookie roundtrip with namespace":
|
||||
let originalCookie = Cookie(offset: 42'u64, ns: Opt.some(namespace))
|
||||
let decodeResult = Cookie.decode(originalCookie.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedCookie = decodeResult.get()
|
||||
check:
|
||||
decodedCookie.offset == originalCookie.offset
|
||||
decodedCookie.ns.get() == originalCookie.ns.get()
|
||||
decodedCookie.encode().buffer == originalCookie.encode().buffer # roundtrip again
|
||||
|
||||
test "Cookie roundtrip without namespace":
|
||||
let originalCookie = Cookie(offset: 7'u64)
|
||||
let decodeResult = Cookie.decode(originalCookie.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedCookie = decodeResult.get()
|
||||
check:
|
||||
decodedCookie.offset == originalCookie.offset
|
||||
decodedCookie.ns.isNone()
|
||||
decodedCookie.encode().buffer == originalCookie.encode().buffer
|
||||
|
||||
test "Cookie decode fails when missing offset":
|
||||
var emptyCookieBuffer = initProtoBuffer()
|
||||
check Cookie.decode(emptyCookieBuffer.buffer).isNone()
|
||||
|
||||
test "Register roundtrip with ttl":
|
||||
let originalRegister =
|
||||
Register(ns: namespace, signedPeerRecord: @[byte 1, 2, 3], ttl: Opt.some(60'u64))
|
||||
let decodeResult = Register.decode(originalRegister.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedRegister = decodeResult.get()
|
||||
check:
|
||||
decodedRegister.ns == originalRegister.ns
|
||||
decodedRegister.signedPeerRecord == originalRegister.signedPeerRecord
|
||||
decodedRegister.ttl.get() == originalRegister.ttl.get()
|
||||
decodedRegister.encode().buffer == originalRegister.encode().buffer
|
||||
|
||||
test "Register roundtrip without ttl":
|
||||
let originalRegister = Register(ns: namespace, signedPeerRecord: @[byte 4, 5])
|
||||
let decodeResult = Register.decode(originalRegister.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedRegister = decodeResult.get()
|
||||
check:
|
||||
decodedRegister.ns == originalRegister.ns
|
||||
decodedRegister.signedPeerRecord == originalRegister.signedPeerRecord
|
||||
decodedRegister.ttl.isNone()
|
||||
decodedRegister.encode().buffer == originalRegister.encode().buffer
|
||||
|
||||
test "Register decode fails when missing namespace":
|
||||
var bufferMissingNamespace = initProtoBuffer()
|
||||
bufferMissingNamespace.write(2, @[byte 1])
|
||||
check Register.decode(bufferMissingNamespace.buffer).isNone()
|
||||
|
||||
test "Register decode fails when missing signedPeerRecord":
|
||||
var bufferMissingSignedPeerRecord = initProtoBuffer()
|
||||
bufferMissingSignedPeerRecord.write(1, namespace)
|
||||
check Register.decode(bufferMissingSignedPeerRecord.buffer).isNone()
|
||||
|
||||
test "RegisterResponse roundtrip successful":
|
||||
let originalResponse = RegisterResponse(
|
||||
status: ResponseStatus.Ok, text: Opt.some("ok"), ttl: Opt.some(10'u64)
|
||||
)
|
||||
let decodeResult = RegisterResponse.decode(originalResponse.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedResponse = decodeResult.get()
|
||||
check:
|
||||
decodedResponse.status == originalResponse.status
|
||||
decodedResponse.text.get() == originalResponse.text.get()
|
||||
decodedResponse.ttl.get() == originalResponse.ttl.get()
|
||||
decodedResponse.encode().buffer == originalResponse.encode().buffer
|
||||
|
||||
test "RegisterResponse roundtrip failed":
|
||||
let originalResponse = RegisterResponse(status: ResponseStatus.InvalidNamespace)
|
||||
let decodeResult = RegisterResponse.decode(originalResponse.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedResponse = decodeResult.get()
|
||||
check:
|
||||
decodedResponse.status == originalResponse.status
|
||||
decodedResponse.text.isNone()
|
||||
decodedResponse.ttl.isNone()
|
||||
decodedResponse.encode().buffer == originalResponse.encode().buffer
|
||||
|
||||
test "RegisterResponse decode fails invalid status enum":
|
||||
var bufferInvalidStatusValue = initProtoBuffer()
|
||||
bufferInvalidStatusValue.write(1, uint(999))
|
||||
check RegisterResponse.decode(bufferInvalidStatusValue.buffer).isNone()
|
||||
|
||||
test "RegisterResponse decode fails missing status":
|
||||
var bufferMissingStatusField = initProtoBuffer()
|
||||
bufferMissingStatusField.write(2, "msg")
|
||||
check RegisterResponse.decode(bufferMissingStatusField.buffer).isNone()
|
||||
|
||||
test "Unregister roundtrip":
|
||||
let originalUnregister = Unregister(ns: namespace)
|
||||
let decodeResult = Unregister.decode(originalUnregister.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedUnregister = decodeResult.get()
|
||||
check:
|
||||
decodedUnregister.ns == originalUnregister.ns
|
||||
decodedUnregister.encode().buffer == originalUnregister.encode().buffer
|
||||
|
||||
test "Unregister decode fails when missing namespace":
|
||||
var bufferMissingUnregisterNamespace = initProtoBuffer()
|
||||
check Unregister.decode(bufferMissingUnregisterNamespace.buffer).isNone()
|
||||
|
||||
test "Discover roundtrip with all optional fields":
|
||||
let originalDiscover = Discover(
|
||||
ns: Opt.some(namespace), limit: Opt.some(5'u64), cookie: Opt.some(@[byte 1, 2, 3])
|
||||
)
|
||||
let decodeResult = Discover.decode(originalDiscover.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedDiscover = decodeResult.get()
|
||||
check:
|
||||
decodedDiscover.ns.get() == originalDiscover.ns.get()
|
||||
decodedDiscover.limit.get() == originalDiscover.limit.get()
|
||||
decodedDiscover.cookie.get() == originalDiscover.cookie.get()
|
||||
decodedDiscover.encode().buffer == originalDiscover.encode().buffer
|
||||
|
||||
test "Discover decode empty buffer yields empty object":
|
||||
var emptyDiscoverBuffer: seq[byte] = @[]
|
||||
let decodeResult = Discover.decode(emptyDiscoverBuffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedDiscover = decodeResult.get()
|
||||
check:
|
||||
decodedDiscover.ns.isNone()
|
||||
decodedDiscover.limit.isNone()
|
||||
decodedDiscover.cookie.isNone()
|
||||
|
||||
test "DiscoverResponse roundtrip with registration":
|
||||
let registrationEntry = Register(ns: namespace, signedPeerRecord: @[byte 9])
|
||||
let originalResponse = DiscoverResponse(
|
||||
registrations: @[registrationEntry],
|
||||
cookie: Opt.some(@[byte 0xAA]),
|
||||
status: ResponseStatus.Ok,
|
||||
text: Opt.some("t"),
|
||||
)
|
||||
let decodeResult = DiscoverResponse.decode(originalResponse.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedResponse = decodeResult.get()
|
||||
check:
|
||||
decodedResponse.status == originalResponse.status
|
||||
decodedResponse.registrations.len == 1
|
||||
decodedResponse.registrations[0].ns == registrationEntry.ns
|
||||
decodedResponse.cookie.get() == originalResponse.cookie.get()
|
||||
decodedResponse.text.get() == originalResponse.text.get()
|
||||
decodedResponse.encode().buffer == originalResponse.encode().buffer
|
||||
|
||||
test "DiscoverResponse roundtrip failed":
|
||||
let originalResponse = DiscoverResponse(status: ResponseStatus.InternalError)
|
||||
let decodeResult = DiscoverResponse.decode(originalResponse.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedResponse = decodeResult.get()
|
||||
check:
|
||||
decodedResponse.status == originalResponse.status
|
||||
decodedResponse.registrations.len == 0
|
||||
decodedResponse.cookie.isNone()
|
||||
decodedResponse.text.isNone()
|
||||
decodedResponse.encode().buffer == originalResponse.encode().buffer
|
||||
|
||||
test "DiscoverResponse decode fails with invalid registration":
|
||||
var bufferInvalidRegistration = initProtoBuffer()
|
||||
bufferInvalidRegistration.write(1, @[byte 0x00, 0xFF])
|
||||
bufferInvalidRegistration.write(3, ResponseStatus.Ok.uint)
|
||||
check DiscoverResponse.decode(bufferInvalidRegistration.buffer).isNone()
|
||||
|
||||
test "DiscoverResponse decode fails missing status":
|
||||
var bufferMissingDiscoverResponseStatus = initProtoBuffer()
|
||||
check DiscoverResponse.decode(bufferMissingDiscoverResponseStatus.buffer).isNone()
|
||||
|
||||
test "Message roundtrip Register variant":
|
||||
let registerPayload = Register(ns: namespace, signedPeerRecord: @[byte 1])
|
||||
let originalMessage =
|
||||
Message(msgType: MessageType.Register, register: Opt.some(registerPayload))
|
||||
let decodeResult = Message.decode(originalMessage.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedMessage = decodeResult.get()
|
||||
check:
|
||||
decodedMessage.msgType == originalMessage.msgType
|
||||
decodedMessage.register.get().ns == registerPayload.ns
|
||||
|
||||
test "Message roundtrip RegisterResponse variant":
|
||||
let registerResponsePayload = RegisterResponse(status: ResponseStatus.Ok)
|
||||
let originalMessage = Message(
|
||||
msgType: MessageType.RegisterResponse,
|
||||
registerResponse: Opt.some(registerResponsePayload),
|
||||
)
|
||||
let decodeResult = Message.decode(originalMessage.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedMessage = decodeResult.get()
|
||||
check:
|
||||
decodedMessage.msgType == originalMessage.msgType
|
||||
decodedMessage.registerResponse.get().status == registerResponsePayload.status
|
||||
|
||||
test "Message roundtrip Unregister variant":
|
||||
let unregisterPayload = Unregister(ns: namespace)
|
||||
let originalMessage =
|
||||
Message(msgType: MessageType.Unregister, unregister: Opt.some(unregisterPayload))
|
||||
let decodeResult = Message.decode(originalMessage.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedMessage = decodeResult.get()
|
||||
check:
|
||||
decodedMessage.unregister.get().ns == unregisterPayload.ns
|
||||
|
||||
test "Message roundtrip Discover variant":
|
||||
let discoverPayload = Discover(limit: Opt.some(1'u64))
|
||||
let originalMessage =
|
||||
Message(msgType: MessageType.Discover, discover: Opt.some(discoverPayload))
|
||||
let decodeResult = Message.decode(originalMessage.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedMessage = decodeResult.get()
|
||||
check:
|
||||
decodedMessage.discover.get().limit.get() == discoverPayload.limit.get()
|
||||
|
||||
test "Message roundtrip DiscoverResponse variant":
|
||||
let discoverResponsePayload = DiscoverResponse(status: ResponseStatus.Unavailable)
|
||||
let originalMessage = Message(
|
||||
msgType: MessageType.DiscoverResponse,
|
||||
discoverResponse: Opt.some(discoverResponsePayload),
|
||||
)
|
||||
let decodeResult = Message.decode(originalMessage.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
let decodedMessage = decodeResult.get()
|
||||
check:
|
||||
decodedMessage.discoverResponse.get().status == discoverResponsePayload.status
|
||||
|
||||
test "Message decode header only":
|
||||
var headerOnlyMessage = Message(msgType: MessageType.Register)
|
||||
let decodeResult = Message.decode(headerOnlyMessage.encode().buffer)
|
||||
check decodeResult.isSome()
|
||||
check:
|
||||
decodeResult.get().register.isNone()
|
||||
|
||||
test "Message decode fails invalid msgType":
|
||||
var bufferInvalidMessageType = initProtoBuffer()
|
||||
bufferInvalidMessageType.write(1, uint(999))
|
||||
check Message.decode(bufferInvalidMessageType.buffer).isNone()
|
||||
@@ -311,7 +311,8 @@ suite "GossipSub":
|
||||
check gossipSub.mcache.msgs.len == 0
|
||||
|
||||
asyncTest "rpcHandler - subscription limits":
|
||||
let gossipSub = TestGossipSub.init(newStandardSwitch())
|
||||
let gossipSub =
|
||||
TestGossipSub.init(newStandardSwitch(transport = TransportType.QUIC))
|
||||
gossipSub.topicsHigh = 10
|
||||
|
||||
var tooManyTopics: seq[string]
|
||||
@@ -333,7 +334,8 @@ suite "GossipSub":
|
||||
await conn.close()
|
||||
|
||||
asyncTest "rpcHandler - invalid message bytes":
|
||||
let gossipSub = TestGossipSub.init(newStandardSwitch())
|
||||
let gossipSub =
|
||||
TestGossipSub.init(newStandardSwitch(transport = TransportType.QUIC))
|
||||
|
||||
let peerId = randomPeerId()
|
||||
let peer = gossipSub.getPubSubPeer(peerId)
|
||||
|
||||
@@ -87,7 +87,7 @@ proc setupGossipSubWithPeers*(
|
||||
populateMesh: bool = false,
|
||||
populateFanout: bool = false,
|
||||
): (TestGossipSub, seq[Connection], seq[PubSubPeer]) =
|
||||
let gossipSub = TestGossipSub.init(newStandardSwitch())
|
||||
let gossipSub = TestGossipSub.init(newStandardSwitch(transport = TransportType.QUIC))
|
||||
|
||||
for topic in topics:
|
||||
gossipSub.subscribe(topic, voidTopicHandler)
|
||||
@@ -195,7 +195,9 @@ proc generateNodes*(
|
||||
): seq[PubSub] =
|
||||
for i in 0 ..< num:
|
||||
let switch = newStandardSwitch(
|
||||
secureManagers = secureManagers, sendSignedPeerRecord = sendSignedPeerRecord
|
||||
secureManagers = secureManagers,
|
||||
sendSignedPeerRecord = sendSignedPeerRecord,
|
||||
transport = TransportType.Memory,
|
||||
)
|
||||
let pubsub =
|
||||
if gossip:
|
||||
|
||||
@@ -18,6 +18,7 @@ import
|
||||
upgrademngrs/upgrade,
|
||||
builders,
|
||||
protocols/connectivity/autonatv2/types,
|
||||
protocols/connectivity/autonatv2/server,
|
||||
protocols/connectivity/autonatv2/utils,
|
||||
],
|
||||
./helpers
|
||||
@@ -27,6 +28,10 @@ proc checkEncodeDecode[T](msg: T) =
|
||||
# check msg == DialBack.decode(msg.encode()).get()
|
||||
check msg == T.decode(msg.encode()).get()
|
||||
|
||||
proc newAutonatV2ServerSwitch(config: AutonatV2Config = AutonatV2Config.new()): Switch =
|
||||
var builder = newStandardSwitchBuilder().withAutonatV2(config = config)
|
||||
return builder.build()
|
||||
|
||||
suite "AutonatV2":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
@@ -149,3 +154,15 @@ suite "AutonatV2":
|
||||
AutonatV2Response(
|
||||
reachability: Reachable, dialResp: correctDialResp, addrs: Opt.some(addrs[0])
|
||||
)
|
||||
|
||||
asyncTest "Instantiate server":
|
||||
let serverSwitch = newAutonatV2ServerSwitch()
|
||||
await serverSwitch.start()
|
||||
await serverSwitch.stop()
|
||||
|
||||
asyncTest "Instantiate server with config":
|
||||
let serverSwitch = newAutonatV2ServerSwitch(
|
||||
config = AutonatV2Config.new(allowPrivateAddresses = true)
|
||||
)
|
||||
await serverSwitch.start()
|
||||
await serverSwitch.stop()
|
||||
|
||||
@@ -60,15 +60,19 @@ proc invalidCertGenerator(
|
||||
except ResultError[crypto.CryptoError]:
|
||||
raiseAssert "private key should be set"
|
||||
|
||||
proc createTransport(withInvalidCert: bool = false): Future[QuicTransport] {.async.} =
|
||||
let ma = @[MultiAddress.init("/ip4/127.0.0.1/udp/0/quic-v1").tryGet()]
|
||||
proc createTransport(
|
||||
isServer: bool = false, withInvalidCert: bool = false
|
||||
): Future[QuicTransport] {.async.} =
|
||||
let privateKey = PrivateKey.random(ECDSA, (newRng())[]).tryGet()
|
||||
let trans =
|
||||
if withInvalidCert:
|
||||
QuicTransport.new(Upgrade(), privateKey, invalidCertGenerator)
|
||||
else:
|
||||
QuicTransport.new(Upgrade(), privateKey)
|
||||
await trans.start(ma)
|
||||
|
||||
if isServer: # servers are started because they need to listen
|
||||
let ma = @[MultiAddress.init("/ip4/127.0.0.1/udp/0/quic-v1").tryGet()]
|
||||
await trans.start(ma)
|
||||
|
||||
return trans
|
||||
|
||||
@@ -77,12 +81,46 @@ suite "Quic transport":
|
||||
checkTrackers()
|
||||
|
||||
asyncTest "can handle local address":
|
||||
let trans = await createTransport()
|
||||
check trans.handles(trans.addrs[0])
|
||||
await trans.stop()
|
||||
let server = await createTransport(isServer = true)
|
||||
check server.handles(server.addrs[0])
|
||||
await server.stop()
|
||||
|
||||
asyncTest "handle accept cancellation":
|
||||
let server = await createTransport(isServer = true)
|
||||
|
||||
let acceptFut = server.accept()
|
||||
await acceptFut.cancelAndWait()
|
||||
check acceptFut.cancelled
|
||||
|
||||
await server.stop()
|
||||
|
||||
asyncTest "handle dial cancellation":
|
||||
let server = await createTransport(isServer = true)
|
||||
let client = await createTransport(isServer = false)
|
||||
|
||||
let connFut = client.dial(server.addrs[0])
|
||||
await connFut.cancelAndWait()
|
||||
check connFut.cancelled
|
||||
|
||||
await client.stop()
|
||||
await server.stop()
|
||||
|
||||
asyncTest "stopping transport kills connections":
|
||||
let server = await createTransport(isServer = true)
|
||||
let client = await createTransport(isServer = false)
|
||||
|
||||
let acceptFut = server.accept()
|
||||
let conn = await client.dial(server.addrs[0])
|
||||
let serverConn = await acceptFut
|
||||
|
||||
await client.stop()
|
||||
await server.stop()
|
||||
|
||||
check serverConn.closed()
|
||||
check conn.closed()
|
||||
|
||||
asyncTest "transport e2e":
|
||||
let server = await createTransport()
|
||||
let server = await createTransport(isServer = true)
|
||||
asyncSpawn createServerAcceptConn(server)()
|
||||
defer:
|
||||
await server.stop()
|
||||
@@ -101,7 +139,7 @@ suite "Quic transport":
|
||||
await runClient()
|
||||
|
||||
asyncTest "transport e2e - invalid cert - server":
|
||||
let server = await createTransport(true)
|
||||
let server = await createTransport(isServer = true, withInvalidCert = true)
|
||||
asyncSpawn createServerAcceptConn(server)()
|
||||
defer:
|
||||
await server.stop()
|
||||
@@ -115,22 +153,27 @@ suite "Quic transport":
|
||||
await runClient()
|
||||
|
||||
asyncTest "transport e2e - invalid cert - client":
|
||||
let server = await createTransport()
|
||||
let server = await createTransport(isServer = true)
|
||||
asyncSpawn createServerAcceptConn(server)()
|
||||
defer:
|
||||
await server.stop()
|
||||
|
||||
proc runClient() {.async.} =
|
||||
let client = await createTransport(true)
|
||||
let client = await createTransport(withInvalidCert = true)
|
||||
expect QuicTransportDialError:
|
||||
discard await client.dial("", server.addrs[0])
|
||||
await client.stop()
|
||||
|
||||
await runClient()
|
||||
|
||||
asyncTest "should allow multiple local addresses":
|
||||
# TODO(#1663): handle multiple addr
|
||||
# See test example in commonTransportTest
|
||||
return
|
||||
|
||||
asyncTest "server not accepting":
|
||||
let server = await createTransport()
|
||||
# itentionally not calling createServerAcceptConn as server should not accept
|
||||
let server = await createTransport(isServer = true)
|
||||
# intentionally not calling createServerAcceptConn as server should not accept
|
||||
defer:
|
||||
await server.stop()
|
||||
|
||||
@@ -145,7 +188,7 @@ suite "Quic transport":
|
||||
await runClient()
|
||||
|
||||
asyncTest "closing session should close all streams":
|
||||
let server = await createTransport()
|
||||
let server = await createTransport(isServer = true)
|
||||
# because some clients will not write full message,
|
||||
# it is expected for server to receive eof
|
||||
asyncSpawn createServerAcceptConn(server, true)()
|
||||
@@ -201,7 +244,7 @@ suite "Quic transport":
|
||||
check (await stream.readLp(100)) == fromHex("5678")
|
||||
await client.stop()
|
||||
|
||||
let server = await createTransport()
|
||||
let server = await createTransport(isServer = true)
|
||||
asyncSpawn serverHandler(server)
|
||||
defer:
|
||||
await server.stop()
|
||||
|
||||
Reference in New Issue
Block a user