mirror of
https://github.com/vacp2p/nim-libp2p.git
synced 2026-01-10 13:58:17 -05:00
Compare commits
45 Commits
dev/etan/e
...
anyaddr-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdff1ebec5 | ||
|
|
d34575675c | ||
|
|
03f72a8b5c | ||
|
|
b11e2b0349 | ||
|
|
2fa2c4425f | ||
|
|
e228981a11 | ||
|
|
907c41a491 | ||
|
|
03668a3e90 | ||
|
|
980950b147 | ||
|
|
eb6a9f2ff1 | ||
|
|
7810dba9d6 | ||
|
|
6d327c37c4 | ||
|
|
21cb13a8ff | ||
|
|
1b2b009f79 | ||
|
|
3a659ffddb | ||
|
|
ac994a8f15 | ||
|
|
ee8318ec42 | ||
|
|
52a8870f78 | ||
|
|
0911cb20f4 | ||
|
|
d7c0486968 | ||
|
|
63b6390d1a | ||
|
|
cf7b77bf82 | ||
|
|
3ca49a2f40 | ||
|
|
1b91b97499 | ||
|
|
21cbe3a91a | ||
|
|
88e233db81 | ||
|
|
84659af45b | ||
|
|
aef44ed1ce | ||
|
|
02c96fc003 | ||
|
|
c4da9be32c | ||
|
|
2b5319622c | ||
|
|
5cbb473d1b | ||
|
|
b30b2656d5 | ||
|
|
89cad5a3ba | ||
|
|
09b3e11956 | ||
|
|
03f67d3db5 | ||
|
|
bb97a9de79 | ||
|
|
1a707e1264 | ||
|
|
458b0885dd | ||
|
|
a2027003cd | ||
|
|
c5db35d9b0 | ||
|
|
d1e51beb7f | ||
|
|
275d649287 | ||
|
|
467b5b4f0c | ||
|
|
fdf53d18cd |
3
.github/workflows/bumper.yml
vendored
3
.github/workflows/bumper.yml
vendored
@@ -2,8 +2,7 @@ name: Bumper
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- unstable
|
||||
- bumper
|
||||
- master
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -71,7 +71,6 @@ List of packages modules implemented in nim-libp2p:
|
||||
| [libp2p-ws](libp2p/transports/wstransport.nim) | WebSocket & WebSocket Secure transport |
|
||||
| [libp2p-tor](libp2p/transports/tortransport.nim) | Tor Transport |
|
||||
| **Secure Channels** | |
|
||||
| [libp2p-secio](libp2p/protocols/secure/secio.nim) | Secio secure channel |
|
||||
| [libp2p-noise](libp2p/protocols/secure/noise.nim) | [Noise](https://docs.libp2p.io/concepts/secure-comm/noise/) secure channel |
|
||||
| [libp2p-plaintext](libp2p/protocols/secure/plaintext.nim) | Plain Text for development purposes |
|
||||
| **Stream Multiplexers** | |
|
||||
|
||||
@@ -7,7 +7,6 @@ if dirExists("nimbledeps/pkgs2"):
|
||||
switch("warning", "CaseTransition:off")
|
||||
switch("warning", "ObservableStores:off")
|
||||
switch("warning", "LockLevel:off")
|
||||
--define:chronosStrictException
|
||||
--styleCheck:usages
|
||||
switch("warningAsError", "UseBase:on")
|
||||
--styleCheck:error
|
||||
|
||||
24
di/di.nim
Normal file
24
di/di.nim
Normal file
@@ -0,0 +1,24 @@
|
||||
import typetraits
|
||||
import tables
|
||||
|
||||
type
|
||||
BindingKey = tuple[typeName: string, qualifier: string]
|
||||
|
||||
Container* = ref object
|
||||
bindings*: Table[BindingKey, proc(): RootRef {.gcsafe, raises: [].}]
|
||||
|
||||
BindingNotFoundError* = object of CatchableError
|
||||
|
||||
proc register*[T](c: Container, implementation: proc(): T {.gcsafe, raises: [].}, qualifier: string = "") =
|
||||
let key: BindingKey = (name(T), qualifier)
|
||||
proc p(): RootRef =
|
||||
let o: RootRef = implementation()
|
||||
return o
|
||||
c.bindings[key] = p
|
||||
|
||||
proc resolve*[T](c: Container, qualifier: string = ""): T {.raises: [BindingNotFoundError]} =
|
||||
let key: BindingKey = (name(T), qualifier)
|
||||
try:
|
||||
return cast[T](c.bindings[key]())
|
||||
except KeyError:
|
||||
raise newException(BindingNotFoundError, "Type not bound: " & name(T))
|
||||
@@ -20,7 +20,7 @@ proc new(T: typedesc[TestProto]): T =
|
||||
# We must close the connections ourselves when we're done with it
|
||||
await conn.close()
|
||||
|
||||
return T(codecs: @[TestCodec], handler: handle)
|
||||
return T.new(codecs = @[TestCodec], handler = handle)
|
||||
|
||||
##
|
||||
# Helper to create a switch/node
|
||||
|
||||
@@ -19,7 +19,8 @@ runnableExamples:
|
||||
{.push raises: [].}
|
||||
|
||||
import
|
||||
options, tables, chronos, chronicles, sequtils,
|
||||
options, tables, chronos, chronicles, sequtils
|
||||
import
|
||||
switch, peerid, peerinfo, stream/connection, multiaddress,
|
||||
crypto/crypto, transports/[transport, tcptransport],
|
||||
muxers/[muxer, mplex/mplex, yamux/yamux],
|
||||
@@ -28,6 +29,8 @@ import
|
||||
connmanager, upgrademngrs/muxedupgrade, observedaddrmanager,
|
||||
nameresolving/nameresolver,
|
||||
errors, utility
|
||||
import services/wildcardresolverservice
|
||||
import ../di/di
|
||||
|
||||
export
|
||||
switch, peerid, peerinfo, connection, multiaddress, crypto, errors
|
||||
@@ -36,8 +39,7 @@ type
|
||||
TransportProvider* {.public.} = proc(upgr: Upgrade): Transport {.gcsafe, raises: [].}
|
||||
|
||||
SecureProtocol* {.pure.} = enum
|
||||
Noise,
|
||||
Secio {.deprecated.}
|
||||
Noise
|
||||
|
||||
SwitchBuilder* = ref object
|
||||
privKey: Option[PrivateKey]
|
||||
@@ -60,6 +62,8 @@ type
|
||||
rdv: RendezVous
|
||||
services: seq[Service]
|
||||
observedAddrManager: ObservedAddrManager
|
||||
enableWildcardResolver: bool
|
||||
container*: Container
|
||||
|
||||
proc new*(T: type[SwitchBuilder]): T {.public.} =
|
||||
## Creates a SwitchBuilder
|
||||
@@ -68,7 +72,7 @@ proc new*(T: type[SwitchBuilder]): T {.public.} =
|
||||
.init("/ip4/127.0.0.1/tcp/0")
|
||||
.expect("Should initialize to default")
|
||||
|
||||
SwitchBuilder(
|
||||
let sb = SwitchBuilder(
|
||||
privKey: none(PrivateKey),
|
||||
addresses: @[address],
|
||||
secureManagers: @[],
|
||||
@@ -77,7 +81,12 @@ proc new*(T: type[SwitchBuilder]): T {.public.} =
|
||||
maxOut: -1,
|
||||
maxConnsPerPeer: MaxConnectionsPerPeer,
|
||||
protoVersion: ProtoVersion,
|
||||
agentVersion: AgentVersion)
|
||||
agentVersion: AgentVersion,
|
||||
container: Container())
|
||||
|
||||
register[NetworkInterfaceProvider](sb.container, networkInterfaceProvider)
|
||||
|
||||
sb
|
||||
|
||||
proc withPrivateKey*(b: SwitchBuilder, privateKey: PrivateKey): SwitchBuilder {.public.} =
|
||||
## Set the private key of the switch. Will be used to
|
||||
@@ -86,20 +95,19 @@ proc withPrivateKey*(b: SwitchBuilder, privateKey: PrivateKey): SwitchBuilder {.
|
||||
b.privKey = some(privateKey)
|
||||
b
|
||||
|
||||
proc withAddress*(b: SwitchBuilder, address: MultiAddress): SwitchBuilder {.public.} =
|
||||
## | Set the listening address of the switch
|
||||
## | Calling it multiple time will override the value
|
||||
|
||||
b.addresses = @[address]
|
||||
b
|
||||
|
||||
proc withAddresses*(b: SwitchBuilder, addresses: seq[MultiAddress]): SwitchBuilder {.public.} =
|
||||
proc withAddresses*(b: SwitchBuilder, addresses: seq[MultiAddress], enableWildcardResolver: bool = true): SwitchBuilder {.public.} =
|
||||
## | Set the listening addresses of the switch
|
||||
## | Calling it multiple time will override the value
|
||||
|
||||
b.addresses = addresses
|
||||
b.enableWildcardResolver = enableWildcardResolver
|
||||
b
|
||||
|
||||
proc withAddress*(b: SwitchBuilder, address: MultiAddress, enableWildcardResolver: bool = true): SwitchBuilder {.public.} =
|
||||
## | Set the listening address of the switch
|
||||
## | Calling it multiple time will override the value
|
||||
b.withAddresses(@[address], enableWildcardResolver)
|
||||
|
||||
|
||||
proc withSignedPeerRecord*(b: SwitchBuilder, sendIt = true): SwitchBuilder {.public.} =
|
||||
b.sendSignedPeerRecord = sendIt
|
||||
b
|
||||
@@ -210,6 +218,10 @@ proc withObservedAddrManager*(b: SwitchBuilder, observedAddrManager: ObservedAdd
|
||||
b.observedAddrManager = observedAddrManager
|
||||
b
|
||||
|
||||
proc withBinding*[T](b: SwitchBuilder, binding: proc(): T {.gcsafe, raises: [].}): SwitchBuilder =
|
||||
register[T](b.container, binding)
|
||||
b
|
||||
|
||||
proc build*(b: SwitchBuilder): Switch
|
||||
{.raises: [LPError], public.} =
|
||||
|
||||
@@ -262,6 +274,12 @@ proc build*(b: SwitchBuilder): Switch
|
||||
else:
|
||||
PeerStore.new(identify)
|
||||
|
||||
try:
|
||||
let networkInterfaceProvider = resolve[NetworkInterfaceProvider](b.container)
|
||||
b.services.add(WildcardAddressResolverService.new(networkInterfaceProvider))
|
||||
except BindingNotFoundError as e:
|
||||
raise newException(LPError, "Cannot resolve NetworkInterfaceProvider", e)
|
||||
|
||||
let switch = newSwitch(
|
||||
peerInfo = peerInfo,
|
||||
transports = transports,
|
||||
@@ -310,15 +328,10 @@ proc newStandardSwitch*(
|
||||
peerStoreCapacity = 1000
|
||||
): Switch {.raises: [LPError], public.} =
|
||||
## Helper for common switch configurations.
|
||||
{.push warning[Deprecated]:off.}
|
||||
if SecureProtocol.Secio in secureManagers:
|
||||
quit("Secio is deprecated!") # use of secio is unsafe
|
||||
{.pop.}
|
||||
|
||||
let addrs = when addrs is MultiAddress: @[addrs] else: addrs
|
||||
var b = SwitchBuilder
|
||||
.new()
|
||||
.withAddresses(addrs)
|
||||
.withAddresses(addrs, true)
|
||||
.withRng(rng)
|
||||
.withSignedPeerRecord(sendSignedPeerRecord)
|
||||
.withMaxConnections(maxConnections)
|
||||
|
||||
@@ -924,59 +924,6 @@ proc selectBest*(order: int, p1, p2: string): string =
|
||||
if felement == selement:
|
||||
return felement
|
||||
|
||||
proc createProposal*(nonce, pubkey: openArray[byte],
|
||||
exchanges, ciphers, hashes: string): seq[byte] =
|
||||
## Create SecIO proposal message using random ``nonce``, local public key
|
||||
## ``pubkey``, comma-delimieted list of supported exchange schemes
|
||||
## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and
|
||||
## comma-delimeted list of supported hashes ``hashes``.
|
||||
var msg = initProtoBuffer({WithUint32BeLength})
|
||||
msg.write(1, nonce)
|
||||
msg.write(2, pubkey)
|
||||
msg.write(3, exchanges)
|
||||
msg.write(4, ciphers)
|
||||
msg.write(5, hashes)
|
||||
msg.finish()
|
||||
msg.buffer
|
||||
|
||||
proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
|
||||
exchanges, ciphers, hashes: var string): bool =
|
||||
## Parse incoming proposal message and decode remote random nonce ``nonce``,
|
||||
## remote public key ``pubkey``, comma-delimieted list of supported exchange
|
||||
## schemes ``exchanges``, comma-delimeted list of supported ciphers
|
||||
## ``ciphers`` and comma-delimeted list of supported hashes ``hashes``.
|
||||
##
|
||||
## Procedure returns ``true`` on success and ``false`` on error.
|
||||
var pb = initProtoBuffer(message)
|
||||
let r1 = pb.getField(1, nonce)
|
||||
let r2 = pb.getField(2, pubkey)
|
||||
let r3 = pb.getField(3, exchanges)
|
||||
let r4 = pb.getField(4, ciphers)
|
||||
let r5 = pb.getField(5, hashes)
|
||||
|
||||
r1.get(false) and r2.get(false) and r3.get(false) and
|
||||
r4.get(false) and r5.get(false)
|
||||
|
||||
proc createExchange*(epubkey, signature: openArray[byte]): seq[byte] =
|
||||
## Create SecIO exchange message using ephemeral public key ``epubkey`` and
|
||||
## signature of proposal blocks ``signature``.
|
||||
var msg = initProtoBuffer({WithUint32BeLength})
|
||||
msg.write(1, epubkey)
|
||||
msg.write(2, signature)
|
||||
msg.finish()
|
||||
msg.buffer
|
||||
|
||||
proc decodeExchange*(message: seq[byte],
|
||||
pubkey, signature: var seq[byte]): bool =
|
||||
## Parse incoming exchange message and decode remote ephemeral public key
|
||||
## ``pubkey`` and signature ``signature``.
|
||||
##
|
||||
## Procedure returns ``true`` on success and ``false`` on error.
|
||||
var pb = initProtoBuffer(message)
|
||||
let r1 = pb.getField(1, pubkey)
|
||||
let r2 = pb.getField(2, signature)
|
||||
r1.get(false) and r2.get(false)
|
||||
|
||||
## Serialization/Deserialization helpers
|
||||
|
||||
proc write*(vb: var VBuffer, pubkey: PublicKey) {.
|
||||
|
||||
@@ -81,16 +81,18 @@ proc dialAndUpgrade(
|
||||
if dialed.dir != dir:
|
||||
dialed.dir = dir
|
||||
await transport.upgrade(dialed, peerId)
|
||||
except CancelledError as exc:
|
||||
await dialed.close()
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
# If we failed to establish the connection through one transport,
|
||||
# we won't succeeded through another - no use in trying again
|
||||
await dialed.close()
|
||||
debug "Upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId))
|
||||
if exc isnot CancelledError:
|
||||
if dialed.dir == Direction.Out:
|
||||
libp2p_failed_upgrades_outgoing.inc()
|
||||
else:
|
||||
libp2p_failed_upgrades_incoming.inc()
|
||||
debug "Connection upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId))
|
||||
if dialed.dir == Direction.Out:
|
||||
libp2p_failed_upgrades_outgoing.inc()
|
||||
else:
|
||||
libp2p_failed_upgrades_incoming.inc()
|
||||
|
||||
# Try other address
|
||||
return nil
|
||||
|
||||
@@ -44,12 +44,3 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
|
||||
# We still don't abort but warn
|
||||
debug "A future has failed, enable trace logging for details", error=exc.name
|
||||
trace "Exception details", msg=exc.msg
|
||||
|
||||
template tryAndWarn*(message: static[string]; body: untyped): untyped =
|
||||
try:
|
||||
body
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
debug "An exception has ocurred, enable trace logging for details", name = exc.name, msg = message
|
||||
trace "Exception details", exc = exc.msg
|
||||
|
||||
@@ -13,12 +13,12 @@
|
||||
{.push public.}
|
||||
|
||||
import pkg/chronos, chronicles
|
||||
import std/[nativesockets, hashes]
|
||||
import tables, strutils, sets, stew/shims/net
|
||||
import std/[nativesockets, net, hashes]
|
||||
import tables, strutils, sets
|
||||
import multicodec, multihash, multibase, transcoder, vbuffer, peerid,
|
||||
protobuf/minprotobuf, errors, utility
|
||||
import stew/[base58, base32, endians2, results]
|
||||
export results, minprotobuf, vbuffer, utility
|
||||
export results, minprotobuf, vbuffer, utility, multicodec
|
||||
|
||||
logScope:
|
||||
topics = "libp2p multiaddress"
|
||||
|
||||
@@ -246,10 +246,14 @@ proc addHandler*(m: MultistreamSelect,
|
||||
matcher: Matcher = nil) =
|
||||
addHandler(m, @[codec], protocol, matcher)
|
||||
|
||||
proc addHandler*(m: MultistreamSelect,
|
||||
codec: string,
|
||||
handler: LPProtoHandler,
|
||||
matcher: Matcher = nil) =
|
||||
proc addHandler*[E](
|
||||
m: MultistreamSelect,
|
||||
codec: string,
|
||||
handler: LPProtoHandler |
|
||||
proc (
|
||||
conn: Connection,
|
||||
proto: string): InternalRaisesFuture[void, E],
|
||||
matcher: Matcher = nil) =
|
||||
## helper to allow registering pure handlers
|
||||
trace "registering proto handler", proto = codec
|
||||
let protocol = new LPProtocol
|
||||
@@ -261,26 +265,28 @@ proc addHandler*(m: MultistreamSelect,
|
||||
match: matcher))
|
||||
|
||||
proc start*(m: MultistreamSelect) {.async: (raises: [CancelledError]).} =
|
||||
let
|
||||
handlers = m.handlers
|
||||
futs = handlers.mapIt(it.protocol.start())
|
||||
# Nim 1.6.18: Using `mapIt` results in a seq of `.Raising([])`
|
||||
# TODO https://github.com/nim-lang/Nim/issues/23445
|
||||
var futs = newSeqOfCap[Future[void].Raising([CancelledError])](m.handlers.len)
|
||||
for it in m.handlers:
|
||||
futs.add it.protocol.start()
|
||||
try:
|
||||
await allFutures(futs)
|
||||
for fut in futs:
|
||||
await fut
|
||||
except CancelledError as exc:
|
||||
var pending: seq[Future[void].Raising([])]
|
||||
doAssert m.handlers.len == futs.len, "Handlers modified while starting"
|
||||
for i, fut in futs:
|
||||
if not fut.finished:
|
||||
pending.add noCancel fut.cancelAndWait()
|
||||
pending.add fut.cancelAndWait()
|
||||
elif fut.completed:
|
||||
pending.add handlers[i].protocol.stop()
|
||||
pending.add m.handlers[i].protocol.stop()
|
||||
else:
|
||||
static: doAssert typeof(fut).E is (CancelledError,)
|
||||
await noCancel allFutures(pending)
|
||||
raise exc
|
||||
|
||||
|
||||
proc stop*(m: MultistreamSelect) {.async: (raises: []).} =
|
||||
# Nim 1.6.18: Using `mapIt` results in a seq of `.Raising([CancelledError])`
|
||||
var futs = newSeqOfCap[Future[void].Raising([])](m.handlers.len)
|
||||
|
||||
@@ -164,7 +164,6 @@ type
|
||||
closedRemotely: Future[void].Raising([])
|
||||
closedLocally: bool
|
||||
receivedData: AsyncEvent
|
||||
returnedEof: bool
|
||||
|
||||
proc `$`(channel: YamuxChannel): string =
|
||||
result = if channel.conn.dir == Out: "=> " else: "<= "
|
||||
@@ -204,8 +203,8 @@ proc remoteClosed(channel: YamuxChannel) {.async: (raises: []).} =
|
||||
|
||||
method closeImpl*(channel: YamuxChannel) {.async: (raises: []).} =
|
||||
if not channel.closedLocally:
|
||||
trace "Closing yamux channel locally", streamId = channel.id, conn = channel.conn
|
||||
channel.closedLocally = true
|
||||
channel.isEof = true
|
||||
|
||||
if not channel.isReset and channel.sendQueue.len == 0:
|
||||
try: await channel.conn.write(YamuxHeader.data(channel.id, 0, {Fin}))
|
||||
@@ -273,7 +272,7 @@ method readOnce*(
|
||||
newLPStreamClosedError()
|
||||
else:
|
||||
newLPStreamConnDownError()
|
||||
if channel.returnedEof:
|
||||
if channel.isEof:
|
||||
raise newLPStreamRemoteClosedError()
|
||||
if channel.recvQueue.len == 0:
|
||||
channel.receivedData.clear()
|
||||
@@ -281,9 +280,8 @@ method readOnce*(
|
||||
discard await race(channel.closedRemotely, channel.receivedData.wait())
|
||||
except ValueError: raiseAssert("Futures list is not empty")
|
||||
if channel.closedRemotely.completed() and channel.recvQueue.len == 0:
|
||||
channel.returnedEof = true
|
||||
channel.isEof = true
|
||||
return 0
|
||||
return 0 # we return 0 to indicate that the channel is closed for reading from now on
|
||||
|
||||
let toRead = min(channel.recvQueue.len, nbytes)
|
||||
|
||||
@@ -555,13 +553,13 @@ method handle*(m: Yamux) {.async: (raises: []).} =
|
||||
if flushed[] < 0:
|
||||
raise newException(YamuxError,
|
||||
"Peer exhausted the recvWindow after reset")
|
||||
if header.length > 0:
|
||||
var buffer = newSeqUninitialized[byte](header.length)
|
||||
await m.connection.readExactly(
|
||||
addr buffer[0], int(header.length))
|
||||
do:
|
||||
raise newException(YamuxError,
|
||||
"Unknown stream ID: " & $header.streamId)
|
||||
if header.length > 0:
|
||||
var buffer = newSeqUninitialized[byte](header.length)
|
||||
await m.connection.readExactly(
|
||||
addr buffer[0], int(header.length))
|
||||
continue
|
||||
|
||||
let channel =
|
||||
|
||||
@@ -24,12 +24,16 @@ type
|
||||
AddressMapper* =
|
||||
proc(listenAddrs: seq[MultiAddress]): Future[seq[MultiAddress]]
|
||||
{.gcsafe, raises: [].}
|
||||
## A proc that expected to resolve the listen addresses into dialable addresses
|
||||
|
||||
PeerInfo* {.public.} = ref object
|
||||
peerId*: PeerId
|
||||
listenAddrs*: seq[MultiAddress]
|
||||
## contains addresses the node listens on, which may include wildcard and private addresses (not directly reachable).
|
||||
addrs: seq[MultiAddress]
|
||||
## contains resolved addresses that other peers can use to connect, including public-facing NAT and port-forwarded addresses.
|
||||
addressMappers*: seq[AddressMapper]
|
||||
## contains a list of procs that can be used to resolve the listen addresses into dialable addresses.
|
||||
protocols*: seq[string]
|
||||
protoVersion*: string
|
||||
agentVersion*: string
|
||||
|
||||
@@ -56,7 +56,7 @@ method init*(p: Ping) =
|
||||
trace "handling ping", conn
|
||||
var buf: array[PingSize, byte]
|
||||
await conn.readExactly(addr buf[0], PingSize)
|
||||
trace "echoing ping", conn
|
||||
trace "echoing ping", conn, pingData = @buf
|
||||
await conn.write(@buf)
|
||||
if not isNil(p.pingHandler):
|
||||
await p.pingHandler(conn.peerId)
|
||||
|
||||
@@ -19,14 +19,12 @@ const
|
||||
|
||||
type
|
||||
LPProtoHandler* = proc (
|
||||
conn: Connection,
|
||||
proto: string):
|
||||
Future[void]
|
||||
{.gcsafe, raises: [].}
|
||||
conn: Connection,
|
||||
proto: string): Future[void] {.async.}
|
||||
|
||||
LPProtocol* = ref object of RootObj
|
||||
codecs*: seq[string]
|
||||
handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator
|
||||
handlerImpl: LPProtoHandler ## invoked by the protocol negotiator
|
||||
started*: bool
|
||||
maxIncomingStreams: Opt[int]
|
||||
|
||||
@@ -52,7 +50,7 @@ proc `maxIncomingStreams=`*(p: LPProtocol, val: int) =
|
||||
p.maxIncomingStreams = Opt.some(val)
|
||||
|
||||
func codec*(p: LPProtocol): string =
|
||||
assert(p.codecs.len > 0, "Codecs sequence was empty!")
|
||||
doAssert(p.codecs.len > 0, "Codecs sequence was empty!")
|
||||
p.codecs[0]
|
||||
|
||||
func `codec=`*(p: LPProtocol, codec: string) =
|
||||
@@ -60,15 +58,51 @@ func `codec=`*(p: LPProtocol, codec: string) =
|
||||
# if we use this abstraction
|
||||
p.codecs.insert(codec, 0)
|
||||
|
||||
template `handler`*(p: LPProtocol): LPProtoHandler =
|
||||
p.handlerImpl
|
||||
|
||||
template `handler`*(
|
||||
p: LPProtocol, conn: Connection, proto: string): Future[void] =
|
||||
p.handlerImpl(conn, proto)
|
||||
|
||||
func `handler=`*(p: LPProtocol, handler: LPProtoHandler) =
|
||||
p.handlerImpl = handler
|
||||
|
||||
# Callbacks that are annotated with `{.async: (raises).}` explicitly
|
||||
# document the types of errors that they may raise, but are not compatible
|
||||
# with `LPProtoHandler` and need to use a custom `proc` type.
|
||||
# They are internally wrapped into a `LPProtoHandler`, but still allow the
|
||||
# compiler to check that their `{.async: (raises).}` annotation is correct.
|
||||
# https://github.com/nim-lang/Nim/issues/23432
|
||||
func `handler=`*[E](
|
||||
p: LPProtocol,
|
||||
handler: proc (
|
||||
conn: Connection,
|
||||
proto: string): InternalRaisesFuture[void, E]) =
|
||||
proc wrap(conn: Connection, proto: string): Future[void] {.async.} =
|
||||
await handler(conn, proto)
|
||||
p.handlerImpl = wrap
|
||||
|
||||
proc new*(
|
||||
T: type LPProtocol,
|
||||
codecs: seq[string],
|
||||
handler: LPProtoHandler,
|
||||
maxIncomingStreams: Opt[int] | int = Opt.none(int)): T =
|
||||
T: type LPProtocol,
|
||||
codecs: seq[string],
|
||||
handler: LPProtoHandler,
|
||||
maxIncomingStreams: Opt[int] | int = Opt.none(int)): T =
|
||||
T(
|
||||
codecs: codecs,
|
||||
handler: handler,
|
||||
handlerImpl: handler,
|
||||
maxIncomingStreams:
|
||||
when maxIncomingStreams is int: Opt.some(maxIncomingStreams)
|
||||
else: maxIncomingStreams
|
||||
)
|
||||
|
||||
proc new*[E](
|
||||
T: type LPProtocol,
|
||||
codecs: seq[string],
|
||||
handler: proc (
|
||||
conn: Connection,
|
||||
proto: string): InternalRaisesFuture[void, E],
|
||||
maxIncomingStreams: Opt[int] | int = Opt.none(int)): T =
|
||||
proc wrap(conn: Connection, proto: string): Future[void] {.async.} =
|
||||
await handler(conn, proto)
|
||||
T.new(codec, wrap, maxIncomingStreams)
|
||||
|
||||
@@ -16,6 +16,7 @@ import ./pubsub,
|
||||
./timedcache,
|
||||
./peertable,
|
||||
./rpc/[message, messages, protobuf],
|
||||
nimcrypto/[hash, sha2],
|
||||
../../crypto/crypto,
|
||||
../../stream/connection,
|
||||
../../peerid,
|
||||
@@ -32,25 +33,34 @@ const FloodSubCodec* = "/floodsub/1.0.0"
|
||||
type
|
||||
FloodSub* {.public.} = ref object of PubSub
|
||||
floodsub*: PeerTable # topic to remote peer map
|
||||
seen*: TimedCache[MessageId] # message id:s already seen on the network
|
||||
seenSalt*: seq[byte]
|
||||
seen*: TimedCache[SaltedId]
|
||||
# Early filter for messages recently observed on the network
|
||||
# We use a salted id because the messages in this cache have not yet
|
||||
# been validated meaning that an attacker has greater control over the
|
||||
# hash key and therefore could poison the table
|
||||
seenSalt*: sha256
|
||||
# The salt in this case is a partially updated SHA256 context pre-seeded
|
||||
# with some random data
|
||||
|
||||
proc hasSeen*(f: FloodSub, msgId: MessageId): bool =
|
||||
f.seenSalt & msgId in f.seen
|
||||
proc salt*(f: FloodSub, msgId: MessageId): SaltedId =
|
||||
var tmp = f.seenSalt
|
||||
tmp.update(msgId)
|
||||
SaltedId(data: tmp.finish())
|
||||
|
||||
proc addSeen*(f: FloodSub, msgId: MessageId): bool =
|
||||
# Salting the seen hash helps avoid attacks against the hash function used
|
||||
# in the nim hash table
|
||||
proc hasSeen*(f: FloodSub, saltedId: SaltedId): bool =
|
||||
saltedId in f.seen
|
||||
|
||||
proc addSeen*(f: FloodSub, saltedId: SaltedId): bool =
|
||||
# Return true if the message has already been seen
|
||||
f.seen.put(f.seenSalt & msgId)
|
||||
f.seen.put(saltedId)
|
||||
|
||||
proc firstSeen*(f: FloodSub, msgId: MessageId): Moment =
|
||||
f.seen.addedAt(f.seenSalt & msgId)
|
||||
proc firstSeen*(f: FloodSub, saltedId: SaltedId): Moment =
|
||||
f.seen.addedAt(saltedId)
|
||||
|
||||
proc handleSubscribe*(f: FloodSub,
|
||||
peer: PubSubPeer,
|
||||
topic: string,
|
||||
subscribe: bool) =
|
||||
proc handleSubscribe(f: FloodSub,
|
||||
peer: PubSubPeer,
|
||||
topic: string,
|
||||
subscribe: bool) =
|
||||
logScope:
|
||||
peer
|
||||
topic
|
||||
@@ -96,10 +106,9 @@ method unsubscribePeer*(f: FloodSub, peer: PeerId) =
|
||||
method rpcHandler*(f: FloodSub,
|
||||
peer: PubSubPeer,
|
||||
data: seq[byte]) {.async.} =
|
||||
|
||||
var rpcMsg = decodeRpcMsg(data).valueOr:
|
||||
debug "failed to decode msg from peer", peer, err = error
|
||||
raise newException(CatchableError, "")
|
||||
raise newException(CatchableError, "Peer msg couldn't be decoded")
|
||||
|
||||
trace "decoded msg from peer", peer, msg = rpcMsg.shortLog
|
||||
# trigger hooks
|
||||
@@ -117,9 +126,11 @@ method rpcHandler*(f: FloodSub,
|
||||
# TODO: descore peers due to error during message validation (malicious?)
|
||||
continue
|
||||
|
||||
let msgId = msgIdResult.get
|
||||
let
|
||||
msgId = msgIdResult.get
|
||||
saltedId = f.salt(msgId)
|
||||
|
||||
if f.addSeen(msgId):
|
||||
if f.addSeen(saltedId):
|
||||
trace "Dropping already-seen message", msgId, peer
|
||||
continue
|
||||
|
||||
@@ -148,12 +159,15 @@ method rpcHandler*(f: FloodSub,
|
||||
discard
|
||||
|
||||
var toSendPeers = initHashSet[PubSubPeer]()
|
||||
for t in msg.topicIds: # for every topic in the message
|
||||
if t notin f.topics:
|
||||
continue
|
||||
f.floodsub.withValue(t, peers): toSendPeers.incl(peers[])
|
||||
let topic = msg.topic
|
||||
if topic notin f.topics:
|
||||
debug "Dropping message due to topic not in floodsub topics", topic, msgId, peer
|
||||
continue
|
||||
|
||||
await handleData(f, t, msg.data)
|
||||
f.floodsub.withValue(topic, peers):
|
||||
toSendPeers.incl(peers[])
|
||||
|
||||
await handleData(f, topic, msg.data)
|
||||
|
||||
# In theory, if topics are the same in all messages, we could batch - we'd
|
||||
# also have to be careful to only include validated messages
|
||||
@@ -213,7 +227,7 @@ method publish*(f: FloodSub,
|
||||
trace "Created new message",
|
||||
msg = shortLog(msg), peers = peers.len, topic, msgId
|
||||
|
||||
if f.addSeen(msgId):
|
||||
if f.addSeen(f.salt(msgId)):
|
||||
# custom msgid providers might cause this
|
||||
trace "Dropping already-seen message", msgId, topic
|
||||
return 0
|
||||
@@ -231,8 +245,11 @@ method publish*(f: FloodSub,
|
||||
method initPubSub*(f: FloodSub)
|
||||
{.raises: [InitializationError].} =
|
||||
procCall PubSub(f).initPubSub()
|
||||
f.seen = TimedCache[MessageId].init(2.minutes)
|
||||
f.seenSalt = newSeqUninitialized[byte](sizeof(Hash))
|
||||
hmacDrbgGenerate(f.rng[], f.seenSalt)
|
||||
f.seen = TimedCache[SaltedId].init(2.minutes)
|
||||
f.seenSalt.init()
|
||||
|
||||
var tmp: array[32, byte]
|
||||
hmacDrbgGenerate(f.rng[], tmp)
|
||||
f.seenSalt.update(tmp)
|
||||
|
||||
f.init()
|
||||
|
||||
@@ -49,41 +49,79 @@ declareCounter(libp2p_gossipsub_received, "number of messages received (deduplic
|
||||
when defined(libp2p_expensive_metrics):
|
||||
declareCounter(libp2p_pubsub_received_messages, "number of messages received", labels = ["id", "topic"])
|
||||
|
||||
proc init*(_: type[GossipSubParams]): GossipSubParams =
|
||||
proc init*(
|
||||
_: type[GossipSubParams],
|
||||
pruneBackoff = 1.minutes,
|
||||
unsubscribeBackoff = 5.seconds,
|
||||
floodPublish = true,
|
||||
gossipFactor: float64 = 0.25,
|
||||
d = GossipSubD,
|
||||
dLow = GossipSubDlo,
|
||||
dHigh = GossipSubDhi,
|
||||
dScore = GossipSubDlo,
|
||||
dOut = GossipSubDlo - 1, # DLow - 1
|
||||
dLazy = GossipSubD, # Like D,
|
||||
heartbeatInterval = GossipSubHeartbeatInterval,
|
||||
historyLength = GossipSubHistoryLength,
|
||||
historyGossip = GossipSubHistoryGossip,
|
||||
fanoutTTL = GossipSubFanoutTTL,
|
||||
seenTTL = 2.minutes,
|
||||
gossipThreshold = -100.0,
|
||||
publishThreshold = -1000.0,
|
||||
graylistThreshold = -10000.0,
|
||||
opportunisticGraftThreshold = 0.0,
|
||||
decayInterval = 1.seconds,
|
||||
decayToZero = 0.01,
|
||||
retainScore = 2.minutes,
|
||||
appSpecificWeight = 0.0,
|
||||
ipColocationFactorWeight = 0.0,
|
||||
ipColocationFactorThreshold = 1.0,
|
||||
behaviourPenaltyWeight = -1.0,
|
||||
behaviourPenaltyDecay = 0.999,
|
||||
directPeers = initTable[PeerId, seq[MultiAddress]](),
|
||||
disconnectBadPeers = false,
|
||||
enablePX = false,
|
||||
bandwidthEstimatebps = 100_000_000, # 100 Mbps or 12.5 MBps
|
||||
overheadRateLimit = Opt.none(tuple[bytes: int, interval: Duration]),
|
||||
disconnectPeerAboveRateLimit = false,
|
||||
maxNumElementsInNonPriorityQueue = DefaultMaxNumElementsInNonPriorityQueue): GossipSubParams =
|
||||
|
||||
GossipSubParams(
|
||||
explicit: true,
|
||||
pruneBackoff: 1.minutes,
|
||||
unsubscribeBackoff: 5.seconds,
|
||||
floodPublish: true,
|
||||
gossipFactor: 0.25,
|
||||
d: GossipSubD,
|
||||
dLow: GossipSubDlo,
|
||||
dHigh: GossipSubDhi,
|
||||
dScore: GossipSubDlo,
|
||||
dOut: GossipSubDlo - 1, # DLow - 1
|
||||
dLazy: GossipSubD, # Like D
|
||||
heartbeatInterval: GossipSubHeartbeatInterval,
|
||||
historyLength: GossipSubHistoryLength,
|
||||
historyGossip: GossipSubHistoryGossip,
|
||||
fanoutTTL: GossipSubFanoutTTL,
|
||||
seenTTL: 2.minutes,
|
||||
gossipThreshold: -100,
|
||||
publishThreshold: -1000,
|
||||
graylistThreshold: -10000,
|
||||
opportunisticGraftThreshold: 0,
|
||||
decayInterval: 1.seconds,
|
||||
decayToZero: 0.01,
|
||||
retainScore: 2.minutes,
|
||||
appSpecificWeight: 0.0,
|
||||
ipColocationFactorWeight: 0.0,
|
||||
ipColocationFactorThreshold: 1.0,
|
||||
behaviourPenaltyWeight: -1.0,
|
||||
behaviourPenaltyDecay: 0.999,
|
||||
disconnectBadPeers: false,
|
||||
enablePX: false,
|
||||
bandwidthEstimatebps: 100_000_000, # 100 Mbps or 12.5 MBps
|
||||
overheadRateLimit: Opt.none(tuple[bytes: int, interval: Duration]),
|
||||
disconnectPeerAboveRateLimit: false
|
||||
pruneBackoff: pruneBackoff,
|
||||
unsubscribeBackoff: unsubscribeBackoff,
|
||||
floodPublish: floodPublish,
|
||||
gossipFactor: gossipFactor,
|
||||
d: d,
|
||||
dLow: dLow,
|
||||
dHigh: dHigh,
|
||||
dScore: dScore,
|
||||
dOut: dOut,
|
||||
dLazy: dLazy,
|
||||
heartbeatInterval: heartbeatInterval,
|
||||
historyLength: historyLength,
|
||||
historyGossip: historyGossip,
|
||||
fanoutTTL: fanoutTTL,
|
||||
seenTTL: seenTTL,
|
||||
gossipThreshold: gossipThreshold,
|
||||
publishThreshold: publishThreshold,
|
||||
graylistThreshold: graylistThreshold,
|
||||
opportunisticGraftThreshold: opportunisticGraftThreshold,
|
||||
decayInterval: decayInterval,
|
||||
decayToZero: decayToZero,
|
||||
retainScore: retainScore,
|
||||
appSpecificWeight: appSpecificWeight,
|
||||
ipColocationFactorWeight: ipColocationFactorWeight,
|
||||
ipColocationFactorThreshold: ipColocationFactorThreshold,
|
||||
behaviourPenaltyWeight: behaviourPenaltyWeight,
|
||||
behaviourPenaltyDecay: behaviourPenaltyDecay,
|
||||
directPeers: directPeers,
|
||||
disconnectBadPeers: disconnectBadPeers,
|
||||
enablePX: enablePX,
|
||||
bandwidthEstimatebps: bandwidthEstimatebps,
|
||||
overheadRateLimit: overheadRateLimit,
|
||||
disconnectPeerAboveRateLimit: disconnectPeerAboveRateLimit,
|
||||
maxNumElementsInNonPriorityQueue: maxNumElementsInNonPriorityQueue
|
||||
)
|
||||
|
||||
proc validateParameters*(parameters: GossipSubParams): Result[void, cstring] =
|
||||
@@ -114,6 +152,8 @@ proc validateParameters*(parameters: GossipSubParams): Result[void, cstring] =
|
||||
err("gossipsub: behaviourPenaltyWeight parameter error, Must be negative")
|
||||
elif parameters.behaviourPenaltyDecay < 0 or parameters.behaviourPenaltyDecay >= 1:
|
||||
err("gossipsub: behaviourPenaltyDecay parameter error, Must be between 0 and 1")
|
||||
elif parameters.maxNumElementsInNonPriorityQueue <= 0:
|
||||
err("gossipsub: maxNumElementsInNonPriorityQueue parameter error, Must be > 0")
|
||||
else:
|
||||
ok()
|
||||
|
||||
@@ -172,10 +212,10 @@ method onNewPeer*(g: GossipSub, peer: PubSubPeer) =
|
||||
|
||||
method onPubSubPeerEvent*(p: GossipSub, peer: PubSubPeer, event: PubSubPeerEvent) {.gcsafe.} =
|
||||
case event.kind
|
||||
of PubSubPeerEventKind.Connected:
|
||||
of PubSubPeerEventKind.StreamOpened:
|
||||
discard
|
||||
of PubSubPeerEventKind.Disconnected:
|
||||
# If a send connection is lost, it's better to remove peer from the mesh -
|
||||
of PubSubPeerEventKind.StreamClosed:
|
||||
# If a send stream is lost, it's better to remove peer from the mesh -
|
||||
# if it gets reestablished, the peer will be readded to the mesh, and if it
|
||||
# doesn't, well.. then we hope the peer is going away!
|
||||
for topic, peers in p.mesh.mpairs():
|
||||
@@ -183,6 +223,8 @@ method onPubSubPeerEvent*(p: GossipSub, peer: PubSubPeer, event: PubSubPeerEvent
|
||||
peers.excl(peer)
|
||||
for _, peers in p.fanout.mpairs():
|
||||
peers.excl(peer)
|
||||
of PubSubPeerEventKind.DisconnectionRequested:
|
||||
asyncSpawn p.disconnectPeer(peer) # this should unsubscribePeer the peer too
|
||||
|
||||
procCall FloodSub(p).onPubSubPeerEvent(peer, event)
|
||||
|
||||
@@ -224,10 +266,10 @@ method unsubscribePeer*(g: GossipSub, peer: PeerId) =
|
||||
|
||||
procCall FloodSub(g).unsubscribePeer(peer)
|
||||
|
||||
proc handleSubscribe*(g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
topic: string,
|
||||
subscribe: bool) =
|
||||
proc handleSubscribe(g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
topic: string,
|
||||
subscribe: bool) =
|
||||
logScope:
|
||||
peer
|
||||
topic
|
||||
@@ -276,7 +318,7 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
|
||||
var respControl: ControlMessage
|
||||
g.handleIDontWant(peer, control.idontwant)
|
||||
let iwant = g.handleIHave(peer, control.ihave)
|
||||
if iwant.messageIds.len > 0:
|
||||
if iwant.messageIDs.len > 0:
|
||||
respControl.iwant.add(iwant)
|
||||
respControl.prune.add(g.handleGraft(peer, control.graft))
|
||||
let messages = g.handleIWant(peer, control.iwant)
|
||||
@@ -292,8 +334,8 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
|
||||
|
||||
if isPruneNotEmpty:
|
||||
for prune in respControl.prune:
|
||||
if g.knownTopics.contains(prune.topicId):
|
||||
libp2p_pubsub_broadcast_prune.inc(labelValues = [prune.topicId])
|
||||
if g.knownTopics.contains(prune.topicID):
|
||||
libp2p_pubsub_broadcast_prune.inc(labelValues = [prune.topicID])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_prune.inc(labelValues = ["generic"])
|
||||
|
||||
@@ -304,11 +346,11 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
|
||||
|
||||
if messages.len > 0:
|
||||
for smsg in messages:
|
||||
for topic in smsg.topicIds:
|
||||
if g.knownTopics.contains(topic):
|
||||
libp2p_pubsub_broadcast_messages.inc(labelValues = [topic])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_messages.inc(labelValues = ["generic"])
|
||||
let topic = smsg.topic
|
||||
if g.knownTopics.contains(topic):
|
||||
libp2p_pubsub_broadcast_messages.inc(labelValues = [topic])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_messages.inc(labelValues = ["generic"])
|
||||
|
||||
# iwant replies have lower priority
|
||||
trace "sending iwant reply messages", peer
|
||||
@@ -318,13 +360,13 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) =
|
||||
|
||||
proc validateAndRelay(g: GossipSub,
|
||||
msg: Message,
|
||||
msgId, msgIdSalted: MessageId,
|
||||
msgId: MessageId, saltedId: SaltedId,
|
||||
peer: PubSubPeer) {.async.} =
|
||||
try:
|
||||
let validation = await g.validate(msg)
|
||||
|
||||
var seenPeers: HashSet[PubSubPeer]
|
||||
discard g.validationSeen.pop(msgIdSalted, seenPeers)
|
||||
discard g.validationSeen.pop(saltedId, seenPeers)
|
||||
libp2p_gossipsub_duplicate_during_validation.inc(seenPeers.len.int64)
|
||||
libp2p_gossipsub_saved_bytes.inc((msg.data.len * seenPeers.len).int64, labelValues = ["validation_duplicate"])
|
||||
|
||||
@@ -344,18 +386,16 @@ proc validateAndRelay(g: GossipSub,
|
||||
# store in cache only after validation
|
||||
g.mcache.put(msgId, msg)
|
||||
|
||||
g.rewardDelivered(peer, msg.topicIds, true)
|
||||
let topic = msg.topic
|
||||
g.rewardDelivered(peer, topic, true)
|
||||
|
||||
var toSendPeers = HashSet[PubSubPeer]()
|
||||
for t in msg.topicIds: # for every topic in the message
|
||||
if t notin g.topics:
|
||||
continue
|
||||
if topic notin g.topics:
|
||||
return
|
||||
|
||||
g.floodsub.withValue(t, peers): toSendPeers.incl(peers[])
|
||||
g.mesh.withValue(t, peers): toSendPeers.incl(peers[])
|
||||
|
||||
# add direct peers
|
||||
toSendPeers.incl(g.subscribedDirectPeers.getOrDefault(t))
|
||||
g.floodsub.withValue(topic, peers): toSendPeers.incl(peers[])
|
||||
g.mesh.withValue(topic, peers): toSendPeers.incl(peers[])
|
||||
g.subscribedDirectPeers.withValue(topic, peers): toSendPeers.incl(peers[])
|
||||
|
||||
# Don't send it to source peer, or peers that
|
||||
# sent it during validation
|
||||
@@ -366,66 +406,55 @@ proc validateAndRelay(g: GossipSub,
|
||||
# bigger than the messageId
|
||||
if msg.data.len > msgId.len * 10:
|
||||
g.broadcast(toSendPeers, RPCMsg(control: some(ControlMessage(
|
||||
idontwant: @[ControlIWant(messageIds: @[msgId])]
|
||||
idontwant: @[ControlIWant(messageIDs: @[msgId])]
|
||||
))), isHighPriority = true)
|
||||
|
||||
for peer in toSendPeers:
|
||||
for heDontWant in peer.heDontWants:
|
||||
if msgId in heDontWant:
|
||||
if saltedId in heDontWant:
|
||||
seenPeers.incl(peer)
|
||||
libp2p_gossipsub_idontwant_saved_messages.inc
|
||||
libp2p_gossipsub_saved_bytes.inc(msg.data.len.int64, labelValues = ["idontwant"])
|
||||
break
|
||||
toSendPeers.excl(seenPeers)
|
||||
|
||||
|
||||
# In theory, if topics are the same in all messages, we could batch - we'd
|
||||
# also have to be careful to only include validated messages
|
||||
g.broadcast(toSendPeers, RPCMsg(messages: @[msg]), isHighPriority = false)
|
||||
trace "forwarded message to peers", peers = toSendPeers.len, msgId, peer
|
||||
for topic in msg.topicIds:
|
||||
if topic notin g.topics: continue
|
||||
|
||||
if g.knownTopics.contains(topic):
|
||||
libp2p_pubsub_messages_rebroadcasted.inc(toSendPeers.len.int64, labelValues = [topic])
|
||||
else:
|
||||
libp2p_pubsub_messages_rebroadcasted.inc(toSendPeers.len.int64, labelValues = ["generic"])
|
||||
if g.knownTopics.contains(topic):
|
||||
libp2p_pubsub_messages_rebroadcasted.inc(toSendPeers.len.int64, labelValues = [topic])
|
||||
else:
|
||||
libp2p_pubsub_messages_rebroadcasted.inc(toSendPeers.len.int64, labelValues = ["generic"])
|
||||
|
||||
await handleData(g, topic, msg.data)
|
||||
await handleData(g, topic, msg.data)
|
||||
except CatchableError as exc:
|
||||
info "validateAndRelay failed", msg=exc.msg
|
||||
|
||||
proc dataAndTopicsIdSize(msgs: seq[Message]): int =
|
||||
msgs.mapIt(it.data.len + it.topicIds.mapIt(it.len).foldl(a + b, 0)).foldl(a + b, 0)
|
||||
msgs.mapIt(it.data.len + it.topic.len).foldl(a + b, 0)
|
||||
|
||||
proc rateLimit*(g: GossipSub, peer: PubSubPeer, rpcMsgOpt: Opt[RPCMsg], msgSize: int) {.async.} =
|
||||
proc messageOverhead(g: GossipSub, msg: RPCMsg, msgSize: int): int =
|
||||
# In this way we count even ignored fields by protobuf
|
||||
let
|
||||
payloadSize =
|
||||
if g.verifySignature:
|
||||
byteSize(msg.messages)
|
||||
else:
|
||||
dataAndTopicsIdSize(msg.messages)
|
||||
controlSize = msg.control.withValue(control):
|
||||
byteSize(control.ihave) + byteSize(control.iwant)
|
||||
do: # no control message
|
||||
0
|
||||
|
||||
var rmsg = rpcMsgOpt.valueOr:
|
||||
peer.overheadRateLimitOpt.withValue(overheadRateLimit):
|
||||
if not overheadRateLimit.tryConsume(msgSize):
|
||||
libp2p_gossipsub_peers_rate_limit_hits.inc(labelValues = [peer.getAgent()]) # let's just measure at the beginning for test purposes.
|
||||
debug "Peer sent a msg that couldn't be decoded and it's above rate limit.", peer, uselessAppBytesNum = msgSize
|
||||
if g.parameters.disconnectPeerAboveRateLimit:
|
||||
await g.disconnectPeer(peer)
|
||||
raise newException(PeerRateLimitError, "Peer disconnected because it's above rate limit.")
|
||||
|
||||
raise newException(CatchableError, "Peer msg couldn't be decoded")
|
||||
|
||||
let usefulMsgBytesNum =
|
||||
if g.verifySignature:
|
||||
byteSize(rmsg.messages)
|
||||
else:
|
||||
dataAndTopicsIdSize(rmsg.messages)
|
||||
|
||||
var uselessAppBytesNum = msgSize - usefulMsgBytesNum
|
||||
rmsg.control.withValue(control):
|
||||
uselessAppBytesNum -= (byteSize(control.ihave) + byteSize(control.iwant))
|
||||
msgSize - payloadSize - controlSize
|
||||
|
||||
proc rateLimit*(g: GossipSub, peer: PubSubPeer, overhead: int) {.async.} =
|
||||
peer.overheadRateLimitOpt.withValue(overheadRateLimit):
|
||||
if not overheadRateLimit.tryConsume(uselessAppBytesNum):
|
||||
if not overheadRateLimit.tryConsume(overhead):
|
||||
libp2p_gossipsub_peers_rate_limit_hits.inc(labelValues = [peer.getAgent()]) # let's just measure at the beginning for test purposes.
|
||||
debug "Peer sent too much useless application data and it's above rate limit.", peer, msgSize, uselessAppBytesNum, rmsg
|
||||
debug "Peer sent too much useless application data and it's above rate limit.", peer, overhead
|
||||
if g.parameters.disconnectPeerAboveRateLimit:
|
||||
await g.disconnectPeer(peer)
|
||||
raise newException(PeerRateLimitError, "Peer disconnected because it's above rate limit.")
|
||||
@@ -433,27 +462,31 @@ proc rateLimit*(g: GossipSub, peer: PubSubPeer, rpcMsgOpt: Opt[RPCMsg], msgSize:
|
||||
method rpcHandler*(g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
data: seq[byte]) {.async.} =
|
||||
|
||||
let msgSize = data.len
|
||||
var rpcMsg = decodeRpcMsg(data).valueOr:
|
||||
debug "failed to decode msg from peer", peer, err = error
|
||||
await rateLimit(g, peer, Opt.none(RPCMsg), msgSize)
|
||||
return
|
||||
await rateLimit(g, peer, msgSize)
|
||||
# Raising in the handler closes the gossipsub connection (but doesn't
|
||||
# disconnect the peer!)
|
||||
# TODO evaluate behaviour penalty values
|
||||
peer.behaviourPenalty += 0.1
|
||||
|
||||
raise newException(CatchableError, "Peer msg couldn't be decoded")
|
||||
|
||||
when defined(libp2p_expensive_metrics):
|
||||
for m in rpcMsg.messages:
|
||||
for t in m.topicIds:
|
||||
libp2p_pubsub_received_messages.inc(labelValues = [$peer.peerId, t])
|
||||
libp2p_pubsub_received_messages.inc(labelValues = [$peer.peerId, m.topic])
|
||||
|
||||
trace "decoded msg from peer", peer, msg = rpcMsg.shortLog
|
||||
await rateLimit(g, peer, Opt.some(rpcMsg), msgSize)
|
||||
await rateLimit(g, peer, g.messageOverhead(rpcMsg, msgSize))
|
||||
|
||||
# trigger hooks
|
||||
# trigger hooks - these may modify the message
|
||||
peer.recvObservers(rpcMsg)
|
||||
|
||||
if rpcMsg.ping.len in 1..<64 and peer.pingBudget > 0:
|
||||
g.send(peer, RPCMsg(pong: rpcMsg.ping), isHighPriority = true)
|
||||
peer.pingBudget.dec
|
||||
|
||||
for i in 0..<min(g.topicsHigh, rpcMsg.subscriptions.len):
|
||||
template sub: untyped = rpcMsg.subscriptions[i]
|
||||
g.handleSubscribe(peer, sub.topic, sub.subscribe)
|
||||
@@ -473,16 +506,15 @@ method rpcHandler*(g: GossipSub,
|
||||
if msgIdResult.isErr:
|
||||
debug "Dropping message due to failed message id generation",
|
||||
error = msgIdResult.error
|
||||
# TODO: descore peers due to error during message validation (malicious?)
|
||||
await g.punishInvalidMessage(peer, msg)
|
||||
continue
|
||||
|
||||
let
|
||||
msgId = msgIdResult.get
|
||||
msgIdSalted = msgId & g.seenSalt
|
||||
msgIdSalted = g.salt(msgId)
|
||||
topic = msg.topic
|
||||
|
||||
# addSeen adds salt to msgId to avoid
|
||||
# remote attacking the hash function
|
||||
if g.addSeen(msgId):
|
||||
if g.addSeen(msgIdSalted):
|
||||
trace "Dropping already-seen message", msgId = shortLog(msgId), peer
|
||||
|
||||
var alreadyReceived = false
|
||||
@@ -492,8 +524,8 @@ method rpcHandler*(g: GossipSub,
|
||||
alreadyReceived = true
|
||||
|
||||
if not alreadyReceived:
|
||||
let delay = Moment.now() - g.firstSeen(msgId)
|
||||
g.rewardDelivered(peer, msg.topicIds, false, delay)
|
||||
let delay = Moment.now() - g.firstSeen(msgIdSalted)
|
||||
g.rewardDelivered(peer, topic, false, delay)
|
||||
|
||||
libp2p_gossipsub_duplicate.inc()
|
||||
|
||||
@@ -503,7 +535,7 @@ method rpcHandler*(g: GossipSub,
|
||||
libp2p_gossipsub_received.inc()
|
||||
|
||||
# avoid processing messages we are not interested in
|
||||
if msg.topicIds.allIt(it notin g.topics):
|
||||
if topic notin g.topics:
|
||||
debug "Dropping message of topic without subscription", msgId = shortLog(msgId), peer
|
||||
continue
|
||||
|
||||
@@ -569,25 +601,24 @@ method onTopicSubscription*(g: GossipSub, topic: string, subscribed: bool) =
|
||||
|
||||
g.mesh.del(topic)
|
||||
|
||||
|
||||
# Send unsubscribe (in reverse order to sub/graft)
|
||||
procCall PubSub(g).onTopicSubscription(topic, subscribed)
|
||||
|
||||
method publish*(g: GossipSub,
|
||||
topic: string,
|
||||
data: seq[byte]): Future[int] {.async.} =
|
||||
# base returns always 0
|
||||
discard await procCall PubSub(g).publish(topic, data)
|
||||
|
||||
logScope:
|
||||
topic
|
||||
|
||||
trace "Publishing message on topic", data = data.shortLog
|
||||
|
||||
if topic.len <= 0: # data could be 0/empty
|
||||
debug "Empty topic, skipping publish"
|
||||
return 0
|
||||
|
||||
# base returns always 0
|
||||
discard await procCall PubSub(g).publish(topic, data)
|
||||
|
||||
trace "Publishing message on topic", data = data.shortLog
|
||||
|
||||
var peers: HashSet[PubSubPeer]
|
||||
|
||||
# add always direct peers
|
||||
@@ -600,38 +631,39 @@ method publish*(g: GossipSub,
|
||||
# With flood publishing enabled, the mesh is used when propagating messages from other peers,
|
||||
# but a peer's own messages will always be published to all known peers in the topic, limited
|
||||
# to the amount of peers we can send it to in one heartbeat
|
||||
var maxPeersToFlodOpt: Opt[int64]
|
||||
if g.parameters.bandwidthEstimatebps > 0:
|
||||
let
|
||||
bandwidth = (g.parameters.bandwidthEstimatebps) div 8 div 1000 # Divisions are to convert it to Bytes per ms TODO replace with bandwidth estimate
|
||||
msToTransmit = max(data.len div bandwidth, 1)
|
||||
maxPeersToFlodOpt = Opt.some(max(g.parameters.heartbeatInterval.milliseconds div msToTransmit, g.parameters.dLow))
|
||||
|
||||
let maxPeersToFlood =
|
||||
if g.parameters.bandwidthEstimatebps > 0:
|
||||
let
|
||||
bandwidth = (g.parameters.bandwidthEstimatebps) div 8 div 1000 # Divisions are to convert it to Bytes per ms TODO replace with bandwidth estimate
|
||||
msToTransmit = max(data.len div bandwidth, 1)
|
||||
max(g.parameters.heartbeatInterval.milliseconds div msToTransmit, g.parameters.dLow)
|
||||
else:
|
||||
int.high() # unlimited
|
||||
|
||||
for peer in g.gossipsub.getOrDefault(topic):
|
||||
maxPeersToFlodOpt.withValue(maxPeersToFlod):
|
||||
if peers.len >= maxPeersToFlod: break
|
||||
if peers.len >= maxPeersToFlood: break
|
||||
|
||||
if peer.score >= g.parameters.publishThreshold:
|
||||
trace "publish: including flood/high score peer", peer
|
||||
peers.incl(peer)
|
||||
|
||||
if peers.len < g.parameters.dLow:
|
||||
# not subscribed, or bad mesh, send to fanout peers
|
||||
var fanoutPeers = g.fanout.getOrDefault(topic).toSeq()
|
||||
if fanoutPeers.len < g.parameters.dLow:
|
||||
g.replenishFanout(topic)
|
||||
fanoutPeers = g.fanout.getOrDefault(topic).toSeq()
|
||||
elif peers.len < g.parameters.dLow:
|
||||
# not subscribed or bad mesh, send to fanout peers
|
||||
# when flood-publishing, fanout won't help since all potential peers have
|
||||
# already been added
|
||||
|
||||
g.replenishFanout(topic) # Make sure fanout is populated
|
||||
|
||||
var fanoutPeers = g.fanout.getOrDefault(topic).toSeq()
|
||||
g.rng.shuffle(fanoutPeers)
|
||||
|
||||
for fanPeer in fanoutPeers:
|
||||
peers.incl(fanPeer)
|
||||
if peers.len > g.parameters.d: break
|
||||
|
||||
# even if we couldn't publish,
|
||||
# we still attempted to publish
|
||||
# on the topic, so it makes sense
|
||||
# to update the last topic publish
|
||||
# time
|
||||
# Attempting to publish counts as fanout send (even if the message
|
||||
# ultimately is not sent)
|
||||
g.lastFanoutPubSub[topic] = Moment.fromNow(g.parameters.fanoutTTL)
|
||||
|
||||
if peers.len == 0:
|
||||
@@ -659,8 +691,10 @@ method publish*(g: GossipSub,
|
||||
|
||||
trace "Created new message", msg = shortLog(msg), peers = peers.len
|
||||
|
||||
if g.addSeen(msgId):
|
||||
# custom msgid providers might cause this
|
||||
if g.addSeen(g.salt(msgId)):
|
||||
# If the message was received or published recently, don't re-publish it -
|
||||
# this might happen when not using sequence id:s and / or with a custom
|
||||
# message id provider
|
||||
trace "Dropping already-seen message"
|
||||
return 0
|
||||
|
||||
@@ -748,7 +782,7 @@ method initPubSub*(g: GossipSub)
|
||||
raise newException(InitializationError, $validationRes.error)
|
||||
|
||||
# init the floodsub stuff here, we customize timedcache in gossip!
|
||||
g.seen = TimedCache[MessageId].init(g.parameters.seenTTL)
|
||||
g.seen = TimedCache[SaltedId].init(g.parameters.seenTTL)
|
||||
|
||||
# init gossip stuff
|
||||
g.mcache = MCache.init(g.parameters.historyGossip, g.parameters.historyLength)
|
||||
@@ -761,4 +795,5 @@ method getOrCreatePeer*(
|
||||
let peer = procCall PubSub(g).getOrCreatePeer(peerId, protos)
|
||||
g.parameters.overheadRateLimit.withValue(overheadRateLimit):
|
||||
peer.overheadRateLimitOpt = Opt.some(TokenBucket.new(overheadRateLimit.bytes, overheadRateLimit.interval))
|
||||
peer.maxNumElementsInNonPriorityQueue = g.parameters.maxNumElementsInNonPriorityQueue
|
||||
return peer
|
||||
|
||||
@@ -30,7 +30,7 @@ declareGauge(libp2p_gossipsub_healthy_peers_topics, "number of topics in mesh wi
|
||||
declareCounter(libp2p_gossipsub_above_dhigh_condition, "number of above dhigh pruning branches ran", labels = ["topic"])
|
||||
declareGauge(libp2p_gossipsub_received_iwants, "received iwants", labels = ["kind"])
|
||||
|
||||
proc grafted*(g: GossipSub, p: PubSubPeer, topic: string) {.raises: [].} =
|
||||
proc grafted*(g: GossipSub, p: PubSubPeer, topic: string) =
|
||||
g.withPeerStats(p.peerId) do (stats: var PeerStats):
|
||||
var info = stats.topicInfos.getOrDefault(topic)
|
||||
info.graftTime = Moment.now()
|
||||
@@ -46,7 +46,7 @@ proc pruned*(g: GossipSub,
|
||||
p: PubSubPeer,
|
||||
topic: string,
|
||||
setBackoff: bool = true,
|
||||
backoff = none(Duration)) {.raises: [].} =
|
||||
backoff = none(Duration)) =
|
||||
if setBackoff:
|
||||
let
|
||||
backoffDuration = backoff.get(g.parameters.pruneBackoff)
|
||||
@@ -70,7 +70,7 @@ proc pruned*(g: GossipSub,
|
||||
|
||||
trace "pruned", peer=p, topic
|
||||
|
||||
proc handleBackingOff*(t: var BackoffTable, topic: string) {.raises: [].} =
|
||||
proc handleBackingOff*(t: var BackoffTable, topic: string) =
|
||||
let now = Moment.now()
|
||||
var expired = toSeq(t.getOrDefault(topic).pairs())
|
||||
expired.keepIf do (pair: tuple[peer: PeerId, expire: Moment]) -> bool:
|
||||
@@ -79,7 +79,7 @@ proc handleBackingOff*(t: var BackoffTable, topic: string) {.raises: [].} =
|
||||
t.withValue(topic, v):
|
||||
v[].del(peer)
|
||||
|
||||
proc peerExchangeList*(g: GossipSub, topic: string): seq[PeerInfoMsg] {.raises: [].} =
|
||||
proc peerExchangeList*(g: GossipSub, topic: string): seq[PeerInfoMsg] =
|
||||
if not g.parameters.enablePX:
|
||||
return @[]
|
||||
var peers = g.gossipsub.getOrDefault(topic, initHashSet[PubSubPeer]()).toSeq()
|
||||
@@ -100,11 +100,11 @@ proc peerExchangeList*(g: GossipSub, topic: string): seq[PeerInfoMsg] {.raises:
|
||||
|
||||
proc handleGraft*(g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
grafts: seq[ControlGraft]): seq[ControlPrune] = # {.raises: [Defect].} TODO chronicles exception on windows
|
||||
grafts: seq[ControlGraft]): seq[ControlPrune] =
|
||||
var prunes: seq[ControlPrune]
|
||||
for graft in grafts:
|
||||
let topic = graft.topicId
|
||||
trace "peer grafted topic", peer, topic
|
||||
let topic = graft.topicID
|
||||
trace "peer grafted topicID", peer, topic
|
||||
|
||||
# It is an error to GRAFT on a direct peer
|
||||
if peer.peerId in g.parameters.directPeers:
|
||||
@@ -204,12 +204,11 @@ proc getPeers(prune: ControlPrune, peer: PubSubPeer): seq[(PeerId, Option[PeerRe
|
||||
|
||||
routingRecords
|
||||
|
||||
|
||||
proc handlePrune*(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) {.raises: [].} =
|
||||
proc handlePrune*(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) =
|
||||
for prune in prunes:
|
||||
let topic = prune.topicId
|
||||
let topic = prune.topicID
|
||||
|
||||
trace "peer pruned topic", peer, topic
|
||||
trace "peer pruned topicID", peer, topic
|
||||
|
||||
# add peer backoff
|
||||
if prune.backoff > 0:
|
||||
@@ -239,7 +238,7 @@ proc handlePrune*(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) {.r
|
||||
|
||||
proc handleIHave*(g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
ihaves: seq[ControlIHave]): ControlIWant {.raises: [].} =
|
||||
ihaves: seq[ControlIHave]): ControlIWant =
|
||||
var res: ControlIWant
|
||||
if peer.score < g.parameters.gossipThreshold:
|
||||
trace "ihave: ignoring low score peer", peer, score = peer.score
|
||||
@@ -248,33 +247,32 @@ proc handleIHave*(g: GossipSub,
|
||||
else:
|
||||
for ihave in ihaves:
|
||||
trace "peer sent ihave",
|
||||
peer, topic = ihave.topicId, msgs = ihave.messageIds
|
||||
if ihave.topicId in g.topics:
|
||||
for msgId in ihave.messageIds:
|
||||
if not g.hasSeen(msgId):
|
||||
peer, topicID = ihave.topicID, msgs = ihave.messageIDs
|
||||
if ihave.topicID in g.topics:
|
||||
for msgId in ihave.messageIDs:
|
||||
if not g.hasSeen(g.salt(msgId)):
|
||||
if peer.iHaveBudget <= 0:
|
||||
break
|
||||
elif msgId notin res.messageIds:
|
||||
res.messageIds.add(msgId)
|
||||
elif msgId notin res.messageIDs:
|
||||
res.messageIDs.add(msgId)
|
||||
dec peer.iHaveBudget
|
||||
trace "requested message via ihave", messageID=msgId
|
||||
# shuffling res.messageIDs before sending it out to increase the likelihood
|
||||
# of getting an answer if the peer truncates the list due to internal size restrictions.
|
||||
g.rng.shuffle(res.messageIds)
|
||||
g.rng.shuffle(res.messageIDs)
|
||||
return res
|
||||
|
||||
proc handleIDontWant*(g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
iDontWants: seq[ControlIWant]) =
|
||||
for dontWant in iDontWants:
|
||||
for messageId in dontWant.messageIds:
|
||||
for messageId in dontWant.messageIDs:
|
||||
if peer.heDontWants[^1].len > 1000: break
|
||||
if messageId.len > 100: continue
|
||||
peer.heDontWants[^1].incl(messageId)
|
||||
peer.heDontWants[^1].incl(g.salt(messageId))
|
||||
|
||||
proc handleIWant*(g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
iwants: seq[ControlIWant]): seq[Message] {.raises: [].} =
|
||||
iwants: seq[ControlIWant]): seq[Message] =
|
||||
var
|
||||
messages: seq[Message]
|
||||
invalidRequests = 0
|
||||
@@ -282,7 +280,7 @@ proc handleIWant*(g: GossipSub,
|
||||
trace "iwant: ignoring low score peer", peer, score = peer.score
|
||||
else:
|
||||
for iwant in iwants:
|
||||
for mid in iwant.messageIds:
|
||||
for mid in iwant.messageIDs:
|
||||
trace "peer sent iwant", peer, messageID = mid
|
||||
# canAskIWant will only return true once for a specific message
|
||||
if not peer.canAskIWant(mid):
|
||||
@@ -300,7 +298,7 @@ proc handleIWant*(g: GossipSub,
|
||||
messages.add(msg)
|
||||
return messages
|
||||
|
||||
proc commitMetrics(metrics: var MeshMetrics) {.raises: [].} =
|
||||
proc commitMetrics(metrics: var MeshMetrics) =
|
||||
libp2p_gossipsub_low_peers_topics.set(metrics.lowPeersTopics)
|
||||
libp2p_gossipsub_no_peers_topics.set(metrics.noPeersTopics)
|
||||
libp2p_gossipsub_under_dout_topics.set(metrics.underDoutTopics)
|
||||
@@ -309,7 +307,7 @@ proc commitMetrics(metrics: var MeshMetrics) {.raises: [].} =
|
||||
libp2p_gossipsub_peers_per_topic_fanout.set(metrics.otherPeersPerTopicFanout, labelValues = ["other"])
|
||||
libp2p_gossipsub_peers_per_topic_mesh.set(metrics.otherPeersPerTopicMesh, labelValues = ["other"])
|
||||
|
||||
proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil) {.raises: [].} =
|
||||
proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil) =
|
||||
logScope:
|
||||
topic
|
||||
mesh = g.mesh.peers(topic)
|
||||
@@ -539,7 +537,7 @@ proc rebalanceMesh*(g: GossipSub, topic: string, metrics: ptr MeshMetrics = nil)
|
||||
backoff: g.parameters.pruneBackoff.seconds.uint64)])))
|
||||
g.broadcast(prunes, prune, isHighPriority = true)
|
||||
|
||||
proc dropFanoutPeers*(g: GossipSub) {.raises: [].} =
|
||||
proc dropFanoutPeers*(g: GossipSub) =
|
||||
# drop peers that we haven't published to in
|
||||
# GossipSubFanoutTTL seconds
|
||||
let now = Moment.now()
|
||||
@@ -552,7 +550,7 @@ proc dropFanoutPeers*(g: GossipSub) {.raises: [].} =
|
||||
for topic in drops:
|
||||
g.lastFanoutPubSub.del topic
|
||||
|
||||
proc replenishFanout*(g: GossipSub, topic: string) {.raises: [].} =
|
||||
proc replenishFanout*(g: GossipSub, topic: string) =
|
||||
## get fanout peers for a topic
|
||||
logScope: topic
|
||||
trace "about to replenish fanout"
|
||||
@@ -568,7 +566,7 @@ proc replenishFanout*(g: GossipSub, topic: string) {.raises: [].} =
|
||||
|
||||
trace "fanout replenished with peers", peers = g.fanout.peers(topic)
|
||||
|
||||
proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] {.raises: [].} =
|
||||
proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] =
|
||||
## gossip iHave messages to peers
|
||||
##
|
||||
|
||||
@@ -579,7 +577,7 @@ proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] {.raises:
|
||||
trace "getting gossip peers (iHave)", ntopics=topics.len
|
||||
for topic in topics:
|
||||
if topic notin g.gossipsub:
|
||||
trace "topic not in gossip array, skipping", topicID = topic
|
||||
trace "topic not in gossip array, skipping", topic = topic
|
||||
continue
|
||||
|
||||
let mids = g.mcache.window(topic)
|
||||
@@ -612,26 +610,25 @@ proc getGossipPeers*(g: GossipSub): Table[PubSubPeer, ControlMessage] {.raises:
|
||||
x notin gossipPeers and
|
||||
x.score >= g.parameters.gossipThreshold
|
||||
|
||||
var target = g.parameters.dLazy
|
||||
let factor = (g.parameters.gossipFactor.float * allPeers.len.float).int
|
||||
if factor > target:
|
||||
target = min(factor, allPeers.len)
|
||||
# https://github.com/libp2p/specs/blob/98c5aa9421703fc31b0833ad8860a55db15be063/pubsub/gossipsub/gossipsub-v1.1.md#adaptive-gossip-dissemination
|
||||
let
|
||||
factor = (g.parameters.gossipFactor.float * allPeers.len.float).int
|
||||
target = max(g.parameters.dLazy, factor)
|
||||
|
||||
if target < allPeers.len:
|
||||
g.rng.shuffle(allPeers)
|
||||
allPeers.setLen(target)
|
||||
|
||||
let msgIdsAsSet = ihave.messageIds.toHashSet()
|
||||
|
||||
for peer in allPeers:
|
||||
control.mgetOrPut(peer, ControlMessage()).ihave.add(ihave)
|
||||
peer.sentIHaves[^1].incl(msgIdsAsSet)
|
||||
for msgId in ihave.messageIDs:
|
||||
peer.sentIHaves[^1].incl(msgId)
|
||||
|
||||
libp2p_gossipsub_cache_window_size.set(cacheWindowSize.int64)
|
||||
|
||||
return control
|
||||
|
||||
proc onHeartbeat(g: GossipSub) {.raises: [].} =
|
||||
proc onHeartbeat(g: GossipSub) =
|
||||
# reset IWANT budget
|
||||
# reset IHAVE cap
|
||||
block:
|
||||
@@ -639,7 +636,7 @@ proc onHeartbeat(g: GossipSub) {.raises: [].} =
|
||||
peer.sentIHaves.addFirst(default(HashSet[MessageId]))
|
||||
if peer.sentIHaves.len > g.parameters.historyLength:
|
||||
discard peer.sentIHaves.popLast()
|
||||
peer.heDontWants.addFirst(default(HashSet[MessageId]))
|
||||
peer.heDontWants.addFirst(default(HashSet[SaltedId]))
|
||||
if peer.heDontWants.len > g.parameters.historyLength:
|
||||
discard peer.heDontWants.popLast()
|
||||
peer.iHaveBudget = IHavePeerBudget
|
||||
@@ -687,16 +684,14 @@ proc onHeartbeat(g: GossipSub) {.raises: [].} =
|
||||
for peer, control in peers:
|
||||
# only ihave from here
|
||||
for ihave in control.ihave:
|
||||
if g.knownTopics.contains(ihave.topicId):
|
||||
libp2p_pubsub_broadcast_ihave.inc(labelValues = [ihave.topicId])
|
||||
if g.knownTopics.contains(ihave.topicID):
|
||||
libp2p_pubsub_broadcast_ihave.inc(labelValues = [ihave.topicID])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_ihave.inc(labelValues = ["generic"])
|
||||
g.send(peer, RPCMsg(control: some(control)), isHighPriority = true)
|
||||
|
||||
g.mcache.shift() # shift the cache
|
||||
|
||||
# {.pop.} # raises []
|
||||
|
||||
proc heartbeat*(g: GossipSub) {.async.} =
|
||||
heartbeat "GossipSub", g.parameters.heartbeatInterval:
|
||||
trace "running heartbeat", instance = cast[int](g)
|
||||
|
||||
@@ -87,8 +87,6 @@ proc colocationFactor(g: GossipSub, peer: PubSubPeer): float64 =
|
||||
else:
|
||||
0.0
|
||||
|
||||
{.pop.}
|
||||
|
||||
proc disconnectPeer*(g: GossipSub, peer: PubSubPeer) {.async.} =
|
||||
try:
|
||||
await g.switch.disconnect(peer.peerId)
|
||||
@@ -250,43 +248,42 @@ proc punishInvalidMessage*(g: GossipSub, peer: PubSubPeer, msg: Message) {.async
|
||||
await g.disconnectPeer(peer)
|
||||
raise newException(PeerRateLimitError, "Peer disconnected because it's above rate limit.")
|
||||
|
||||
let topic = msg.topic
|
||||
if topic notin g.topics:
|
||||
return
|
||||
|
||||
for tt in msg.topicIds:
|
||||
let t = tt
|
||||
if t notin g.topics:
|
||||
continue
|
||||
|
||||
let tt = t
|
||||
# update stats
|
||||
g.withPeerStats(peer.peerId) do (stats: var PeerStats):
|
||||
stats.topicInfos.mgetOrPut(tt, TopicInfo()).invalidMessageDeliveries += 1
|
||||
# update stats
|
||||
g.withPeerStats(peer.peerId) do(stats: var PeerStats):
|
||||
stats.topicInfos.mgetOrPut(topic, TopicInfo()).invalidMessageDeliveries += 1
|
||||
|
||||
proc addCapped*[T](stat: var T, diff, cap: T) =
|
||||
stat += min(diff, cap - stat)
|
||||
|
||||
proc rewardDelivered*(
|
||||
g: GossipSub, peer: PubSubPeer, topics: openArray[string], first: bool, delay = ZeroDuration) =
|
||||
for tt in topics:
|
||||
let t = tt
|
||||
if t notin g.topics:
|
||||
continue
|
||||
g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
topic: string,
|
||||
first: bool,
|
||||
delay = ZeroDuration,
|
||||
) =
|
||||
if topic notin g.topics:
|
||||
return
|
||||
|
||||
let tt = t
|
||||
let topicParams = g.topicParams.mgetOrPut(t, TopicParams.init())
|
||||
# if in mesh add more delivery score
|
||||
let topicParams = g.topicParams.mgetOrPut(topic, TopicParams.init())
|
||||
# if in mesh add more delivery score
|
||||
|
||||
if delay > topicParams.meshMessageDeliveriesWindow:
|
||||
# Too old
|
||||
continue
|
||||
if delay > topicParams.meshMessageDeliveriesWindow:
|
||||
# Too old
|
||||
return
|
||||
|
||||
g.withPeerStats(peer.peerId) do (stats: var PeerStats):
|
||||
stats.topicInfos.withValue(tt, tstats):
|
||||
if first:
|
||||
tstats[].firstMessageDeliveries.addCapped(
|
||||
1, topicParams.firstMessageDeliveriesCap)
|
||||
g.withPeerStats(peer.peerId) do (stats: var PeerStats):
|
||||
stats.topicInfos.withValue(topic, tstats):
|
||||
if first:
|
||||
tstats[].firstMessageDeliveries.addCapped(
|
||||
1, topicParams.firstMessageDeliveriesCap)
|
||||
|
||||
if tstats[].inMesh:
|
||||
tstats[].meshMessageDeliveries.addCapped(
|
||||
1, topicParams.meshMessageDeliveriesCap)
|
||||
do: # make sure we don't loose this information
|
||||
stats.topicInfos[tt] = TopicInfo(meshMessageDeliveries: 1)
|
||||
if tstats[].inMesh:
|
||||
tstats[].meshMessageDeliveries.addCapped(
|
||||
1, topicParams.meshMessageDeliveriesCap)
|
||||
do: # make sure we don't lose this information
|
||||
stats.topicInfos[topic] = TopicInfo(meshMessageDeliveries: 1)
|
||||
|
||||
@@ -102,6 +102,11 @@ type
|
||||
behaviourPenalty*: float64 # the eventual penalty score
|
||||
|
||||
GossipSubParams* {.public.} = object
|
||||
# explicit is used to check if the GossipSubParams instance was created by the user either passing params to GossipSubParams(...)
|
||||
# or GossipSubParams.init(...). In the first case explicit should be set to true when calling the Nim constructor.
|
||||
# In the second case, the param isn't necessary and should be always be set to true by init.
|
||||
# If none of those options were used, it means the instance was created using Nim default values.
|
||||
# In this case, GossipSubParams.init() should be called when initing GossipSub to set the values to their default value defined by nim-libp2p.
|
||||
explicit*: bool
|
||||
pruneBackoff*: Duration
|
||||
unsubscribeBackoff*: Duration
|
||||
@@ -147,8 +152,11 @@ type
|
||||
overheadRateLimit*: Opt[tuple[bytes: int, interval: Duration]]
|
||||
disconnectPeerAboveRateLimit*: bool
|
||||
|
||||
# Max number of elements allowed in the non-priority queue. When this limit has been reached, the peer will be disconnected.
|
||||
maxNumElementsInNonPriorityQueue*: int
|
||||
|
||||
BackoffTable* = Table[string, Table[PeerId, Moment]]
|
||||
ValidationSeenTable* = Table[MessageId, HashSet[PubSubPeer]]
|
||||
ValidationSeenTable* = Table[SaltedId, HashSet[PubSubPeer]]
|
||||
|
||||
RoutingRecordsPair* = tuple[id: PeerId, record: Option[PeerRecord]]
|
||||
RoutingRecordsHandler* =
|
||||
@@ -164,8 +172,6 @@ type
|
||||
subscribedDirectPeers*: PeerTable # directpeers that we keep alive
|
||||
backingOff*: BackoffTable # peers to backoff from when replenishing the mesh
|
||||
lastFanoutPubSub*: Table[string, Moment] # last publish time for fanout topics
|
||||
gossip*: Table[string, seq[ControlIHave]] # pending gossip
|
||||
control*: Table[string, ControlMessage] # pending control messages
|
||||
mcache*: MCache # messages cache
|
||||
validationSeen*: ValidationSeenTable # peers who sent us message in validation
|
||||
heartbeatFut*: Future[void] # cancellation future for heartbeat interval
|
||||
|
||||
@@ -9,52 +9,57 @@
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import std/[sets, tables, options]
|
||||
import std/[sets, tables]
|
||||
import rpc/[messages]
|
||||
import results
|
||||
|
||||
export sets, tables, messages, options
|
||||
export sets, tables, messages, results
|
||||
|
||||
type
|
||||
CacheEntry* = object
|
||||
mid*: MessageId
|
||||
topicIds*: seq[string]
|
||||
msgId*: MessageId
|
||||
topic*: string
|
||||
|
||||
MCache* = object of RootObj
|
||||
msgs*: Table[MessageId, Message]
|
||||
history*: seq[seq[CacheEntry]]
|
||||
pos*: int
|
||||
windowSize*: Natural
|
||||
|
||||
func get*(c: MCache, mid: MessageId): Option[Message] =
|
||||
if mid in c.msgs:
|
||||
try: some(c.msgs[mid])
|
||||
func get*(c: MCache, msgId: MessageId): Opt[Message] =
|
||||
if msgId in c.msgs:
|
||||
try: Opt.some(c.msgs[msgId])
|
||||
except KeyError: raiseAssert "checked"
|
||||
else:
|
||||
none(Message)
|
||||
Opt.none(Message)
|
||||
|
||||
func contains*(c: MCache, mid: MessageId): bool =
|
||||
mid in c.msgs
|
||||
func contains*(c: MCache, msgId: MessageId): bool =
|
||||
msgId in c.msgs
|
||||
|
||||
func put*(c: var MCache, msgId: MessageId, msg: Message) =
|
||||
if not c.msgs.hasKeyOrPut(msgId, msg):
|
||||
# Only add cache entry if the message was not already in the cache
|
||||
c.history[0].add(CacheEntry(mid: msgId, topicIds: msg.topicIds))
|
||||
c.history[c.pos].add(CacheEntry(msgId: msgId, topic: msg.topic))
|
||||
|
||||
func window*(c: MCache, topic: string): HashSet[MessageId] =
|
||||
let
|
||||
len = min(c.windowSize, c.history.len)
|
||||
|
||||
for i in 0..<len:
|
||||
for entry in c.history[i]:
|
||||
for t in entry.topicIds:
|
||||
if t == topic:
|
||||
result.incl(entry.mid)
|
||||
break
|
||||
# Work backwards from `pos` in the circular buffer
|
||||
for entry in c.history[(c.pos + c.history.len - i) mod c.history.len]:
|
||||
if entry.topic == topic:
|
||||
result.incl(entry.msgId)
|
||||
|
||||
func shift*(c: var MCache) =
|
||||
for entry in c.history.pop():
|
||||
c.msgs.del(entry.mid)
|
||||
# Shift circular buffer to write to a new position, clearing it from past
|
||||
# iterations
|
||||
c.pos = (c.pos + 1) mod c.history.len
|
||||
|
||||
c.history.insert(@[])
|
||||
for entry in c.history[c.pos]:
|
||||
c.msgs.del(entry.msgId)
|
||||
|
||||
reset(c.history[c.pos])
|
||||
|
||||
func init*(T: type MCache, window, history: Natural): T =
|
||||
T(
|
||||
|
||||
@@ -30,7 +30,6 @@ import ./errors as pubsub_errors,
|
||||
../../errors,
|
||||
../../utility
|
||||
|
||||
import metrics
|
||||
import stew/results
|
||||
export results
|
||||
|
||||
@@ -181,28 +180,28 @@ proc broadcast*(
|
||||
libp2p_pubsub_broadcast_unsubscriptions.inc(npeers, labelValues = ["generic"])
|
||||
|
||||
for smsg in msg.messages:
|
||||
for topic in smsg.topicIds:
|
||||
if p.knownTopics.contains(topic):
|
||||
libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = [topic])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = ["generic"])
|
||||
let topic = smsg.topic
|
||||
if p.knownTopics.contains(topic):
|
||||
libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = [topic])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_messages.inc(npeers, labelValues = ["generic"])
|
||||
|
||||
msg.control.withValue(control):
|
||||
libp2p_pubsub_broadcast_iwant.inc(npeers * control.iwant.len.int64)
|
||||
|
||||
for ihave in control.ihave:
|
||||
if p.knownTopics.contains(ihave.topicId):
|
||||
libp2p_pubsub_broadcast_ihave.inc(npeers, labelValues = [ihave.topicId])
|
||||
if p.knownTopics.contains(ihave.topicID):
|
||||
libp2p_pubsub_broadcast_ihave.inc(npeers, labelValues = [ihave.topicID])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_ihave.inc(npeers, labelValues = ["generic"])
|
||||
for graft in control.graft:
|
||||
if p.knownTopics.contains(graft.topicId):
|
||||
libp2p_pubsub_broadcast_graft.inc(npeers, labelValues = [graft.topicId])
|
||||
if p.knownTopics.contains(graft.topicID):
|
||||
libp2p_pubsub_broadcast_graft.inc(npeers, labelValues = [graft.topicID])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_graft.inc(npeers, labelValues = ["generic"])
|
||||
for prune in control.prune:
|
||||
if p.knownTopics.contains(prune.topicId):
|
||||
libp2p_pubsub_broadcast_prune.inc(npeers, labelValues = [prune.topicId])
|
||||
if p.knownTopics.contains(prune.topicID):
|
||||
libp2p_pubsub_broadcast_prune.inc(npeers, labelValues = [prune.topicID])
|
||||
else:
|
||||
libp2p_pubsub_broadcast_prune.inc(npeers, labelValues = ["generic"])
|
||||
|
||||
@@ -252,29 +251,27 @@ proc updateMetrics*(p: PubSub, rpcMsg: RPCMsg) =
|
||||
libp2p_pubsub_received_unsubscriptions.inc(labelValues = ["generic"])
|
||||
|
||||
for i in 0..<rpcMsg.messages.len():
|
||||
template smsg: untyped = rpcMsg.messages[i]
|
||||
for j in 0..<smsg.topicIds.len():
|
||||
template topic: untyped = smsg.topicIds[j]
|
||||
if p.knownTopics.contains(topic):
|
||||
libp2p_pubsub_received_messages.inc(labelValues = [topic])
|
||||
else:
|
||||
libp2p_pubsub_received_messages.inc(labelValues = ["generic"])
|
||||
let topic = rpcMsg.messages[i].topic
|
||||
if p.knownTopics.contains(topic):
|
||||
libp2p_pubsub_received_messages.inc(labelValues = [topic])
|
||||
else:
|
||||
libp2p_pubsub_received_messages.inc(labelValues = ["generic"])
|
||||
|
||||
rpcMsg.control.withValue(control):
|
||||
libp2p_pubsub_received_iwant.inc(control.iwant.len.int64)
|
||||
for ihave in control.ihave:
|
||||
if p.knownTopics.contains(ihave.topicId):
|
||||
libp2p_pubsub_received_ihave.inc(labelValues = [ihave.topicId])
|
||||
if p.knownTopics.contains(ihave.topicID):
|
||||
libp2p_pubsub_received_ihave.inc(labelValues = [ihave.topicID])
|
||||
else:
|
||||
libp2p_pubsub_received_ihave.inc(labelValues = ["generic"])
|
||||
for graft in control.graft:
|
||||
if p.knownTopics.contains(graft.topicId):
|
||||
libp2p_pubsub_received_graft.inc(labelValues = [graft.topicId])
|
||||
if p.knownTopics.contains(graft.topicID):
|
||||
libp2p_pubsub_received_graft.inc(labelValues = [graft.topicID])
|
||||
else:
|
||||
libp2p_pubsub_received_graft.inc(labelValues = ["generic"])
|
||||
for prune in control.prune:
|
||||
if p.knownTopics.contains(prune.topicId):
|
||||
libp2p_pubsub_received_prune.inc(labelValues = [prune.topicId])
|
||||
if p.knownTopics.contains(prune.topicID):
|
||||
libp2p_pubsub_received_prune.inc(labelValues = [prune.topicID])
|
||||
else:
|
||||
libp2p_pubsub_received_prune.inc(labelValues = ["generic"])
|
||||
|
||||
@@ -289,11 +286,14 @@ method onNewPeer(p: PubSub, peer: PubSubPeer) {.base, gcsafe.} = discard
|
||||
method onPubSubPeerEvent*(p: PubSub, peer: PubSubPeer, event: PubSubPeerEvent) {.base, gcsafe.} =
|
||||
# Peer event is raised for the send connection in particular
|
||||
case event.kind
|
||||
of PubSubPeerEventKind.Connected:
|
||||
of PubSubPeerEventKind.StreamOpened:
|
||||
if p.topics.len > 0:
|
||||
p.sendSubs(peer, toSeq(p.topics.keys), true)
|
||||
of PubSubPeerEventKind.Disconnected:
|
||||
of PubSubPeerEventKind.StreamClosed:
|
||||
discard
|
||||
of PubSubPeerEventKind.DisconnectionRequested:
|
||||
discard
|
||||
|
||||
|
||||
method getOrCreatePeer*(
|
||||
p: PubSub,
|
||||
@@ -517,7 +517,7 @@ method addValidator*(p: PubSub,
|
||||
## will be sent to `hook`. `hook` can return either `Accept`,
|
||||
## `Ignore` or `Reject` (which can descore the peer)
|
||||
for t in topic:
|
||||
trace "adding validator for topic", topicId = t
|
||||
trace "adding validator for topic", topic = t
|
||||
p.validators.mgetOrPut(t, HashSet[ValidatorHandler]()).incl(hook)
|
||||
|
||||
method removeValidator*(p: PubSub,
|
||||
@@ -532,13 +532,13 @@ method removeValidator*(p: PubSub,
|
||||
method validate*(p: PubSub, message: Message): Future[ValidationResult] {.async, base.} =
|
||||
var pending: seq[Future[ValidationResult]]
|
||||
trace "about to validate message"
|
||||
for topic in message.topicIds:
|
||||
trace "looking for validators on topic", topicId = topic,
|
||||
registered = toSeq(p.validators.keys)
|
||||
if topic in p.validators:
|
||||
trace "running validators for topic", topicId = topic
|
||||
for validator in p.validators[topic]:
|
||||
pending.add(validator(topic, message))
|
||||
let topic = message.topic
|
||||
trace "looking for validators on topic",
|
||||
topic = topic, registered = toSeq(p.validators.keys)
|
||||
if topic in p.validators:
|
||||
trace "running validators for topic", topic = topic
|
||||
for validator in p.validators[topic]:
|
||||
pending.add(validator(topic, message))
|
||||
|
||||
result = ValidationResult.Accept
|
||||
let futs = await allFinished(pending)
|
||||
|
||||
@@ -35,6 +35,11 @@ when defined(pubsubpeer_queue_metrics):
|
||||
declareGauge(libp2p_gossipsub_priority_queue_size, "the number of messages in the priority queue", labels = ["id"])
|
||||
declareGauge(libp2p_gossipsub_non_priority_queue_size, "the number of messages in the non-priority queue", labels = ["id"])
|
||||
|
||||
declareCounter(libp2p_pubsub_disconnects_over_non_priority_queue_limit, "number of peers disconnected due to over non-prio queue capacity")
|
||||
|
||||
const
|
||||
DefaultMaxNumElementsInNonPriorityQueue* = 1024
|
||||
|
||||
type
|
||||
PeerRateLimitError* = object of CatchableError
|
||||
|
||||
@@ -43,8 +48,9 @@ type
|
||||
onSend*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe, raises: [].}
|
||||
|
||||
PubSubPeerEventKind* {.pure.} = enum
|
||||
Connected
|
||||
Disconnected
|
||||
StreamOpened
|
||||
StreamClosed
|
||||
DisconnectionRequested # tells gossipsub that the transport connection to the peer should be closed
|
||||
|
||||
PubSubPeerEvent* = object
|
||||
kind*: PubSubPeerEventKind
|
||||
@@ -74,7 +80,10 @@ type
|
||||
|
||||
score*: float64
|
||||
sentIHaves*: Deque[HashSet[MessageId]]
|
||||
heDontWants*: Deque[HashSet[MessageId]]
|
||||
heDontWants*: Deque[HashSet[SaltedId]]
|
||||
## IDONTWANT contains unvalidated message id:s which may be long and/or
|
||||
## expensive to look up, so we apply the same salting to them as during
|
||||
## unvalidated message processing
|
||||
iHaveBudget*: int
|
||||
pingBudget*: int
|
||||
maxMessageSize: int
|
||||
@@ -83,6 +92,8 @@ type
|
||||
overheadRateLimitOpt*: Opt[TokenBucket]
|
||||
|
||||
rpcmessagequeue: RpcMessageQueue
|
||||
maxNumElementsInNonPriorityQueue*: int # The max number of elements allowed in the non-priority queue.
|
||||
disconnected: bool
|
||||
|
||||
RPCHandler* = proc(peer: PubSubPeer, data: seq[byte]): Future[void]
|
||||
{.gcsafe, raises: [].}
|
||||
@@ -181,6 +192,24 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
|
||||
debug "exiting pubsub read loop",
|
||||
conn, peer = p, closed = conn.closed
|
||||
|
||||
proc closeSendConn(p: PubSubPeer, event: PubSubPeerEventKind) {.async.} =
|
||||
if p.sendConn != nil:
|
||||
trace "Removing send connection", p, conn = p.sendConn
|
||||
await p.sendConn.close()
|
||||
p.sendConn = nil
|
||||
|
||||
if not p.connectedFut.finished:
|
||||
p.connectedFut.complete()
|
||||
|
||||
try:
|
||||
if p.onEvent != nil:
|
||||
p.onEvent(p, PubSubPeerEvent(kind: event))
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
debug "Errors during diconnection events", error = exc.msg
|
||||
# don't cleanup p.address else we leak some gossip stat table
|
||||
|
||||
proc connectOnce(p: PubSubPeer): Future[void] {.async.} =
|
||||
try:
|
||||
if p.connectedFut.finished:
|
||||
@@ -203,27 +232,11 @@ proc connectOnce(p: PubSubPeer): Future[void] {.async.} =
|
||||
p.address = if p.sendConn.observedAddr.isSome: some(p.sendConn.observedAddr.get) else: none(MultiAddress)
|
||||
|
||||
if p.onEvent != nil:
|
||||
p.onEvent(p, PubSubPeerEvent(kind: PubSubPeerEventKind.Connected))
|
||||
p.onEvent(p, PubSubPeerEvent(kind: PubSubPeerEventKind.StreamOpened))
|
||||
|
||||
await handle(p, newConn)
|
||||
finally:
|
||||
if p.sendConn != nil:
|
||||
trace "Removing send connection", p, conn = p.sendConn
|
||||
await p.sendConn.close()
|
||||
p.sendConn = nil
|
||||
|
||||
if not p.connectedFut.finished:
|
||||
p.connectedFut.complete()
|
||||
|
||||
try:
|
||||
if p.onEvent != nil:
|
||||
p.onEvent(p, PubSubPeerEvent(kind: PubSubPeerEventKind.Disconnected))
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
debug "Errors during diconnection events", error = exc.msg
|
||||
|
||||
# don't cleanup p.address else we leak some gossip stat table
|
||||
await p.closeSendConn(PubSubPeerEventKind.StreamClosed)
|
||||
|
||||
proc connectImpl(p: PubSubPeer) {.async.} =
|
||||
try:
|
||||
@@ -231,6 +244,10 @@ proc connectImpl(p: PubSubPeer) {.async.} =
|
||||
# send connection might get disconnected due to a timeout or an unrelated
|
||||
# issue so we try to get a new on
|
||||
while true:
|
||||
if p.disconnected:
|
||||
if not p.connectedFut.finished:
|
||||
p.connectedFut.complete()
|
||||
return
|
||||
await connectOnce(p)
|
||||
except CatchableError as exc: # never cancelled
|
||||
debug "Could not establish send connection", msg = exc.msg
|
||||
@@ -247,9 +264,8 @@ proc hasSendConn*(p: PubSubPeer): bool =
|
||||
template sendMetrics(msg: RPCMsg): untyped =
|
||||
when defined(libp2p_expensive_metrics):
|
||||
for x in msg.messages:
|
||||
for t in x.topicIds:
|
||||
# metrics
|
||||
libp2p_pubsub_sent_messages.inc(labelValues = [$p.peerId, t])
|
||||
# metrics
|
||||
libp2p_pubsub_sent_messages.inc(labelValues = [$p.peerId, x.topic])
|
||||
|
||||
proc clearSendPriorityQueue(p: PubSubPeer) =
|
||||
if p.rpcmessagequeue.sendPriorityQueue.len == 0:
|
||||
@@ -288,7 +304,7 @@ proc sendMsgSlow(p: PubSubPeer, msg: seq[byte]) {.async.} =
|
||||
if p.sendConn == nil:
|
||||
# Wait for a send conn to be setup. `connectOnce` will
|
||||
# complete this even if the sendConn setup failed
|
||||
await p.connectedFut
|
||||
discard await race(p.connectedFut)
|
||||
|
||||
var conn = p.sendConn
|
||||
if conn == nil or conn.closed():
|
||||
@@ -323,14 +339,21 @@ proc sendEncoded*(p: PubSubPeer, msg: seq[byte], isHighPriority: bool): Future[v
|
||||
## priority messages have been sent.
|
||||
doAssert(not isNil(p), "pubsubpeer nil!")
|
||||
|
||||
p.clearSendPriorityQueue()
|
||||
|
||||
# When queues are empty, skipping the non-priority queue for low priority
|
||||
# messages reduces latency
|
||||
let emptyQueues =
|
||||
(p.rpcmessagequeue.sendPriorityQueue.len() +
|
||||
p.rpcmessagequeue.nonPriorityQueue.len()) == 0
|
||||
|
||||
if msg.len <= 0:
|
||||
debug "empty message, skipping", p, msg = shortLog(msg)
|
||||
Future[void].completed()
|
||||
elif msg.len > p.maxMessageSize:
|
||||
info "trying to send a msg too big for pubsub", maxSize=p.maxMessageSize, msgSize=msg.len
|
||||
Future[void].completed()
|
||||
elif isHighPriority:
|
||||
p.clearSendPriorityQueue()
|
||||
elif isHighPriority or emptyQueues:
|
||||
let f = p.sendMsg(msg)
|
||||
if not f.finished:
|
||||
p.rpcmessagequeue.sendPriorityQueue.addLast(f)
|
||||
@@ -338,10 +361,18 @@ proc sendEncoded*(p: PubSubPeer, msg: seq[byte], isHighPriority: bool): Future[v
|
||||
libp2p_gossipsub_priority_queue_size.inc(labelValues = [$p.peerId])
|
||||
f
|
||||
else:
|
||||
let f = p.rpcmessagequeue.nonPriorityQueue.addLast(msg)
|
||||
when defined(pubsubpeer_queue_metrics):
|
||||
libp2p_gossipsub_non_priority_queue_size.inc(labelValues = [$p.peerId])
|
||||
f
|
||||
if len(p.rpcmessagequeue.nonPriorityQueue) >= p.maxNumElementsInNonPriorityQueue:
|
||||
if not p.disconnected:
|
||||
p.disconnected = true
|
||||
libp2p_pubsub_disconnects_over_non_priority_queue_limit.inc()
|
||||
p.closeSendConn(PubSubPeerEventKind.DisconnectionRequested)
|
||||
else:
|
||||
Future[void].completed()
|
||||
else:
|
||||
let f = p.rpcmessagequeue.nonPriorityQueue.addLast(msg)
|
||||
when defined(pubsubpeer_queue_metrics):
|
||||
libp2p_gossipsub_non_priority_queue_size.inc(labelValues = [$p.peerId])
|
||||
f
|
||||
|
||||
iterator splitRPCMsg(peer: PubSubPeer, rpcMsg: RPCMsg, maxSize: int, anonymize: bool): seq[byte] =
|
||||
## This iterator takes an `RPCMsg` and sequentially repackages its Messages into new `RPCMsg` instances.
|
||||
@@ -431,7 +462,9 @@ proc sendNonPriorityTask(p: PubSubPeer) {.async.} =
|
||||
# clearSendPriorityQueue ensures we're not waiting for an already-finished
|
||||
# future
|
||||
if p.rpcmessagequeue.sendPriorityQueue.len > 0:
|
||||
await p.rpcmessagequeue.sendPriorityQueue[^1]
|
||||
# `race` prevents `p.rpcmessagequeue.sendPriorityQueue[^1]` from being
|
||||
# cancelled when this task is cancelled
|
||||
discard await race(p.rpcmessagequeue.sendPriorityQueue[^1])
|
||||
when defined(pubsubpeer_queue_metrics):
|
||||
libp2p_gossipsub_non_priority_queue_size.dec(labelValues = [$p.peerId])
|
||||
await p.sendMsg(msg)
|
||||
@@ -456,7 +489,7 @@ proc stopSendNonPriorityTask*(p: PubSubPeer) =
|
||||
proc new(T: typedesc[RpcMessageQueue]): T =
|
||||
return T(
|
||||
sendPriorityQueue: initDeque[Future[void]](),
|
||||
nonPriorityQueue: newAsyncQueue[seq[byte]](),
|
||||
nonPriorityQueue: newAsyncQueue[seq[byte]]()
|
||||
)
|
||||
|
||||
proc new*(
|
||||
@@ -466,6 +499,7 @@ proc new*(
|
||||
onEvent: OnEvent,
|
||||
codec: string,
|
||||
maxMessageSize: int,
|
||||
maxNumElementsInNonPriorityQueue: int = DefaultMaxNumElementsInNonPriorityQueue,
|
||||
overheadRateLimitOpt: Opt[TokenBucket] = Opt.none(TokenBucket)): T =
|
||||
|
||||
result = T(
|
||||
@@ -477,7 +511,8 @@ proc new*(
|
||||
maxMessageSize: maxMessageSize,
|
||||
overheadRateLimitOpt: overheadRateLimitOpt,
|
||||
rpcmessagequeue: RpcMessageQueue.new(),
|
||||
maxNumElementsInNonPriorityQueue: maxNumElementsInNonPriorityQueue
|
||||
)
|
||||
result.sentIHaves.addFirst(default(HashSet[MessageId]))
|
||||
result.heDontWants.addFirst(default(HashSet[MessageId]))
|
||||
result.heDontWants.addFirst(default(HashSet[SaltedId]))
|
||||
result.startSendNonPriorityTask()
|
||||
|
||||
@@ -63,7 +63,7 @@ proc init*(
|
||||
seqno: Option[uint64],
|
||||
sign: bool = true): Message
|
||||
{.gcsafe, raises: [LPError].} =
|
||||
var msg = Message(data: data, topicIDs: @[topic])
|
||||
var msg = Message(data: data, topic: topic)
|
||||
|
||||
# order matters, we want to include seqno in the signature
|
||||
seqno.withValue(seqn):
|
||||
@@ -87,7 +87,7 @@ proc init*(
|
||||
topic: string,
|
||||
seqno: Option[uint64]): Message
|
||||
{.gcsafe, raises: [LPError].} =
|
||||
var msg = Message(data: data, topicIDs: @[topic])
|
||||
var msg = Message(data: data, topic: topic)
|
||||
msg.fromPeer = peerId
|
||||
|
||||
seqno.withValue(seqn):
|
||||
|
||||
@@ -9,8 +9,8 @@
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import options, sequtils, sugar
|
||||
import "../../.."/[
|
||||
import options, sequtils
|
||||
import ../../../[
|
||||
peerid,
|
||||
routing_record,
|
||||
utility
|
||||
@@ -27,52 +27,58 @@ proc expectedFields[T](t: typedesc[T], existingFieldNames: seq[string]) {.raises
|
||||
raise newException(CatchableError, $T & " fields changed, please search for and revise all relevant procs. New fields: " & $fieldNames)
|
||||
|
||||
type
|
||||
PeerInfoMsg* = object
|
||||
peerId*: PeerId
|
||||
signedPeerRecord*: seq[byte]
|
||||
PeerInfoMsg* = object
|
||||
peerId*: PeerId
|
||||
signedPeerRecord*: seq[byte]
|
||||
|
||||
SubOpts* = object
|
||||
subscribe*: bool
|
||||
topic*: string
|
||||
SubOpts* = object
|
||||
subscribe*: bool
|
||||
topic*: string
|
||||
|
||||
MessageId* = seq[byte]
|
||||
MessageId* = seq[byte]
|
||||
|
||||
Message* = object
|
||||
fromPeer*: PeerId
|
||||
data*: seq[byte]
|
||||
seqno*: seq[byte]
|
||||
topicIds*: seq[string]
|
||||
signature*: seq[byte]
|
||||
key*: seq[byte]
|
||||
SaltedId* = object
|
||||
# Salted hash of message ID - used instead of the ordinary message ID to
|
||||
# avoid hash poisoning attacks and to make memory usage more predictable
|
||||
# with respect to the variable-length message id
|
||||
data*: MDigest[256]
|
||||
|
||||
ControlMessage* = object
|
||||
ihave*: seq[ControlIHave]
|
||||
iwant*: seq[ControlIWant]
|
||||
graft*: seq[ControlGraft]
|
||||
prune*: seq[ControlPrune]
|
||||
idontwant*: seq[ControlIWant]
|
||||
Message* = object
|
||||
fromPeer*: PeerId
|
||||
data*: seq[byte]
|
||||
seqno*: seq[byte]
|
||||
topic*: string
|
||||
signature*: seq[byte]
|
||||
key*: seq[byte]
|
||||
|
||||
ControlIHave* = object
|
||||
topicId*: string
|
||||
messageIds*: seq[MessageId]
|
||||
ControlMessage* = object
|
||||
ihave*: seq[ControlIHave]
|
||||
iwant*: seq[ControlIWant]
|
||||
graft*: seq[ControlGraft]
|
||||
prune*: seq[ControlPrune]
|
||||
idontwant*: seq[ControlIWant]
|
||||
|
||||
ControlIWant* = object
|
||||
messageIds*: seq[MessageId]
|
||||
ControlIHave* = object
|
||||
topicID*: string
|
||||
messageIDs*: seq[MessageId]
|
||||
|
||||
ControlGraft* = object
|
||||
topicId*: string
|
||||
ControlIWant* = object
|
||||
messageIDs*: seq[MessageId]
|
||||
|
||||
ControlPrune* = object
|
||||
topicId*: string
|
||||
peers*: seq[PeerInfoMsg]
|
||||
backoff*: uint64
|
||||
ControlGraft* = object
|
||||
topicID*: string
|
||||
|
||||
RPCMsg* = object
|
||||
subscriptions*: seq[SubOpts]
|
||||
messages*: seq[Message]
|
||||
control*: Option[ControlMessage]
|
||||
ping*: seq[byte]
|
||||
pong*: seq[byte]
|
||||
ControlPrune* = object
|
||||
topicID*: string
|
||||
peers*: seq[PeerInfoMsg]
|
||||
backoff*: uint64
|
||||
|
||||
RPCMsg* = object
|
||||
subscriptions*: seq[SubOpts]
|
||||
messages*: seq[Message]
|
||||
control*: Option[ControlMessage]
|
||||
ping*: seq[byte]
|
||||
pong*: seq[byte]
|
||||
|
||||
func withSubs*(
|
||||
T: type RPCMsg, topics: openArray[string], subscribe: bool): T =
|
||||
@@ -81,23 +87,23 @@ func withSubs*(
|
||||
|
||||
func shortLog*(s: ControlIHave): auto =
|
||||
(
|
||||
topicId: s.topicId.shortLog,
|
||||
messageIds: mapIt(s.messageIds, it.shortLog)
|
||||
topic: s.topicID.shortLog,
|
||||
messageIDs: mapIt(s.messageIDs, it.shortLog)
|
||||
)
|
||||
|
||||
func shortLog*(s: ControlIWant): auto =
|
||||
(
|
||||
messageIds: mapIt(s.messageIds, it.shortLog)
|
||||
messageIDs: mapIt(s.messageIDs, it.shortLog)
|
||||
)
|
||||
|
||||
func shortLog*(s: ControlGraft): auto =
|
||||
(
|
||||
topicId: s.topicId.shortLog
|
||||
topic: s.topicID.shortLog
|
||||
)
|
||||
|
||||
func shortLog*(s: ControlPrune): auto =
|
||||
(
|
||||
topicId: s.topicId.shortLog
|
||||
topic: s.topicID.shortLog
|
||||
)
|
||||
|
||||
func shortLog*(c: ControlMessage): auto =
|
||||
@@ -113,7 +119,7 @@ func shortLog*(msg: Message): auto =
|
||||
fromPeer: msg.fromPeer.shortLog,
|
||||
data: msg.data.shortLog,
|
||||
seqno: msg.seqno.shortLog,
|
||||
topicIds: $msg.topicIds,
|
||||
topic: msg.topic,
|
||||
signature: msg.signature.shortLog,
|
||||
key: msg.key.shortLog
|
||||
)
|
||||
@@ -133,35 +139,35 @@ static: expectedFields(SubOpts, @["subscribe", "topic"])
|
||||
proc byteSize(subOpts: SubOpts): int =
|
||||
1 + subOpts.topic.len # 1 byte for the bool
|
||||
|
||||
static: expectedFields(Message, @["fromPeer", "data", "seqno", "topicIds", "signature", "key"])
|
||||
static: expectedFields(Message, @["fromPeer", "data", "seqno", "topic", "signature", "key"])
|
||||
proc byteSize*(msg: Message): int =
|
||||
msg.fromPeer.len + msg.data.len + msg.seqno.len +
|
||||
msg.signature.len + msg.key.len + msg.topicIds.foldl(a + b.len, 0)
|
||||
msg.fromPeer.len + msg.data.len + msg.seqno.len + msg.signature.len + msg.key.len +
|
||||
msg.topic.len
|
||||
|
||||
proc byteSize*(msgs: seq[Message]): int =
|
||||
msgs.foldl(a + b.byteSize, 0)
|
||||
|
||||
static: expectedFields(ControlIHave, @["topicId", "messageIds"])
|
||||
static: expectedFields(ControlIHave, @["topicID", "messageIDs"])
|
||||
proc byteSize(controlIHave: ControlIHave): int =
|
||||
controlIHave.topicId.len + controlIHave.messageIds.foldl(a + b.len, 0)
|
||||
controlIHave.topicID.len + controlIHave.messageIDs.foldl(a + b.len, 0)
|
||||
|
||||
proc byteSize*(ihaves: seq[ControlIHave]): int =
|
||||
ihaves.foldl(a + b.byteSize, 0)
|
||||
|
||||
static: expectedFields(ControlIWant, @["messageIds"])
|
||||
static: expectedFields(ControlIWant, @["messageIDs"])
|
||||
proc byteSize(controlIWant: ControlIWant): int =
|
||||
controlIWant.messageIds.foldl(a + b.len, 0)
|
||||
controlIWant.messageIDs.foldl(a + b.len, 0)
|
||||
|
||||
proc byteSize*(iwants: seq[ControlIWant]): int =
|
||||
iwants.foldl(a + b.byteSize, 0)
|
||||
|
||||
static: expectedFields(ControlGraft, @["topicId"])
|
||||
static: expectedFields(ControlGraft, @["topicID"])
|
||||
proc byteSize(controlGraft: ControlGraft): int =
|
||||
controlGraft.topicId.len
|
||||
controlGraft.topicID.len
|
||||
|
||||
static: expectedFields(ControlPrune, @["topicId", "peers", "backoff"])
|
||||
static: expectedFields(ControlPrune, @["topicID", "peers", "backoff"])
|
||||
proc byteSize(controlPrune: ControlPrune): int =
|
||||
controlPrune.topicId.len + controlPrune.peers.foldl(a + b.byteSize, 0) + 8 # 8 bytes for uint64
|
||||
controlPrune.topicID.len + controlPrune.peers.foldl(a + b.byteSize, 0) + 8 # 8 bytes for uint64
|
||||
|
||||
static: expectedFields(ControlMessage, @["ihave", "iwant", "graft", "prune", "idontwant"])
|
||||
proc byteSize(control: ControlMessage): int =
|
||||
|
||||
@@ -29,7 +29,7 @@ when defined(libp2p_protobuf_metrics):
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, graft: ControlGraft) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, graft.topicId)
|
||||
ipb.write(1, graft.topicID)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
@@ -45,7 +45,7 @@ proc write*(pb: var ProtoBuffer, field: int, infoMsg: PeerInfoMsg) =
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, prune: ControlPrune) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, prune.topicId)
|
||||
ipb.write(1, prune.topicID)
|
||||
for peer in prune.peers:
|
||||
ipb.write(2, peer)
|
||||
ipb.write(3, prune.backoff)
|
||||
@@ -57,8 +57,8 @@ proc write*(pb: var ProtoBuffer, field: int, prune: ControlPrune) =
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, ihave: ControlIHave) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, ihave.topicId)
|
||||
for mid in ihave.messageIds:
|
||||
ipb.write(1, ihave.topicID)
|
||||
for mid in ihave.messageIDs:
|
||||
ipb.write(2, mid)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
@@ -68,7 +68,7 @@ proc write*(pb: var ProtoBuffer, field: int, ihave: ControlIHave) =
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, iwant: ControlIWant) =
|
||||
var ipb = initProtoBuffer()
|
||||
for mid in iwant.messageIds:
|
||||
for mid in iwant.messageIDs:
|
||||
ipb.write(1, mid)
|
||||
if len(ipb.buffer) > 0:
|
||||
ipb.finish()
|
||||
@@ -110,8 +110,7 @@ proc encodeMessage*(msg: Message, anonymize: bool): seq[byte] =
|
||||
pb.write(2, msg.data)
|
||||
if len(msg.seqno) > 0 and not anonymize:
|
||||
pb.write(3, msg.seqno)
|
||||
for topic in msg.topicIds:
|
||||
pb.write(4, topic)
|
||||
pb.write(4, msg.topic)
|
||||
if len(msg.signature) > 0 and not anonymize:
|
||||
pb.write(5, msg.signature)
|
||||
if len(msg.key) > 0 and not anonymize:
|
||||
@@ -133,10 +132,10 @@ proc decodeGraft*(pb: ProtoBuffer): ProtoResult[ControlGraft] {.
|
||||
|
||||
trace "decodeGraft: decoding message"
|
||||
var control = ControlGraft()
|
||||
if ? pb.getField(1, control.topicId):
|
||||
trace "decodeGraft: read topicId", topic_id = control.topicId
|
||||
if ? pb.getField(1, control.topicID):
|
||||
trace "decodeGraft: read topicID", topicID = control.topicID
|
||||
else:
|
||||
trace "decodeGraft: topicId is missing"
|
||||
trace "decodeGraft: topicID is missing"
|
||||
ok(control)
|
||||
|
||||
proc decodePeerInfoMsg*(pb: ProtoBuffer): ProtoResult[PeerInfoMsg] {.
|
||||
@@ -160,10 +159,10 @@ proc decodePrune*(pb: ProtoBuffer): ProtoResult[ControlPrune] {.
|
||||
|
||||
trace "decodePrune: decoding message"
|
||||
var control = ControlPrune()
|
||||
if ? pb.getField(1, control.topicId):
|
||||
trace "decodePrune: read topicId", topic_id = control.topicId
|
||||
if ? pb.getField(1, control.topicID):
|
||||
trace "decodePrune: read topicID", topic = control.topicID
|
||||
else:
|
||||
trace "decodePrune: topicId is missing"
|
||||
trace "decodePrune: topicID is missing"
|
||||
var bpeers: seq[seq[byte]]
|
||||
if ? pb.getRepeatedField(2, bpeers):
|
||||
for bpeer in bpeers:
|
||||
@@ -179,12 +178,12 @@ proc decodeIHave*(pb: ProtoBuffer): ProtoResult[ControlIHave] {.
|
||||
|
||||
trace "decodeIHave: decoding message"
|
||||
var control = ControlIHave()
|
||||
if ? pb.getField(1, control.topicId):
|
||||
trace "decodeIHave: read topicId", topic_id = control.topicId
|
||||
if ? pb.getField(1, control.topicID):
|
||||
trace "decodeIHave: read topicID", topic = control.topicID
|
||||
else:
|
||||
trace "decodeIHave: topicId is missing"
|
||||
if ? pb.getRepeatedField(2, control.messageIds):
|
||||
trace "decodeIHave: read messageIDs", message_ids = control.messageIds
|
||||
trace "decodeIHave: topicID is missing"
|
||||
if ? pb.getRepeatedField(2, control.messageIDs):
|
||||
trace "decodeIHave: read messageIDs", message_ids = control.messageIDs
|
||||
else:
|
||||
trace "decodeIHave: no messageIDs"
|
||||
ok(control)
|
||||
@@ -195,8 +194,8 @@ proc decodeIWant*(pb: ProtoBuffer): ProtoResult[ControlIWant] {.inline.} =
|
||||
|
||||
trace "decodeIWant: decoding message"
|
||||
var control = ControlIWant()
|
||||
if ? pb.getRepeatedField(1, control.messageIds):
|
||||
trace "decodeIWant: read messageIDs", message_ids = control.messageIds
|
||||
if ? pb.getRepeatedField(1, control.messageIDs):
|
||||
trace "decodeIWant: read messageIDs", message_ids = control.messageIDs
|
||||
else:
|
||||
trace "decodeIWant: no messageIDs"
|
||||
ok(control)
|
||||
@@ -286,10 +285,11 @@ proc decodeMessage*(pb: ProtoBuffer): ProtoResult[Message] {.inline.} =
|
||||
trace "decodeMessage: read seqno", seqno = msg.seqno
|
||||
else:
|
||||
trace "decodeMessage: seqno is missing"
|
||||
if ? pb.getRepeatedField(4, msg.topicIds):
|
||||
trace "decodeMessage: read topics", topic_ids = msg.topicIds
|
||||
if ?pb.getField(4, msg.topic):
|
||||
trace "decodeMessage: read topic", topic = msg.topic
|
||||
else:
|
||||
trace "decodeMessage: topics are missing"
|
||||
trace "decodeMessage: topic is required"
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
if ? pb.getField(5, msg.signature):
|
||||
trace "decodeMessage: read signature", signature = msg.signature.shortLog()
|
||||
else:
|
||||
|
||||
@@ -9,12 +9,13 @@
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import std/[tables]
|
||||
|
||||
import std/[hashes, sets]
|
||||
import chronos/timer, stew/results
|
||||
|
||||
import ../../utility
|
||||
|
||||
export results
|
||||
|
||||
const Timeout* = 10.seconds # default timeout in ms
|
||||
|
||||
type
|
||||
@@ -26,20 +27,38 @@ type
|
||||
|
||||
TimedCache*[K] = object of RootObj
|
||||
head, tail: TimedEntry[K] # nim linked list doesn't allow inserting at pos
|
||||
entries: Table[K, TimedEntry[K]]
|
||||
entries: HashSet[TimedEntry[K]]
|
||||
timeout: Duration
|
||||
|
||||
func `==`*[E](a, b: TimedEntry[E]): bool =
|
||||
if isNil(a) == isNil(b):
|
||||
isNil(a) or a.key == b.key
|
||||
else:
|
||||
false
|
||||
|
||||
func hash*(a: TimedEntry): Hash =
|
||||
if isNil(a):
|
||||
default(Hash)
|
||||
else:
|
||||
hash(a[].key)
|
||||
|
||||
func expire*(t: var TimedCache, now: Moment = Moment.now()) =
|
||||
while t.head != nil and t.head.expiresAt < now:
|
||||
t.entries.del(t.head.key)
|
||||
t.entries.excl(t.head)
|
||||
t.head.prev = nil
|
||||
t.head = t.head.next
|
||||
if t.head == nil: t.tail = nil
|
||||
|
||||
func del*[K](t: var TimedCache[K], key: K): Opt[TimedEntry[K]] =
|
||||
# Removes existing key from cache, returning the previous value if present
|
||||
var item: TimedEntry[K]
|
||||
if t.entries.pop(key, item):
|
||||
let tmp = TimedEntry[K](key: key)
|
||||
if tmp in t.entries:
|
||||
let item = try:
|
||||
t.entries[tmp] # use the shared instance in the set
|
||||
except KeyError:
|
||||
raiseAssert "just checked"
|
||||
t.entries.excl(item)
|
||||
|
||||
if t.head == item: t.head = item.next
|
||||
if t.tail == item: t.tail = item.prev
|
||||
|
||||
@@ -55,14 +74,14 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
|
||||
# refreshed.
|
||||
t.expire(now)
|
||||
|
||||
var previous = t.del(k) # Refresh existing item
|
||||
|
||||
var addedAt = now
|
||||
previous.withValue(previous):
|
||||
addedAt = previous.addedAt
|
||||
let
|
||||
previous = t.del(k) # Refresh existing item
|
||||
addedAt = if previous.isSome():
|
||||
previous[].addedAt
|
||||
else:
|
||||
now
|
||||
|
||||
let node = TimedEntry[K](key: k, addedAt: addedAt, expiresAt: now + t.timeout)
|
||||
|
||||
if t.head == nil:
|
||||
t.tail = node
|
||||
t.head = t.tail
|
||||
@@ -83,16 +102,24 @@ func put*[K](t: var TimedCache[K], k: K, now = Moment.now()): bool =
|
||||
if cur == t.tail:
|
||||
t.tail = node
|
||||
|
||||
t.entries[k] = node
|
||||
t.entries.incl(node)
|
||||
|
||||
previous.isSome()
|
||||
|
||||
func contains*[K](t: TimedCache[K], k: K): bool =
|
||||
k in t.entries
|
||||
let tmp = TimedEntry[K](key: k)
|
||||
tmp in t.entries
|
||||
|
||||
func addedAt*[K](t: TimedCache[K], k: K): Moment =
|
||||
t.entries.getOrDefault(k).addedAt
|
||||
func addedAt*[K](t: var TimedCache[K], k: K): Moment =
|
||||
let tmp = TimedEntry[K](key: k)
|
||||
try:
|
||||
if tmp in t.entries: # raising is slow
|
||||
# Use shared instance from entries
|
||||
return t.entries[tmp][].addedAt
|
||||
except KeyError:
|
||||
raiseAssert "just checked"
|
||||
|
||||
default(Moment)
|
||||
|
||||
func init*[K](T: type TimedCache[K], timeout: Duration = Timeout): T =
|
||||
T(
|
||||
|
||||
@@ -1,483 +0,0 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import std/[oids, strformat]
|
||||
import bearssl/rand
|
||||
import chronos, chronicles, stew/endians2
|
||||
import nimcrypto/[hmac, sha2, sha, hash, rijndael, twofish, bcmode]
|
||||
import secure,
|
||||
../../stream/connection,
|
||||
../../peerinfo,
|
||||
../../crypto/crypto,
|
||||
../../crypto/ecnist,
|
||||
../../peerid,
|
||||
../../utility,
|
||||
../../errors
|
||||
|
||||
export hmac, sha2, sha, hash, rijndael, bcmode
|
||||
|
||||
logScope:
|
||||
topics = "libp2p secio"
|
||||
|
||||
const
|
||||
SecioCodec* = "/secio/1.0.0"
|
||||
SecioMaxMessageSize = 8 * 1024 * 1024 ## 8mb
|
||||
SecioMaxMacSize = sha512.sizeDigest
|
||||
SecioNonceSize = 16
|
||||
SecioExchanges = "P-256,P-384,P-521"
|
||||
SecioCiphers = "TwofishCTR,AES-256,AES-128"
|
||||
SecioHashes = "SHA256,SHA512"
|
||||
|
||||
type
|
||||
Secio* = ref object of Secure
|
||||
rng: ref HmacDrbgContext
|
||||
localPrivateKey: PrivateKey
|
||||
localPublicKey: PublicKey
|
||||
remotePublicKey: PublicKey
|
||||
|
||||
SecureCipherType {.pure.} = enum
|
||||
Aes128, Aes256, Twofish
|
||||
|
||||
SecureMacType {.pure.} = enum
|
||||
Sha1, Sha256, Sha512
|
||||
|
||||
SecureCipher = object
|
||||
case kind: SecureCipherType
|
||||
of SecureCipherType.Aes128:
|
||||
ctxaes128: CTR[aes128]
|
||||
of SecureCipherType.Aes256:
|
||||
ctxaes256: CTR[aes256]
|
||||
of SecureCipherType.Twofish:
|
||||
ctxtwofish256: CTR[twofish256]
|
||||
|
||||
SecureMac = object
|
||||
case kind: SecureMacType
|
||||
of SecureMacType.Sha256:
|
||||
ctxsha256: HMAC[sha256]
|
||||
of SecureMacType.Sha512:
|
||||
ctxsha512: HMAC[sha512]
|
||||
of SecureMacType.Sha1:
|
||||
ctxsha1: HMAC[sha1]
|
||||
|
||||
SecioConn = ref object of SecureConn
|
||||
writerMac: SecureMac
|
||||
readerMac: SecureMac
|
||||
writerCoder: SecureCipher
|
||||
readerCoder: SecureCipher
|
||||
|
||||
SecioError* = object of LPStreamError
|
||||
|
||||
func shortLog*(conn: SecioConn): auto =
|
||||
try:
|
||||
if conn == nil: "SecioConn(nil)"
|
||||
else: &"{shortLog(conn.peerId)}:{conn.oid}"
|
||||
except ValueError as exc:
|
||||
raiseAssert(exc.msg)
|
||||
|
||||
chronicles.formatIt(SecioConn): shortLog(it)
|
||||
|
||||
proc init(mac: var SecureMac, hash: string, key: openArray[byte]) =
|
||||
if hash == "SHA256":
|
||||
mac = SecureMac(kind: SecureMacType.Sha256)
|
||||
mac.ctxsha256.init(key)
|
||||
elif hash == "SHA512":
|
||||
mac = SecureMac(kind: SecureMacType.Sha512)
|
||||
mac.ctxsha512.init(key)
|
||||
elif hash == "SHA1":
|
||||
mac = SecureMac(kind: SecureMacType.Sha1)
|
||||
mac.ctxsha1.init(key)
|
||||
|
||||
proc update(mac: var SecureMac, data: openArray[byte]) =
|
||||
case mac.kind
|
||||
of SecureMacType.Sha256:
|
||||
update(mac.ctxsha256, data)
|
||||
of SecureMacType.Sha512:
|
||||
update(mac.ctxsha512, data)
|
||||
of SecureMacType.Sha1:
|
||||
update(mac.ctxsha1, data)
|
||||
|
||||
proc sizeDigest(mac: SecureMac): int {.inline.} =
|
||||
case mac.kind
|
||||
of SecureMacType.Sha256:
|
||||
int(mac.ctxsha256.sizeDigest())
|
||||
of SecureMacType.Sha512:
|
||||
int(mac.ctxsha512.sizeDigest())
|
||||
of SecureMacType.Sha1:
|
||||
int(mac.ctxsha1.sizeDigest())
|
||||
|
||||
proc finish(mac: var SecureMac, data: var openArray[byte]) =
|
||||
case mac.kind
|
||||
of SecureMacType.Sha256:
|
||||
discard finish(mac.ctxsha256, data)
|
||||
of SecureMacType.Sha512:
|
||||
discard finish(mac.ctxsha512, data)
|
||||
of SecureMacType.Sha1:
|
||||
discard finish(mac.ctxsha1, data)
|
||||
|
||||
proc reset(mac: var SecureMac) =
|
||||
case mac.kind
|
||||
of SecureMacType.Sha256:
|
||||
reset(mac.ctxsha256)
|
||||
of SecureMacType.Sha512:
|
||||
reset(mac.ctxsha512)
|
||||
of SecureMacType.Sha1:
|
||||
reset(mac.ctxsha1)
|
||||
|
||||
proc init(sc: var SecureCipher, cipher: string, key: openArray[byte],
|
||||
iv: openArray[byte]) {.inline.} =
|
||||
if cipher == "AES-128":
|
||||
sc = SecureCipher(kind: SecureCipherType.Aes128)
|
||||
sc.ctxaes128.init(key, iv)
|
||||
elif cipher == "AES-256":
|
||||
sc = SecureCipher(kind: SecureCipherType.Aes256)
|
||||
sc.ctxaes256.init(key, iv)
|
||||
elif cipher == "TwofishCTR":
|
||||
sc = SecureCipher(kind: SecureCipherType.Twofish)
|
||||
sc.ctxtwofish256.init(key, iv)
|
||||
|
||||
proc encrypt(cipher: var SecureCipher, input: openArray[byte],
|
||||
output: var openArray[byte]) {.inline.} =
|
||||
case cipher.kind
|
||||
of SecureCipherType.Aes128:
|
||||
cipher.ctxaes128.encrypt(input, output)
|
||||
of SecureCipherType.Aes256:
|
||||
cipher.ctxaes256.encrypt(input, output)
|
||||
of SecureCipherType.Twofish:
|
||||
cipher.ctxtwofish256.encrypt(input, output)
|
||||
|
||||
proc decrypt(cipher: var SecureCipher, input: openArray[byte],
|
||||
output: var openArray[byte]) {.inline.} =
|
||||
case cipher.kind
|
||||
of SecureCipherType.Aes128:
|
||||
cipher.ctxaes128.decrypt(input, output)
|
||||
of SecureCipherType.Aes256:
|
||||
cipher.ctxaes256.decrypt(input, output)
|
||||
of SecureCipherType.Twofish:
|
||||
cipher.ctxtwofish256.decrypt(input, output)
|
||||
|
||||
proc macCheckAndDecode(sconn: SecioConn, data: var seq[byte]): bool =
|
||||
## This procedure checks MAC of recieved message ``data`` and if message is
|
||||
## authenticated, then decrypt message.
|
||||
##
|
||||
## Procedure returns ``false`` if message is too short or MAC verification
|
||||
## failed.
|
||||
var macData: array[SecioMaxMacSize, byte]
|
||||
let macsize = sconn.readerMac.sizeDigest()
|
||||
if len(data) < macsize:
|
||||
trace "Message is shorter then MAC size", message_length = len(data),
|
||||
mac_size = macsize
|
||||
return false
|
||||
let mark = len(data) - macsize
|
||||
sconn.readerMac.update(data.toOpenArray(0, mark - 1))
|
||||
sconn.readerMac.finish(macData)
|
||||
sconn.readerMac.reset()
|
||||
if not equalMem(addr data[mark], addr macData[0], macsize):
|
||||
trace "Invalid MAC",
|
||||
calculated = toHex(macData.toOpenArray(0, macsize - 1)),
|
||||
stored = toHex(data.toOpenArray(mark, data.high))
|
||||
return false
|
||||
|
||||
sconn.readerCoder.decrypt(data.toOpenArray(0, mark - 1),
|
||||
data.toOpenArray(0, mark - 1))
|
||||
data.setLen(mark)
|
||||
true
|
||||
|
||||
proc readRawMessage(
|
||||
conn: Connection
|
||||
): Future[seq[byte]] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
while true: # Discard 0-length payloads
|
||||
var lengthBuf: array[4, byte]
|
||||
await conn.readExactly(addr lengthBuf[0], lengthBuf.len)
|
||||
let length = uint32.fromBytesBE(lengthBuf)
|
||||
|
||||
trace "Recieved message header", header = lengthBuf.shortLog, length = length
|
||||
|
||||
if length > SecioMaxMessageSize: # Verify length before casting!
|
||||
trace "Received size of message exceed limits", conn, length = length
|
||||
raise (ref SecioError)(msg: "Message exceeds maximum length")
|
||||
|
||||
if length > 0:
|
||||
var buf = newSeq[byte](int(length))
|
||||
await conn.readExactly(addr buf[0], buf.len)
|
||||
trace "Received message body",
|
||||
conn, length = buf.len, buff = buf.shortLog
|
||||
return buf
|
||||
|
||||
trace "Discarding 0-length payload", conn
|
||||
|
||||
method readMessage*(
|
||||
sconn: SecioConn
|
||||
): Future[seq[byte]] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
## Read message from channel secure connection ``sconn``.
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
logScope:
|
||||
stream_oid = $sconn.stream.oid
|
||||
var buf = await sconn.stream.readRawMessage()
|
||||
if sconn.macCheckAndDecode(buf):
|
||||
buf
|
||||
else:
|
||||
trace "Message MAC verification failed", buf = buf.shortLog
|
||||
raise (ref SecioError)(msg: "message failed MAC verification")
|
||||
|
||||
method write*(
|
||||
sconn: SecioConn,
|
||||
message: seq[byte]) {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
## Write message ``message`` to secure connection ``sconn``.
|
||||
if message.len == 0:
|
||||
return
|
||||
|
||||
var
|
||||
left = message.len
|
||||
offset = 0
|
||||
while left > 0:
|
||||
let
|
||||
chunkSize = min(left, SecioMaxMessageSize - 64)
|
||||
macsize = sconn.writerMac.sizeDigest()
|
||||
length = chunkSize + macsize
|
||||
|
||||
var msg = newSeq[byte](chunkSize + 4 + macsize)
|
||||
msg[0..<4] = uint32(length).toBytesBE()
|
||||
|
||||
sconn.writerCoder.encrypt(
|
||||
message.toOpenArray(offset, offset + chunkSize - 1),
|
||||
msg.toOpenArray(4, 4 + chunkSize - 1))
|
||||
left = left - chunkSize
|
||||
offset = offset + chunkSize
|
||||
let mo = 4 + chunkSize
|
||||
sconn.writerMac.update(msg.toOpenArray(4, 4 + chunkSize - 1))
|
||||
sconn.writerMac.finish(msg.toOpenArray(mo, mo + macsize - 1))
|
||||
sconn.writerMac.reset()
|
||||
|
||||
trace "Writing message", message = msg.shortLog, left, offset
|
||||
await sconn.stream.write(msg)
|
||||
sconn.activity = true
|
||||
|
||||
proc newSecioConn(
|
||||
conn: Connection,
|
||||
hash: string,
|
||||
cipher: string,
|
||||
secrets: Secret,
|
||||
order: int,
|
||||
remotePubKey: PublicKey): SecioConn =
|
||||
## Create new secure stream/lpstream, using specified hash algorithm ``hash``,
|
||||
## cipher algorithm ``cipher``, stretched keys ``secrets`` and order
|
||||
## ``order``.
|
||||
result = SecioConn.new(conn, conn.peerId, conn.observedAddr)
|
||||
|
||||
let i0 = if order < 0: 1 else: 0
|
||||
let i1 = if order < 0: 0 else: 1
|
||||
|
||||
trace "Writer credentials", mackey = secrets.macOpenArray(i0).shortLog,
|
||||
enckey = secrets.keyOpenArray(i0).shortLog,
|
||||
iv = secrets.ivOpenArray(i0).shortLog
|
||||
trace "Reader credentials", mackey = secrets.macOpenArray(i1).shortLog,
|
||||
enckey = secrets.keyOpenArray(i1).shortLog,
|
||||
iv = secrets.ivOpenArray(i1).shortLog
|
||||
result.writerMac.init(hash, secrets.macOpenArray(i0))
|
||||
result.readerMac.init(hash, secrets.macOpenArray(i1))
|
||||
result.writerCoder.init(cipher, secrets.keyOpenArray(i0),
|
||||
secrets.ivOpenArray(i0))
|
||||
result.readerCoder.init(cipher, secrets.keyOpenArray(i1),
|
||||
secrets.ivOpenArray(i1))
|
||||
|
||||
proc transactMessage(
|
||||
conn: Connection,
|
||||
msg: seq[byte]
|
||||
): Future[seq[byte]] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
trace "Sending message", message = msg.shortLog, length = len(msg)
|
||||
await conn.write(msg)
|
||||
await conn.readRawMessage()
|
||||
|
||||
method handshake*(
|
||||
s: Secio,
|
||||
conn: Connection,
|
||||
initiator: bool,
|
||||
peerId: Opt[PeerId]
|
||||
): Future[SecureConn] {.async: (raises: [CancelledError, LPStreamError]).} =
|
||||
let localBytesPubkey = s.localPublicKey.getBytes()
|
||||
if localBytesPubkey.isErr():
|
||||
raise (ref SecioError)(msg:
|
||||
"Failed to get local public key bytes: " & $localBytesPubkey.error())
|
||||
|
||||
let localPeerId = PeerId.init(s.localPublicKey)
|
||||
if localPeerId.isErr():
|
||||
raise (ref SecioError)(msg:
|
||||
"Failed to initialize local peer ID: " & $localPeerId.error())
|
||||
|
||||
var localNonce: array[SecioNonceSize, byte]
|
||||
hmacDrbgGenerate(s.rng[], localNonce)
|
||||
|
||||
let request = createProposal(
|
||||
localNonce, localBytesPubkey.get(),
|
||||
SecioExchanges, SecioCiphers, SecioHashes)
|
||||
|
||||
trace "Local proposal", schemes = SecioExchanges,
|
||||
ciphers = SecioCiphers,
|
||||
hashes = SecioHashes,
|
||||
pubkey = localBytesPubkey.get().shortLog,
|
||||
peer = localPeerId.get()
|
||||
|
||||
let answer = await transactMessage(conn, request)
|
||||
if len(answer) == 0:
|
||||
trace "Proposal exchange failed", conn
|
||||
raise (ref SecioError)(msg: "Proposal exchange failed")
|
||||
|
||||
var
|
||||
remoteNonce: seq[byte]
|
||||
remoteBytesPubkey: seq[byte]
|
||||
remoteExchanges: string
|
||||
remoteCiphers: string
|
||||
remoteHashes: string
|
||||
if not decodeProposal(
|
||||
answer, remoteNonce, remoteBytesPubkey, remoteExchanges,
|
||||
remoteCiphers, remoteHashes):
|
||||
trace "Remote proposal decoding failed", conn
|
||||
raise (ref SecioError)(msg: "Remote proposal decoding failed")
|
||||
|
||||
var remotePubkey: PublicKey
|
||||
if not remotePubkey.init(remoteBytesPubkey):
|
||||
trace "Remote public key incorrect or corrupted",
|
||||
pubkey = remoteBytesPubkey.shortLog
|
||||
raise (ref SecioError)(msg: "Remote public key incorrect or corrupted")
|
||||
|
||||
let remotePeerId = PeerId.init(remotePubkey)
|
||||
if remotePeerId.isErr():
|
||||
raise (ref SecioError)(msg:
|
||||
"Failed to initialize remote peer ID: " & $remotePeerId.error())
|
||||
|
||||
peerId.withValue(targetPid):
|
||||
if not targetPid.validate():
|
||||
raise (ref SecioError)(msg: "Failed to validate expected peerId.")
|
||||
|
||||
if remotePeerId.get() != targetPid:
|
||||
raise (ref SecioError)(msg: "Peer ids don't match!")
|
||||
conn.peerId = remotePeerId.get()
|
||||
let order = getOrder(
|
||||
remoteBytesPubkey, localNonce, localBytesPubkey.get(), remoteNonce)
|
||||
if order.isErr():
|
||||
raise (ref SecioError)(msg: "Failed to get order: " & $order.error())
|
||||
trace "Remote proposal", schemes = remoteExchanges, ciphers = remoteCiphers,
|
||||
hashes = remoteHashes,
|
||||
pubkey = remoteBytesPubkey.shortLog,
|
||||
order = order.get(),
|
||||
peer = remotePeerId.get()
|
||||
|
||||
let
|
||||
scheme = selectBest(order.get(), SecioExchanges, remoteExchanges)
|
||||
cipher = selectBest(order.get(), SecioCiphers, remoteCiphers)
|
||||
hash = selectBest(order.get(), SecioHashes, remoteHashes)
|
||||
if len(scheme) == 0 or len(cipher) == 0 or len(hash) == 0:
|
||||
trace "No algorithms in common", peer = remotePeerId.get()
|
||||
raise (ref SecioError)(msg: "No algorithms in common")
|
||||
|
||||
trace "Encryption scheme selected", scheme = scheme, cipher = cipher,
|
||||
hash = hash
|
||||
|
||||
let ekeypair = ephemeral(scheme, s.rng[])
|
||||
if ekeypair.isErr():
|
||||
raise (ref SecioError)(msg:
|
||||
"Failed to create ephemeral keypair: " & $ekeypair.error())
|
||||
# We need EC public key in raw binary form
|
||||
let epubkey = ekeypair.get().pubkey.getRawBytes()
|
||||
if epubkey.isErr():
|
||||
raise (ref SecioError)(msg:
|
||||
"Failed to get ephemeral key bytes: " & $epubkey.error())
|
||||
let
|
||||
localCorpus = request[4..^1] & answer & epubkey.get()
|
||||
signature = s.localPrivateKey.sign(localCorpus)
|
||||
if signature.isErr():
|
||||
raise (ref SecioError)(msg:
|
||||
"Failed to sign local corpus: " & $signature.error())
|
||||
|
||||
let
|
||||
localExchange = createExchange(epubkey.get(), signature.get().getBytes())
|
||||
remoteExchange = await transactMessage(conn, localExchange)
|
||||
if len(remoteExchange) == 0:
|
||||
trace "Corpus exchange failed", conn
|
||||
raise (ref SecioError)(msg: "Corpus exchange failed")
|
||||
|
||||
var
|
||||
remoteEBytesPubkey: seq[byte]
|
||||
remoteEBytesSig: seq[byte]
|
||||
if not decodeExchange(remoteExchange, remoteEBytesPubkey, remoteEBytesSig):
|
||||
trace "Remote exchange decoding failed", conn
|
||||
raise (ref SecioError)(msg: "Remote exchange decoding failed")
|
||||
|
||||
var remoteESignature: Signature
|
||||
if not remoteESignature.init(remoteEBytesSig):
|
||||
trace "Remote signature incorrect or corrupted",
|
||||
signature = remoteEBytesSig.shortLog
|
||||
raise (ref SecioError)(msg: "Remote signature incorrect or corrupted")
|
||||
|
||||
let remoteCorpus = answer & request[4..^1] & remoteEBytesPubkey
|
||||
if not remoteESignature.verify(remoteCorpus, remotePubkey):
|
||||
trace "Signature verification failed", scheme = $remotePubkey.scheme,
|
||||
signature = $remoteESignature,
|
||||
pubkey = $remotePubkey,
|
||||
corpus = $remoteCorpus
|
||||
raise (ref SecioError)(msg: "Signature verification failed")
|
||||
|
||||
trace "Signature verified", scheme = remotePubkey.scheme
|
||||
|
||||
var remoteEPubkey: ecnist.EcPublicKey
|
||||
if not remoteEPubkey.initRaw(remoteEBytesPubkey):
|
||||
trace "Remote ephemeral public key incorrect or corrupted",
|
||||
pubkey = toHex(remoteEBytesPubkey)
|
||||
raise (ref SecioError)(msg:
|
||||
"Remote ephemeral public key incorrect or corrupted")
|
||||
|
||||
let secret = getSecret(remoteEPubkey, ekeypair.get().seckey)
|
||||
if len(secret) == 0:
|
||||
trace "Shared secret could not be created"
|
||||
raise (ref SecioError)(msg: "Shared secret could not be created")
|
||||
|
||||
trace "Shared secret calculated", secret = secret.shortLog
|
||||
|
||||
let keys = stretchKeys(cipher, hash, secret)
|
||||
|
||||
trace "Authenticated encryption parameters",
|
||||
iv0 = toHex(keys.ivOpenArray(0)),
|
||||
key0 = keys.keyOpenArray(0).shortLog,
|
||||
mac0 = keys.macOpenArray(0).shortLog,
|
||||
iv1 = keys.ivOpenArray(1).shortLog,
|
||||
key1 = keys.keyOpenArray(1).shortLog,
|
||||
mac1 = keys.macOpenArray(1).shortLog
|
||||
|
||||
# Perform Nonce exchange over encrypted channel.
|
||||
|
||||
let secioConn = newSecioConn(
|
||||
conn, hash, cipher, keys, order.get(), remotePubkey)
|
||||
await secioConn.write(remoteNonce)
|
||||
var res = await secioConn.readMessage()
|
||||
|
||||
if res != @localNonce:
|
||||
trace "Nonce verification failed", receivedNonce = res.shortLog,
|
||||
localNonce = localNonce.shortLog
|
||||
raise (ref SecioError)(msg: "Nonce verification failed")
|
||||
else:
|
||||
trace "Secure handshake succeeded"
|
||||
secioConn
|
||||
|
||||
method init(s: Secio) {.gcsafe.} =
|
||||
procCall Secure(s).init()
|
||||
s.codec = SecioCodec
|
||||
|
||||
proc new*(
|
||||
T: typedesc[Secio],
|
||||
rng: ref HmacDrbgContext,
|
||||
localPrivateKey: PrivateKey): T =
|
||||
let secio = Secio(
|
||||
rng: rng,
|
||||
localPrivateKey: localPrivateKey,
|
||||
localPublicKey: localPrivateKey.getPublicKey().expect("Invalid private key")
|
||||
)
|
||||
secio.init()
|
||||
secio
|
||||
213
libp2p/services/wildcardresolverservice.nim
Normal file
213
libp2p/services/wildcardresolverservice.nim
Normal file
@@ -0,0 +1,213 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import std/sequtils
|
||||
import stew/[byteutils, results, endians2]
|
||||
import chronos, chronos/transports/[osnet, ipnet], chronicles
|
||||
import ../[multiaddress, multicodec]
|
||||
import ../switch
|
||||
|
||||
logScope:
|
||||
topics = "libp2p wildcardresolverservice"
|
||||
|
||||
type
|
||||
WildcardAddressResolverService* = ref object of Service
|
||||
## Service used to resolve wildcard addresses of the type "0.0.0.0" for IPv4 or "::" for IPv6.
|
||||
## When used with a `Switch`, this service will be automatically set up and stopped
|
||||
## when the switch starts and stops. This is facilitated by adding the service to the switch's
|
||||
## list of services using the `.withServices(@[svc])` method in the `SwitchBuilder`.
|
||||
networkInterfaceProvider: NetworkInterfaceProvider
|
||||
## Provides a list of network addresses.
|
||||
addressMapper: AddressMapper
|
||||
## An implementation of an address mapper that takes a list of listen addresses and expands each wildcard address
|
||||
## to the respective list of interface addresses. As an example, if the listen address is 0.0.0.0:4001
|
||||
## and the machine has 2 interfaces with IPs 172.217.11.174 and 64.233.177.113, the address mapper will
|
||||
## expand the wildcard address to 172.217.11.174:4001 and 64.233.177.113:4001.
|
||||
|
||||
NetworkInterfaceProvider* = ref object of RootObj
|
||||
|
||||
proc isLoopbackOrUp(networkInterface: NetworkInterface): bool =
|
||||
if (networkInterface.ifType == IfSoftwareLoopback) or
|
||||
(networkInterface.state == StatusUp): true else: false
|
||||
|
||||
proc networkInterfaceProvider*(): NetworkInterfaceProvider =
|
||||
## Returns a new instance of `NetworkInterfaceProvider`.
|
||||
return NetworkInterfaceProvider()
|
||||
|
||||
method getAddresses*(
|
||||
networkInterfaceProvider: NetworkInterfaceProvider, addrFamily: AddressFamily
|
||||
): seq[InterfaceAddress] {.base.} =
|
||||
## This method retrieves the addresses of network interfaces based on the specified address family.
|
||||
##
|
||||
## The `getAddresses` method filters the available network interfaces to include only
|
||||
## those that are either loopback or up. It then collects all the addresses from these
|
||||
## interfaces and filters them to match the provided address family.
|
||||
##
|
||||
## Parameters:
|
||||
## - `networkInterfaceProvider`: A provider that offers access to network interfaces.
|
||||
## - `addrFamily`: The address family to filter the network addresses (e.g., `AddressFamily.IPv4` or `AddressFamily.IPv6`).
|
||||
##
|
||||
## Returns:
|
||||
## - A sequence of `InterfaceAddress` objects that match the specified address family.
|
||||
echo "Getting addresses for address family: ", addrFamily
|
||||
let
|
||||
interfaces = getInterfaces().filterIt(it.isLoopbackOrUp())
|
||||
flatInterfaceAddresses = concat(interfaces.mapIt(it.addresses))
|
||||
filteredInterfaceAddresses =
|
||||
flatInterfaceAddresses.filterIt(it.host.family == addrFamily)
|
||||
return filteredInterfaceAddresses
|
||||
|
||||
proc new*(
|
||||
T: typedesc[WildcardAddressResolverService],
|
||||
networkInterfaceProvider: NetworkInterfaceProvider = new(NetworkInterfaceProvider),
|
||||
): T =
|
||||
## This procedure initializes a new `WildcardAddressResolverService` with the provided network interface provider.
|
||||
##
|
||||
## Parameters:
|
||||
## - `T`: The type descriptor for `WildcardAddressResolverService`.
|
||||
## - `networkInterfaceProvider`: A provider that offers access to network interfaces. Defaults to a new instance of `NetworkInterfaceProvider`.
|
||||
##
|
||||
## Returns:
|
||||
## - A new instance of `WildcardAddressResolverService`.
|
||||
return T(networkInterfaceProvider: networkInterfaceProvider)
|
||||
|
||||
proc getProtocolArgument*(ma: MultiAddress, codec: MultiCodec): MaResult[seq[byte]] =
|
||||
var buffer: seq[byte]
|
||||
for item in ma:
|
||||
let
|
||||
ritem = ?item
|
||||
code = ?ritem.protoCode()
|
||||
if code == codec:
|
||||
let arg = ?ritem.protoAddress()
|
||||
return ok(arg)
|
||||
|
||||
err("Multiaddress codec has not been found")
|
||||
|
||||
proc getWildcardMultiAddresses(
|
||||
interfaceAddresses: seq[InterfaceAddress], protocol: Protocol, port: Port
|
||||
): seq[MultiAddress] =
|
||||
var addresses: seq[MultiAddress]
|
||||
for ifaddr in interfaceAddresses:
|
||||
var address = ifaddr.host
|
||||
address.port = port
|
||||
MultiAddress.init(address, protocol).withValue(maddress):
|
||||
addresses.add(maddress)
|
||||
addresses
|
||||
|
||||
proc getWildcardAddress(
|
||||
maddress: MultiAddress,
|
||||
multiCodec: MultiCodec,
|
||||
anyAddr: openArray[uint8],
|
||||
addrFamily: AddressFamily,
|
||||
port: Port,
|
||||
networkInterfaceProvider: NetworkInterfaceProvider,
|
||||
): seq[MultiAddress] =
|
||||
var addresses: seq[MultiAddress]
|
||||
maddress.getProtocolArgument(multiCodec).withValue(address):
|
||||
if address == anyAddr:
|
||||
let filteredInterfaceAddresses = networkInterfaceProvider.getAddresses(addrFamily)
|
||||
addresses.add(
|
||||
getWildcardMultiAddresses(filteredInterfaceAddresses, IPPROTO_TCP, port)
|
||||
)
|
||||
else:
|
||||
addresses.add(maddress)
|
||||
return addresses
|
||||
|
||||
proc expandWildcardAddresses(
|
||||
networkInterfaceProvider: NetworkInterfaceProvider, listenAddrs: seq[MultiAddress]
|
||||
): seq[MultiAddress] =
|
||||
var addresses: seq[MultiAddress]
|
||||
# In this loop we expand bound addresses like `0.0.0.0` and `::` to list of interface addresses.
|
||||
for listenAddr in listenAddrs:
|
||||
if TCP_IP.matchPartial(listenAddr):
|
||||
listenAddr.getProtocolArgument(multiCodec("tcp")).withValue(portArg):
|
||||
let port = Port(uint16.fromBytesBE(portArg))
|
||||
if IP4.matchPartial(listenAddr):
|
||||
let wildcardAddresses = getWildcardAddress(
|
||||
listenAddr,
|
||||
multiCodec("ip4"),
|
||||
AnyAddress.address_v4,
|
||||
AddressFamily.IPv4,
|
||||
port,
|
||||
networkInterfaceProvider,
|
||||
)
|
||||
addresses.add(wildcardAddresses)
|
||||
elif IP6.matchPartial(listenAddr):
|
||||
let wildcardAddresses = getWildcardAddress(
|
||||
listenAddr,
|
||||
multiCodec("ip6"),
|
||||
AnyAddress6.address_v6,
|
||||
AddressFamily.IPv6,
|
||||
port,
|
||||
networkInterfaceProvider,
|
||||
)
|
||||
addresses.add(wildcardAddresses)
|
||||
else:
|
||||
addresses.add(listenAddr)
|
||||
else:
|
||||
addresses.add(listenAddr)
|
||||
addresses
|
||||
|
||||
method setup*(
|
||||
self: WildcardAddressResolverService, switch: Switch
|
||||
): Future[bool] {.async.} =
|
||||
## Sets up the `WildcardAddressResolverService`.
|
||||
##
|
||||
## This method adds the address mapper to the peer's list of address mappers.
|
||||
##
|
||||
## Parameters:
|
||||
## - `self`: The instance of `WildcardAddressResolverService` being set up.
|
||||
## - `switch`: The switch context in which the service operates.
|
||||
##
|
||||
## Returns:
|
||||
## - A `Future[bool]` that resolves to `true` if the setup was successful, otherwise `false`.
|
||||
self.addressMapper = proc(
|
||||
listenAddrs: seq[MultiAddress]
|
||||
): Future[seq[MultiAddress]] {.async.} =
|
||||
return expandWildcardAddresses(self.networkInterfaceProvider, listenAddrs)
|
||||
|
||||
debug "Setting up WildcardAddressResolverService"
|
||||
let hasBeenSetup = await procCall Service(self).setup(switch)
|
||||
if hasBeenSetup:
|
||||
switch.peerInfo.addressMappers.add(self.addressMapper)
|
||||
await self.run(switch)
|
||||
return hasBeenSetup
|
||||
|
||||
method run*(self: WildcardAddressResolverService, switch: Switch) {.async, public.} =
|
||||
## Runs the WildcardAddressResolverService for a given switch.
|
||||
##
|
||||
## It updates the peer information for the provided switch by running the registered address mapper. Any other
|
||||
## address mappers that are registered with the switch will also be run.
|
||||
##
|
||||
trace "Running WildcardAddressResolverService"
|
||||
await switch.peerInfo.update()
|
||||
|
||||
method stop*(
|
||||
self: WildcardAddressResolverService, switch: Switch
|
||||
): Future[bool] {.async, public.} =
|
||||
## Stops the WildcardAddressResolverService.
|
||||
##
|
||||
## Handles the shutdown process of the WildcardAddressResolverService for a given switch.
|
||||
## It removes the address mapper from the switch's list of address mappers.
|
||||
## It then updates the peer information for the provided switch. Any wildcard address wont be resolved anymore.
|
||||
##
|
||||
## Parameters:
|
||||
## - `self`: The instance of the WildcardAddressResolverService.
|
||||
## - `switch`: The Switch object associated with the service.
|
||||
##
|
||||
## Returns:
|
||||
## - A future that resolves to `true` if the service was successfully stopped, otherwise `false`.
|
||||
debug "Stopping WildcardAddressResolverService"
|
||||
let hasBeenStopped = await procCall Service(self).stop(switch)
|
||||
if hasBeenStopped:
|
||||
switch.peerInfo.addressMappers.keepItIf(it != self.addressMapper)
|
||||
await switch.peerInfo.update()
|
||||
return hasBeenStopped
|
||||
@@ -273,6 +273,7 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises
|
||||
except CancelledError as exc:
|
||||
trace "releasing semaphore on cancellation"
|
||||
upgrades.release() # always release the slot
|
||||
return
|
||||
except CatchableError as exc:
|
||||
error "Exception in accept loop, exiting", exc = exc.msg
|
||||
upgrades.release() # always release the slot
|
||||
@@ -288,6 +289,12 @@ proc stop*(s: Switch) {.async, public.} =
|
||||
|
||||
s.started = false
|
||||
|
||||
try:
|
||||
# Stop accepting incoming connections
|
||||
await allFutures(s.acceptFuts.mapIt(it.cancelAndWait())).wait(1.seconds)
|
||||
except CatchableError as exc:
|
||||
debug "Cannot cancel accepts", error = exc.msg
|
||||
|
||||
for service in s.services:
|
||||
discard await service.stop(s)
|
||||
|
||||
@@ -302,18 +309,6 @@ proc stop*(s: Switch) {.async, public.} =
|
||||
except CatchableError as exc:
|
||||
warn "error cleaning up transports", msg = exc.msg
|
||||
|
||||
try:
|
||||
await allFutures(s.acceptFuts)
|
||||
.wait(1.seconds)
|
||||
except CatchableError as exc:
|
||||
trace "Exception while stopping accept loops", exc = exc.msg
|
||||
|
||||
# check that all futures were properly
|
||||
# stopped and otherwise cancel them
|
||||
for a in s.acceptFuts:
|
||||
if not a.finished:
|
||||
a.cancel()
|
||||
|
||||
for service in s.services:
|
||||
discard await service.stop(s)
|
||||
|
||||
|
||||
@@ -12,260 +12,327 @@
|
||||
{.push raises: [].}
|
||||
|
||||
import std/[sequtils]
|
||||
import stew/results
|
||||
import chronos, chronicles
|
||||
import transport,
|
||||
../errors,
|
||||
../wire,
|
||||
../multicodec,
|
||||
../connmanager,
|
||||
../multiaddress,
|
||||
../stream/connection,
|
||||
../stream/chronosstream,
|
||||
../upgrademngrs/upgrade,
|
||||
../utility
|
||||
import
|
||||
./transport,
|
||||
../wire,
|
||||
../multiaddress,
|
||||
../stream/connection,
|
||||
../stream/chronosstream,
|
||||
../upgrademngrs/upgrade,
|
||||
../utility
|
||||
|
||||
logScope:
|
||||
topics = "libp2p tcptransport"
|
||||
|
||||
export transport, results
|
||||
export transport, connection, upgrade
|
||||
|
||||
const
|
||||
TcpTransportTrackerName* = "libp2p.tcptransport"
|
||||
const TcpTransportTrackerName* = "libp2p.tcptransport"
|
||||
|
||||
type
|
||||
AcceptFuture = typeof(default(StreamServer).accept())
|
||||
|
||||
TcpTransport* = ref object of Transport
|
||||
servers*: seq[StreamServer]
|
||||
clients: array[Direction, seq[StreamTransport]]
|
||||
flags: set[ServerFlags]
|
||||
clientFlags: set[SocketFlags]
|
||||
acceptFuts: seq[Future[StreamTransport]]
|
||||
acceptFuts: seq[AcceptFuture]
|
||||
connectionsTimeout: Duration
|
||||
stopping: bool
|
||||
|
||||
TcpTransportError* = object of transport.TransportError
|
||||
|
||||
proc connHandler*(self: TcpTransport,
|
||||
client: StreamTransport,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
dir: Direction): Future[Connection] {.async.} =
|
||||
|
||||
trace "Handling tcp connection", address = $observedAddr,
|
||||
dir = $dir,
|
||||
clients = self.clients[Direction.In].len +
|
||||
self.clients[Direction.Out].len
|
||||
proc connHandler*(
|
||||
self: TcpTransport,
|
||||
client: StreamTransport,
|
||||
observedAddr: Opt[MultiAddress],
|
||||
dir: Direction,
|
||||
): Connection =
|
||||
trace "Handling tcp connection",
|
||||
address = $observedAddr,
|
||||
dir = $dir,
|
||||
clients = self.clients[Direction.In].len + self.clients[Direction.Out].len
|
||||
|
||||
let conn = Connection(
|
||||
ChronosStream.init(
|
||||
client = client,
|
||||
dir = dir,
|
||||
observedAddr = observedAddr,
|
||||
timeout = self.connectionsTimeout
|
||||
))
|
||||
timeout = self.connectionsTimeout,
|
||||
)
|
||||
)
|
||||
|
||||
proc onClose() {.async: (raises: []).} =
|
||||
try:
|
||||
block:
|
||||
let
|
||||
fut1 = client.join()
|
||||
fut2 = conn.join()
|
||||
try: # https://github.com/status-im/nim-chronos/issues/516
|
||||
discard await race(fut1, fut2)
|
||||
except ValueError: raiseAssert("Futures list is not empty")
|
||||
# at least one join() completed, cancel pending one, if any
|
||||
if not fut1.finished: await fut1.cancelAndWait()
|
||||
if not fut2.finished: await fut2.cancelAndWait()
|
||||
await noCancel client.join()
|
||||
|
||||
trace "Cleaning up client", addrs = $client.remoteAddress,
|
||||
conn
|
||||
trace "Cleaning up client", addrs = $client.remoteAddress, conn
|
||||
|
||||
self.clients[dir].keepItIf( it != client )
|
||||
self.clients[dir].keepItIf(it != client)
|
||||
|
||||
block:
|
||||
let
|
||||
fut1 = conn.close()
|
||||
fut2 = client.closeWait()
|
||||
await allFutures(fut1, fut2)
|
||||
if fut1.failed:
|
||||
let err = fut1.error()
|
||||
debug "Error cleaning up client", errMsg = err.msg, conn
|
||||
static: doAssert typeof(fut2).E is void # Cannot fail
|
||||
# Propagate the chronos client being closed to the connection
|
||||
# TODO This is somewhat dubious since it's the connection that owns the
|
||||
# client, but it allows the transport to close all connections when
|
||||
# shutting down (also dubious! it would make more sense that the owner
|
||||
# of all connections closes them, or the next read detects the closed
|
||||
# socket and does the right thing..)
|
||||
|
||||
trace "Cleaned up client", addrs = $client.remoteAddress,
|
||||
conn
|
||||
await conn.close()
|
||||
|
||||
except CancelledError as exc:
|
||||
let useExc {.used.} = exc
|
||||
debug "Error cleaning up client", errMsg = exc.msg, conn
|
||||
trace "Cleaned up client", addrs = $client.remoteAddress, conn
|
||||
|
||||
self.clients[dir].add(client)
|
||||
|
||||
asyncSpawn onClose()
|
||||
|
||||
return conn
|
||||
|
||||
proc new*(
|
||||
T: typedesc[TcpTransport],
|
||||
flags: set[ServerFlags] = {},
|
||||
upgrade: Upgrade,
|
||||
connectionsTimeout = 10.minutes): T {.public.} =
|
||||
T: typedesc[TcpTransport],
|
||||
flags: set[ServerFlags] = {},
|
||||
upgrade: Upgrade,
|
||||
connectionsTimeout = 10.minutes,
|
||||
): T {.public.} =
|
||||
T(
|
||||
flags: flags,
|
||||
clientFlags:
|
||||
if ServerFlags.TcpNoDelay in flags:
|
||||
{SocketFlags.TcpNoDelay}
|
||||
else:
|
||||
default(set[SocketFlags])
|
||||
,
|
||||
upgrader: upgrade,
|
||||
networkReachability: NetworkReachability.Unknown,
|
||||
connectionsTimeout: connectionsTimeout,
|
||||
)
|
||||
|
||||
let
|
||||
transport = T(
|
||||
flags: flags,
|
||||
clientFlags:
|
||||
if ServerFlags.TcpNoDelay in flags:
|
||||
compilesOr:
|
||||
{SocketFlags.TcpNoDelay}
|
||||
do:
|
||||
doAssert(false)
|
||||
default(set[SocketFlags])
|
||||
else:
|
||||
default(set[SocketFlags]),
|
||||
upgrader: upgrade,
|
||||
networkReachability: NetworkReachability.Unknown,
|
||||
connectionsTimeout: connectionsTimeout)
|
||||
method start*(self: TcpTransport, addrs: seq[MultiAddress]): Future[void] =
|
||||
## Start transport listening to the given addresses - for dial-only transports,
|
||||
## start with an empty list
|
||||
|
||||
return transport
|
||||
# TODO remove `impl` indirection throughout when `raises` is added to base
|
||||
|
||||
method start*(
|
||||
self: TcpTransport,
|
||||
addrs: seq[MultiAddress]) {.async.} =
|
||||
## listen on the transport
|
||||
##
|
||||
proc impl(
|
||||
self: TcpTransport, addrs: seq[MultiAddress]
|
||||
): Future[void] {.async: (raises: [transport.TransportError, CancelledError]).} =
|
||||
if self.running:
|
||||
warn "TCP transport already running"
|
||||
return
|
||||
|
||||
if self.running:
|
||||
warn "TCP transport already running"
|
||||
return
|
||||
|
||||
await procCall Transport(self).start(addrs)
|
||||
trace "Starting TCP transport"
|
||||
trackCounter(TcpTransportTrackerName)
|
||||
|
||||
for i, ma in addrs:
|
||||
if not self.handles(ma):
|
||||
trace "Invalid address detected, skipping!", address = ma
|
||||
continue
|
||||
trace "Starting TCP transport"
|
||||
|
||||
self.flags.incl(ServerFlags.ReusePort)
|
||||
let server = createStreamServer(
|
||||
ma = ma,
|
||||
flags = self.flags,
|
||||
udata = self)
|
||||
|
||||
# always get the resolved address in case we're bound to 0.0.0.0:0
|
||||
self.addrs[i] = MultiAddress.init(
|
||||
server.sock.getLocalAddress()
|
||||
).tryGet()
|
||||
var supported: seq[MultiAddress]
|
||||
var initialized = false
|
||||
try:
|
||||
for i, ma in addrs:
|
||||
if not self.handles(ma):
|
||||
trace "Invalid address detected, skipping!", address = ma
|
||||
continue
|
||||
|
||||
self.servers &= server
|
||||
let
|
||||
ta = initTAddress(ma).expect("valid address per handles check above")
|
||||
server =
|
||||
try:
|
||||
createStreamServer(ta, flags = self.flags)
|
||||
except common.TransportError as exc:
|
||||
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
|
||||
|
||||
trace "Listening on", address = ma
|
||||
self.servers &= server
|
||||
|
||||
method stop*(self: TcpTransport) {.async.} =
|
||||
## stop the transport
|
||||
##
|
||||
try:
|
||||
trace "Listening on", address = ma
|
||||
supported.add(
|
||||
MultiAddress.init(server.sock.getLocalAddress()).expect(
|
||||
"Can init from local address"
|
||||
)
|
||||
)
|
||||
|
||||
initialized = true
|
||||
finally:
|
||||
if not initialized:
|
||||
# Clean up partial success on exception
|
||||
await noCancel allFutures(self.servers.mapIt(it.closeWait()))
|
||||
reset(self.servers)
|
||||
|
||||
try:
|
||||
await procCall Transport(self).start(supported)
|
||||
except CatchableError:
|
||||
raiseAssert "Base method does not raise"
|
||||
|
||||
trackCounter(TcpTransportTrackerName)
|
||||
|
||||
impl(self, addrs)
|
||||
|
||||
method stop*(self: TcpTransport): Future[void] =
|
||||
## Stop the transport and close all connections it created
|
||||
proc impl(self: TcpTransport) {.async: (raises: []).} =
|
||||
trace "Stopping TCP transport"
|
||||
self.stopping = true
|
||||
defer:
|
||||
self.stopping = false
|
||||
|
||||
checkFutures(
|
||||
await allFinished(
|
||||
self.clients[Direction.In].mapIt(it.closeWait()) &
|
||||
self.clients[Direction.Out].mapIt(it.closeWait())))
|
||||
if self.running:
|
||||
# Reset the running flag
|
||||
try:
|
||||
await noCancel procCall Transport(self).stop()
|
||||
except CatchableError: # TODO remove when `accept` is annotated with raises
|
||||
raiseAssert "doesn't actually raise"
|
||||
|
||||
if not self.running:
|
||||
# Stop each server by closing the socket - this will cause all accept loops
|
||||
# to fail - since the running flag has been reset, it's also safe to close
|
||||
# all known clients since no more of them will be added
|
||||
await noCancel allFutures(
|
||||
self.servers.mapIt(it.closeWait()) &
|
||||
self.clients[Direction.In].mapIt(it.closeWait()) &
|
||||
self.clients[Direction.Out].mapIt(it.closeWait())
|
||||
)
|
||||
|
||||
self.servers = @[]
|
||||
|
||||
for acceptFut in self.acceptFuts:
|
||||
if acceptFut.completed():
|
||||
await acceptFut.value().closeWait()
|
||||
self.acceptFuts = @[]
|
||||
|
||||
if self.clients[Direction.In].len != 0 or self.clients[Direction.Out].len != 0:
|
||||
# Future updates could consider turning this warn into an assert since
|
||||
# it should never happen if the shutdown code is correct
|
||||
warn "Couldn't clean up clients",
|
||||
len = self.clients[Direction.In].len + self.clients[Direction.Out].len
|
||||
|
||||
trace "Transport stopped"
|
||||
untrackCounter(TcpTransportTrackerName)
|
||||
else:
|
||||
# For legacy reasons, `stop` on a transpart that wasn't started is
|
||||
# expected to close outgoing connections created by the transport
|
||||
warn "TCP transport already stopped"
|
||||
return
|
||||
|
||||
await procCall Transport(self).stop() # call base
|
||||
var toWait: seq[Future[void]]
|
||||
for fut in self.acceptFuts:
|
||||
if not fut.finished:
|
||||
toWait.add(fut.cancelAndWait())
|
||||
elif fut.done:
|
||||
toWait.add(fut.read().closeWait())
|
||||
doAssert self.clients[Direction.In].len == 0,
|
||||
"No incoming connections possible without start"
|
||||
await noCancel allFutures(self.clients[Direction.Out].mapIt(it.closeWait()))
|
||||
|
||||
for server in self.servers:
|
||||
server.stop()
|
||||
toWait.add(server.closeWait())
|
||||
impl(self)
|
||||
|
||||
await allFutures(toWait)
|
||||
|
||||
self.servers = @[]
|
||||
|
||||
trace "Transport stopped"
|
||||
untrackCounter(TcpTransportTrackerName)
|
||||
except CatchableError as exc:
|
||||
trace "Error shutting down tcp transport", exc = exc.msg
|
||||
|
||||
method accept*(self: TcpTransport): Future[Connection] {.async.} =
|
||||
## accept a new TCP connection
|
||||
method accept*(self: TcpTransport): Future[Connection] =
|
||||
## accept a new TCP connection, returning nil on non-fatal errors
|
||||
##
|
||||
|
||||
if not self.running:
|
||||
raise newTransportClosedError()
|
||||
|
||||
try:
|
||||
if self.acceptFuts.len <= 0:
|
||||
self.acceptFuts = self.servers.mapIt(Future[StreamTransport](it.accept()))
|
||||
## Raises an exception when the transport is broken and cannot be used for
|
||||
## accepting further connections
|
||||
# TODO returning nil for non-fatal errors is problematic in that error
|
||||
# information is lost and must be logged here instead of being
|
||||
# available to the caller - further refactoring should propagate errors
|
||||
# to the caller instead
|
||||
proc impl(
|
||||
self: TcpTransport
|
||||
): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} =
|
||||
if not self.running:
|
||||
raise newTransportClosedError()
|
||||
|
||||
if self.acceptFuts.len <= 0:
|
||||
return
|
||||
self.acceptFuts = self.servers.mapIt(it.accept())
|
||||
|
||||
let
|
||||
finished = await one(self.acceptFuts)
|
||||
finished =
|
||||
try:
|
||||
await one(self.acceptFuts)
|
||||
except ValueError:
|
||||
raise (ref TcpTransportError)(msg: "No listeners configured")
|
||||
|
||||
index = self.acceptFuts.find(finished)
|
||||
transp =
|
||||
try:
|
||||
await finished
|
||||
except TransportTooManyError as exc:
|
||||
debug "Too many files opened", exc = exc.msg
|
||||
return nil
|
||||
except TransportAbortedError as exc:
|
||||
debug "Connection aborted", exc = exc.msg
|
||||
return nil
|
||||
except TransportUseClosedError as exc:
|
||||
raise newTransportClosedError(exc)
|
||||
except TransportOsError as exc:
|
||||
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
|
||||
except common.TransportError as exc: # Needed for chronos 4.0.0 support
|
||||
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
|
||||
if not self.running: # Stopped while waiting
|
||||
await transp.closeWait()
|
||||
raise newTransportClosedError()
|
||||
|
||||
self.acceptFuts[index] = self.servers[index].accept()
|
||||
|
||||
let transp = await finished
|
||||
try:
|
||||
let observedAddr = MultiAddress.init(transp.remoteAddress).tryGet()
|
||||
return await self.connHandler(transp, Opt.some(observedAddr), Direction.In)
|
||||
except CancelledError as exc:
|
||||
transp.close()
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
debug "Failed to handle connection", exc = exc.msg
|
||||
transp.close()
|
||||
except TransportTooManyError as exc:
|
||||
debug "Too many files opened", exc = exc.msg
|
||||
except TransportAbortedError as exc:
|
||||
debug "Connection aborted", exc = exc.msg
|
||||
except TransportUseClosedError as exc:
|
||||
debug "Server was closed", exc = exc.msg
|
||||
raise newTransportClosedError(exc)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except TransportOsError as exc:
|
||||
info "OS Error", exc = exc.msg
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
info "Unexpected error accepting connection", exc = exc.msg
|
||||
raise exc
|
||||
let remote =
|
||||
try:
|
||||
transp.remoteAddress
|
||||
except TransportOsError as exc:
|
||||
# The connection had errors / was closed before `await` returned control
|
||||
await transp.closeWait()
|
||||
debug "Cannot read remote address", exc = exc.msg
|
||||
return nil
|
||||
|
||||
let observedAddr =
|
||||
MultiAddress.init(remote).expect("Can initialize from remote address")
|
||||
self.connHandler(transp, Opt.some(observedAddr), Direction.In)
|
||||
|
||||
impl(self)
|
||||
|
||||
method dial*(
|
||||
self: TcpTransport,
|
||||
hostname: string,
|
||||
address: MultiAddress,
|
||||
peerId: Opt[PeerId] = Opt.none(PeerId)): Future[Connection] {.async.} =
|
||||
self: TcpTransport,
|
||||
hostname: string,
|
||||
address: MultiAddress,
|
||||
peerId: Opt[PeerId] = Opt.none(PeerId),
|
||||
): Future[Connection] =
|
||||
## dial a peer
|
||||
##
|
||||
proc impl(
|
||||
self: TcpTransport, hostname: string, address: MultiAddress, peerId: Opt[PeerId]
|
||||
): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} =
|
||||
if self.stopping:
|
||||
raise newTransportClosedError()
|
||||
|
||||
trace "Dialing remote peer", address = $address
|
||||
let transp =
|
||||
if self.networkReachability == NetworkReachability.NotReachable and self.addrs.len > 0:
|
||||
self.clientFlags.incl(SocketFlags.ReusePort)
|
||||
await connect(address, flags = self.clientFlags, localAddress = Opt.some(self.addrs[0]))
|
||||
else:
|
||||
await connect(address, flags = self.clientFlags)
|
||||
let ta = initTAddress(address).valueOr:
|
||||
raise (ref TcpTransportError)(msg: "Unsupported address: " & $address)
|
||||
|
||||
try:
|
||||
let observedAddr = MultiAddress.init(transp.remoteAddress).tryGet()
|
||||
return await self.connHandler(transp, Opt.some(observedAddr), Direction.Out)
|
||||
except CatchableError as err:
|
||||
await transp.closeWait()
|
||||
raise err
|
||||
trace "Dialing remote peer", address = $address
|
||||
let transp =
|
||||
try:
|
||||
await(
|
||||
if self.networkReachability == NetworkReachability.NotReachable and
|
||||
self.addrs.len > 0:
|
||||
let local = initTAddress(self.addrs[0]).expect("self address is valid")
|
||||
self.clientFlags.incl(SocketFlags.ReusePort)
|
||||
connect(ta, flags = self.clientFlags, localAddress = local)
|
||||
else:
|
||||
connect(ta, flags = self.clientFlags)
|
||||
)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
|
||||
|
||||
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
|
||||
# If `stop` is called after `connect` but before `await` returns, we might
|
||||
# end up with a race condition where `stop` returns but not all connections
|
||||
# have been closed - we drop connections in this case in order not to leak
|
||||
# them
|
||||
if self.stopping:
|
||||
# Stopped while waiting for new connection
|
||||
await transp.closeWait()
|
||||
raise newTransportClosedError()
|
||||
|
||||
let observedAddr =
|
||||
try:
|
||||
MultiAddress.init(transp.remoteAddress).expect("remote address is valid")
|
||||
except TransportOsError as exc:
|
||||
await transp.closeWait()
|
||||
raise (ref TcpTransportError)(msg: exc.msg)
|
||||
|
||||
self.connHandler(transp, Opt.some(observedAddr), Direction.Out)
|
||||
|
||||
impl(self, hostname, address, peerId)
|
||||
|
||||
method handles*(t: TcpTransport, address: MultiAddress): bool =
|
||||
if procCall Transport(t).handles(address):
|
||||
if address.protocols.isOk:
|
||||
return TCP.match(address)
|
||||
return TCP.match(address)
|
||||
@@ -200,7 +200,7 @@ method dial*(
|
||||
|
||||
try:
|
||||
await dialPeer(transp, address)
|
||||
return await self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out)
|
||||
return self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out)
|
||||
except CatchableError as err:
|
||||
await transp.closeWait()
|
||||
raise err
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
@@ -35,7 +35,7 @@ type
|
||||
upgrader*: Upgrade
|
||||
networkReachability*: NetworkReachability
|
||||
|
||||
proc newTransportClosedError*(parent: ref Exception = nil): ref LPError =
|
||||
proc newTransportClosedError*(parent: ref Exception = nil): ref TransportError =
|
||||
newException(TransportClosedError,
|
||||
"Transport closed, no more connections!", parent)
|
||||
|
||||
@@ -81,25 +81,25 @@ proc dial*(
|
||||
self.dial("", address)
|
||||
|
||||
method upgrade*(
|
||||
self: Transport,
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]): Future[Muxer] {.base, gcsafe.} =
|
||||
self: Transport,
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]
|
||||
): Future[Muxer] {.base, async: (raises: [
|
||||
CancelledError, LPError], raw: true).} =
|
||||
## base upgrade method that the transport uses to perform
|
||||
## transport specific upgrades
|
||||
##
|
||||
|
||||
self.upgrader.upgrade(conn, peerId)
|
||||
|
||||
method handles*(
|
||||
self: Transport,
|
||||
address: MultiAddress): bool {.base, gcsafe.} =
|
||||
self: Transport,
|
||||
address: MultiAddress): bool {.base, gcsafe.} =
|
||||
## check if transport supports the multiaddress
|
||||
##
|
||||
|
||||
# by default we skip circuit addresses to avoid
|
||||
# having to repeat the check in every transport
|
||||
let protocols = address.protocols.valueOr: return false
|
||||
return protocols
|
||||
protocols
|
||||
.filterIt(
|
||||
it == multiCodec("p2p-circuit")
|
||||
).len == 0
|
||||
|
||||
@@ -25,53 +25,61 @@ type
|
||||
muxers*: seq[MuxerProvider]
|
||||
streamHandler*: StreamHandler
|
||||
|
||||
proc getMuxerByCodec(self: MuxedUpgrade, muxerName: string): MuxerProvider =
|
||||
func getMuxerByCodec(
|
||||
self: MuxedUpgrade, muxerName: string): Opt[MuxerProvider] =
|
||||
if muxerName.len == 0 or muxerName == "na":
|
||||
return Opt.none(MuxerProvider)
|
||||
for m in self.muxers:
|
||||
if muxerName == m.codec:
|
||||
return m
|
||||
return Opt.some(m)
|
||||
Opt.none(MuxerProvider)
|
||||
|
||||
proc mux*(
|
||||
proc mux(
|
||||
self: MuxedUpgrade,
|
||||
conn: Connection): Future[Muxer] {.async.} =
|
||||
conn: Connection
|
||||
): Future[Opt[Muxer]] {.async: (raises: [
|
||||
CancelledError, LPStreamError, MultiStreamError]).} =
|
||||
## mux connection
|
||||
|
||||
trace "Muxing connection", conn
|
||||
if self.muxers.len == 0:
|
||||
warn "no muxers registered, skipping upgrade flow", conn
|
||||
return
|
||||
return Opt.none(Muxer)
|
||||
|
||||
let muxerName =
|
||||
if conn.dir == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec))
|
||||
else: await MultistreamSelect.handle(conn, self.muxers.mapIt(it.codec))
|
||||
|
||||
if muxerName.len == 0 or muxerName == "na":
|
||||
debug "no muxer available, early exit", conn
|
||||
return
|
||||
let
|
||||
muxerName =
|
||||
case conn.dir
|
||||
of Direction.Out:
|
||||
await self.ms.select(conn, self.muxers.mapIt(it.codec))
|
||||
of Direction.In:
|
||||
await MultistreamSelect.handle(conn, self.muxers.mapIt(it.codec))
|
||||
muxerProvider = self.getMuxerByCodec(muxerName).valueOr:
|
||||
debug "no muxer available, early exit", conn, muxerName
|
||||
return Opt.none(Muxer)
|
||||
|
||||
trace "Found a muxer", conn, muxerName
|
||||
|
||||
# create new muxer for connection
|
||||
let muxer = self.getMuxerByCodec(muxerName).newMuxer(conn)
|
||||
let muxer = muxerProvider.newMuxer(conn)
|
||||
|
||||
# install stream handler
|
||||
muxer.streamHandler = self.streamHandler
|
||||
muxer.handler = muxer.handle()
|
||||
return muxer
|
||||
Opt.some(muxer)
|
||||
|
||||
method upgrade*(
|
||||
self: MuxedUpgrade,
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]): Future[Muxer] {.async.} =
|
||||
peerId: Opt[PeerId]
|
||||
): Future[Muxer] {.async: (raises: [CancelledError, LPError]).} =
|
||||
trace "Upgrading connection", conn, direction = conn.dir
|
||||
|
||||
let sconn = await self.secure(conn, peerId) # secure the connection
|
||||
let sconn = await self.secure(conn, peerId) # secure the connection
|
||||
if sconn == nil:
|
||||
raise newException(UpgradeFailedError,
|
||||
raise (ref UpgradeFailedError)(msg:
|
||||
"unable to secure connection, stopping upgrade")
|
||||
|
||||
let muxer = await self.mux(sconn) # mux it if possible
|
||||
if muxer == nil:
|
||||
raise newException(UpgradeFailedError,
|
||||
let muxer = (await self.mux(sconn)).valueOr: # mux it if possible
|
||||
raise (ref UpgradeFailedError)(msg:
|
||||
"a muxer is required for outgoing connections")
|
||||
|
||||
when defined(libp2p_agents_metrics):
|
||||
@@ -79,11 +87,11 @@ method upgrade*(
|
||||
|
||||
if sconn.closed():
|
||||
await sconn.close()
|
||||
raise newException(UpgradeFailedError,
|
||||
raise (ref UpgradeFailedError)(msg:
|
||||
"Connection closed or missing peer info, stopping upgrade")
|
||||
|
||||
trace "Upgraded connection", conn, sconn, direction = conn.dir
|
||||
return muxer
|
||||
muxer
|
||||
|
||||
proc new*(
|
||||
T: type MuxedUpgrade,
|
||||
@@ -101,8 +109,6 @@ proc new*(
|
||||
await upgrader.ms.handle(conn) # handle incoming connection
|
||||
except CancelledError as exc:
|
||||
return
|
||||
except CatchableError as exc:
|
||||
trace "exception in stream handler", conn, msg = exc.msg
|
||||
finally:
|
||||
await conn.closeWithEOF()
|
||||
trace "Stream handler done", conn
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Nim-LibP2P
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
@@ -24,8 +24,10 @@ import ../stream/connection,
|
||||
|
||||
export connmanager, connection, identify, secure, multistream
|
||||
|
||||
declarePublicCounter(libp2p_failed_upgrades_incoming, "incoming connections failed upgrades")
|
||||
declarePublicCounter(libp2p_failed_upgrades_outgoing, "outgoing connections failed upgrades")
|
||||
declarePublicCounter(libp2p_failed_upgrades_incoming,
|
||||
"incoming connections failed upgrades")
|
||||
declarePublicCounter(libp2p_failed_upgrades_outgoing,
|
||||
"outgoing connections failed upgrades")
|
||||
|
||||
logScope:
|
||||
topics = "libp2p upgrade"
|
||||
@@ -38,23 +40,28 @@ type
|
||||
secureManagers*: seq[Secure]
|
||||
|
||||
method upgrade*(
|
||||
self: Upgrade,
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]): Future[Muxer] {.base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
self: Upgrade,
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]
|
||||
): Future[Muxer] {.async: (raises: [
|
||||
CancelledError, LPError], raw: true), base.} =
|
||||
raiseAssert("Not implemented!")
|
||||
|
||||
proc secure*(
|
||||
self: Upgrade,
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]): Future[Connection] {.async.} =
|
||||
self: Upgrade,
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]
|
||||
): Future[Connection] {.async: (raises: [CancelledError, LPError]).} =
|
||||
if self.secureManagers.len <= 0:
|
||||
raise newException(UpgradeFailedError, "No secure managers registered!")
|
||||
raise (ref UpgradeFailedError)(msg: "No secure managers registered!")
|
||||
|
||||
let codec =
|
||||
if conn.dir == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
|
||||
else: await MultistreamSelect.handle(conn, self.secureManagers.mapIt(it.codec))
|
||||
if conn.dir == Out:
|
||||
await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
|
||||
else:
|
||||
await MultistreamSelect.handle(conn, self.secureManagers.mapIt(it.codec))
|
||||
if codec.len == 0:
|
||||
raise newException(UpgradeFailedError, "Unable to negotiate a secure channel!")
|
||||
raise (ref UpgradeFailedError)(msg: "Unable to negotiate a secure channel!")
|
||||
|
||||
trace "Securing connection", conn, codec
|
||||
let secureProtocol = self.secureManagers.filterIt(it.codec == codec)
|
||||
@@ -63,4 +70,4 @@ proc secure*(
|
||||
# let's avoid duplicating checks but detect if it fails to do it properly
|
||||
doAssert(secureProtocol.len > 0)
|
||||
|
||||
return await secureProtocol[0].secure(conn, peerId)
|
||||
await secureProtocol[0].secure(conn, peerId)
|
||||
|
||||
@@ -112,24 +112,30 @@ template withValue*[T](self: Opt[T] | Option[T], value, body: untyped): untyped
|
||||
let value {.inject.} = temp.get()
|
||||
body
|
||||
|
||||
macro withValue*[T](self: Opt[T] | Option[T], value, body, body2: untyped): untyped =
|
||||
let elseBody = body2[0]
|
||||
template withValue*[T, E](self: Result[T, E], value, body: untyped): untyped =
|
||||
self.toOpt().withValue(value, body)
|
||||
|
||||
macro withValue*[T](self: Opt[T] | Option[T], value, body, elseStmt: untyped): untyped =
|
||||
let elseBody = elseStmt[0]
|
||||
quote do:
|
||||
if `self`.isSome:
|
||||
let `value` {.inject.} = `self`.get()
|
||||
let temp = (`self`)
|
||||
if temp.isSome:
|
||||
let `value` {.inject.} = temp.get()
|
||||
`body`
|
||||
else:
|
||||
`elseBody`
|
||||
|
||||
template valueOr*[T](self: Option[T], body: untyped): untyped =
|
||||
if self.isSome:
|
||||
self.get()
|
||||
let temp = (self)
|
||||
if temp.isSome:
|
||||
temp.get()
|
||||
else:
|
||||
body
|
||||
|
||||
template toOpt*[T, E](self: Result[T, E]): Opt[T] =
|
||||
if self.isOk:
|
||||
let temp = (self)
|
||||
if temp.isOk:
|
||||
when T is void: Result[void, void].ok()
|
||||
else: Opt.some(self.unsafeGet())
|
||||
else: Opt.some(temp.unsafeGet())
|
||||
else:
|
||||
Opt.none(type(T))
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
import chronos, stew/endians2
|
||||
import multiaddress, multicodec, errors, utility
|
||||
|
||||
export multiaddress, chronos
|
||||
|
||||
when defined(windows):
|
||||
import winlean
|
||||
else:
|
||||
@@ -30,7 +32,6 @@ const
|
||||
UDP,
|
||||
)
|
||||
|
||||
|
||||
proc initTAddress*(ma: MultiAddress): MaResult[TransportAddress] =
|
||||
## Initialize ``TransportAddress`` with MultiAddress ``ma``.
|
||||
##
|
||||
@@ -76,7 +77,7 @@ proc connect*(
|
||||
child: StreamTransport = nil,
|
||||
flags = default(set[SocketFlags]),
|
||||
localAddress: Opt[MultiAddress] = Opt.none(MultiAddress)): Future[StreamTransport]
|
||||
{.raises: [LPError, MaInvalidAddress].} =
|
||||
{.async.} =
|
||||
## Open new connection to remote peer with address ``ma`` and create
|
||||
## new transport object ``StreamTransport`` for established connection.
|
||||
## ``bufferSize`` is size of internal buffer for transport.
|
||||
@@ -88,12 +89,12 @@ proc connect*(
|
||||
let transportAddress = initTAddress(ma).tryGet()
|
||||
|
||||
compilesOr:
|
||||
return connect(transportAddress, bufferSize, child,
|
||||
return await connect(transportAddress, bufferSize, child,
|
||||
if localAddress.isSome(): initTAddress(localAddress.expect("just checked")).tryGet() else: TransportAddress(),
|
||||
flags)
|
||||
do:
|
||||
# support for older chronos versions
|
||||
return connect(transportAddress, bufferSize, child)
|
||||
return await connect(transportAddress, bufferSize, child)
|
||||
|
||||
proc createStreamServer*[T](ma: MultiAddress,
|
||||
cbproc: StreamCallback,
|
||||
|
||||
46
tests/di/testdi.nim
Normal file
46
tests/di/testdi.nim
Normal file
@@ -0,0 +1,46 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import ../helpers
|
||||
import ../../di/di
|
||||
|
||||
type
|
||||
MyInterface = ref object of RootObj
|
||||
MyImplementation = ref object of MyInterface
|
||||
AnotherImplementation = ref object of MyInterface
|
||||
|
||||
MyObject = object
|
||||
|
||||
method doSomething(obj: MyInterface) {.base.} = discard
|
||||
|
||||
method doSomething(obj: MyImplementation) =
|
||||
echo "MyImplementation doing something!"
|
||||
|
||||
method doSomething(obj: AnotherImplementation) =
|
||||
echo "AnotherImplementation doing something!"
|
||||
|
||||
proc provideMyImplementation(): MyInterface =
|
||||
MyImplementation()
|
||||
|
||||
proc provideAnotherImplementation(): MyInterface =
|
||||
AnotherImplementation()
|
||||
|
||||
suite "DI":
|
||||
|
||||
asyncTest "DI":
|
||||
let container = Container()
|
||||
register[MyInterface](container, provideMyImplementation, "myImplementation")
|
||||
register[MyInterface](container, provideAnotherImplementation, "anotherImplementation")
|
||||
|
||||
let myImplementation = resolve[MyInterface](container, "anotherImplementation")
|
||||
myImplementation.doSomething()
|
||||
@@ -24,6 +24,6 @@ proc allFuturesThrowing*(args: varargs[FutureBase]): Future[void] =
|
||||
proc allFuturesThrowing*[T](futs: varargs[Future[T]]): Future[void] =
|
||||
allFuturesThrowing(futs.mapIt(FutureBase(it)))
|
||||
|
||||
proc allFuturesThrowing*[T, E](
|
||||
proc allFuturesThrowing*[T, E]( # https://github.com/nim-lang/Nim/issues/23432
|
||||
futs: varargs[InternalRaisesFuture[T, E]]): Future[void] =
|
||||
allFuturesThrowing(futs.mapIt(FutureBase(it)))
|
||||
|
||||
@@ -525,6 +525,17 @@ suite "GossipSub internal":
|
||||
await conn.close()
|
||||
await gossipSub.switch.stop()
|
||||
|
||||
asyncTest "invalid message bytes":
|
||||
let gossipSub = TestGossipSub.init(newStandardSwitch())
|
||||
|
||||
let peerId = randomPeerId()
|
||||
let peer = gossipSub.getPubSubPeer(peerId)
|
||||
|
||||
expect(CatchableError):
|
||||
await gossipSub.rpcHandler(peer, @[byte 1, 2, 3])
|
||||
|
||||
await gossipSub.switch.stop()
|
||||
|
||||
asyncTest "rebalanceMesh fail due to backoff":
|
||||
let gossipSub = TestGossipSub.init(newStandardSwitch())
|
||||
let topic = "foobar"
|
||||
@@ -681,7 +692,7 @@ suite "GossipSub internal":
|
||||
)
|
||||
peer.iHaveBudget = 0
|
||||
let iwants = gossipSub.handleIHave(peer, @[msg])
|
||||
check: iwants.messageIds.len == 0
|
||||
check: iwants.messageIDs.len == 0
|
||||
|
||||
block:
|
||||
# given duplicate ihave should generate only one iwant
|
||||
@@ -696,7 +707,7 @@ suite "GossipSub internal":
|
||||
messageIDs: @[id, id, id]
|
||||
)
|
||||
let iwants = gossipSub.handleIHave(peer, @[msg])
|
||||
check: iwants.messageIds.len == 1
|
||||
check: iwants.messageIDs.len == 1
|
||||
|
||||
block:
|
||||
# given duplicate iwant should generate only one message
|
||||
@@ -779,7 +790,7 @@ suite "GossipSub internal":
|
||||
let (iwantMessageIds, sentMessages) = createMessages(gossip0, gossip1, messageSize, messageSize)
|
||||
|
||||
gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage(
|
||||
ihave: @[ControlIHave(topicId: "foobar", messageIds: iwantMessageIds)]
|
||||
ihave: @[ControlIHave(topicID: "foobar", messageIDs: iwantMessageIds)]
|
||||
))), isHighPriority = false)
|
||||
|
||||
checkUntilTimeout: receivedMessages[] == sentMessages
|
||||
@@ -796,7 +807,7 @@ suite "GossipSub internal":
|
||||
let (bigIWantMessageIds, sentMessages) = createMessages(gossip0, gossip1, messageSize, messageSize)
|
||||
|
||||
gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage(
|
||||
ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)]
|
||||
ihave: @[ControlIHave(topicID: "foobar", messageIDs: bigIWantMessageIds)]
|
||||
))), isHighPriority = false)
|
||||
|
||||
await sleepAsync(300.milliseconds)
|
||||
@@ -813,7 +824,7 @@ suite "GossipSub internal":
|
||||
let (bigIWantMessageIds, sentMessages) = createMessages(gossip0, gossip1, size1, size2)
|
||||
|
||||
gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage(
|
||||
ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)]
|
||||
ihave: @[ControlIHave(topicID: "foobar", messageIDs: bigIWantMessageIds)]
|
||||
))), isHighPriority = false)
|
||||
|
||||
checkUntilTimeout: receivedMessages[] == sentMessages
|
||||
@@ -831,7 +842,7 @@ suite "GossipSub internal":
|
||||
let (bigIWantMessageIds, sentMessages) = createMessages(gossip0, gossip1, size1, size2)
|
||||
|
||||
gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage(
|
||||
ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)]
|
||||
ihave: @[ControlIHave(topicID: "foobar", messageIDs: bigIWantMessageIds)]
|
||||
))), isHighPriority = false)
|
||||
|
||||
var smallestSet: HashSet[seq[byte]]
|
||||
|
||||
@@ -569,8 +569,8 @@ suite "GossipSub":
|
||||
proc slowValidator(topic: string, message: Message): Future[ValidationResult] {.async.} =
|
||||
await cRelayed
|
||||
# Empty A & C caches to detect duplicates
|
||||
gossip1.seen = TimedCache[MessageId].init()
|
||||
gossip3.seen = TimedCache[MessageId].init()
|
||||
gossip1.seen = TimedCache[SaltedId].init()
|
||||
gossip3.seen = TimedCache[SaltedId].init()
|
||||
let msgId = toSeq(gossip2.validationSeen.keys)[0]
|
||||
checkUntilTimeout(try: gossip2.validationSeen[msgId].len > 0 except: false)
|
||||
result = ValidationResult.Accept
|
||||
@@ -911,7 +911,7 @@ suite "GossipSub":
|
||||
check: gossip3.mesh.peers("foobar") == 1
|
||||
|
||||
gossip3.broadcast(gossip3.mesh["foobar"], RPCMsg(control: some(ControlMessage(
|
||||
idontwant: @[ControlIWant(messageIds: @[newSeq[byte](10)])]
|
||||
idontwant: @[ControlIWant(messageIDs: @[newSeq[byte](10)])]
|
||||
))), isHighPriority = true)
|
||||
checkUntilTimeout: gossip2.mesh.getOrDefault("foobar").anyIt(it.heDontWants[^1].len == 1)
|
||||
|
||||
@@ -970,8 +970,9 @@ suite "GossipSub":
|
||||
|
||||
gossip0.broadcast(
|
||||
gossip0.mesh["foobar"],
|
||||
RPCMsg(messages: @[Message(topicIDs: @["foobar"], data: newSeq[byte](10))]),
|
||||
isHighPriority = true)
|
||||
RPCMsg(messages: @[Message(topic: "foobar", data: newSeq[byte](10))]),
|
||||
isHighPriority = true,
|
||||
)
|
||||
await sleepAsync(300.millis)
|
||||
|
||||
check currentRateLimitHits() == rateLimitHits
|
||||
@@ -981,8 +982,9 @@ suite "GossipSub":
|
||||
gossip1.parameters.disconnectPeerAboveRateLimit = true
|
||||
gossip0.broadcast(
|
||||
gossip0.mesh["foobar"],
|
||||
RPCMsg(messages: @[Message(topicIDs: @["foobar"], data: newSeq[byte](12))]),
|
||||
isHighPriority = true)
|
||||
RPCMsg(messages: @[Message(topic: "foobar", data: newSeq[byte](12))]),
|
||||
isHighPriority = true,
|
||||
)
|
||||
await sleepAsync(300.millis)
|
||||
|
||||
check gossip1.switch.isConnected(gossip0.switch.peerInfo.peerId) == true
|
||||
@@ -1053,7 +1055,7 @@ suite "GossipSub":
|
||||
gossip0.addValidator(topic, execValidator)
|
||||
gossip1.addValidator(topic, execValidator)
|
||||
|
||||
let msg = RPCMsg(messages: @[Message(topicIDs: @[topic], data: newSeq[byte](40))])
|
||||
let msg = RPCMsg(messages: @[Message(topic: topic, data: newSeq[byte](40))])
|
||||
|
||||
gossip0.broadcast(gossip0.mesh[topic], msg, isHighPriority = true)
|
||||
await sleepAsync(300.millis)
|
||||
@@ -1065,8 +1067,9 @@ suite "GossipSub":
|
||||
gossip1.parameters.disconnectPeerAboveRateLimit = true
|
||||
gossip0.broadcast(
|
||||
gossip0.mesh[topic],
|
||||
RPCMsg(messages: @[Message(topicIDs: @[topic], data: newSeq[byte](35))]),
|
||||
isHighPriority = true)
|
||||
RPCMsg(messages: @[Message(topic: topic, data: newSeq[byte](35))]),
|
||||
isHighPriority = true,
|
||||
)
|
||||
|
||||
checkUntilTimeout gossip1.switch.isConnected(gossip0.switch.peerInfo.peerId) == false
|
||||
check currentRateLimitHits() == rateLimitHits + 2
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
{.used.}
|
||||
|
||||
import unittest2, options, sets, sequtils
|
||||
import unittest2, sequtils
|
||||
import stew/byteutils
|
||||
import ../../libp2p/[peerid,
|
||||
crypto/crypto,
|
||||
protocols/pubsub/mcache,
|
||||
protocols/pubsub/rpc/messages]
|
||||
import ./utils
|
||||
protocols/pubsub/rpc/message]
|
||||
|
||||
var rng = newRng()
|
||||
|
||||
@@ -27,48 +26,48 @@ suite "MCache":
|
||||
var mCache = MCache.init(3, 5)
|
||||
|
||||
for i in 0..<3:
|
||||
var msg = Message(fromPeer: randomPeerId(),
|
||||
seqno: "12345".toBytes(),
|
||||
topicIDs: @["foo"])
|
||||
var
|
||||
msg =
|
||||
Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topic: "foo")
|
||||
mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenSuccess), msg)
|
||||
|
||||
for i in 0..<5:
|
||||
var msg = Message(fromPeer: randomPeerId(),
|
||||
seqno: "12345".toBytes(),
|
||||
topicIDs: @["bar"])
|
||||
var
|
||||
msg =
|
||||
Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topic: "bar")
|
||||
mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenSuccess), msg)
|
||||
|
||||
var mids = mCache.window("foo")
|
||||
check mids.len == 3
|
||||
|
||||
var id = toSeq(mids)[0]
|
||||
check mCache.get(id).get().topicIds[0] == "foo"
|
||||
check mCache.get(id).get().topic == "foo"
|
||||
|
||||
test "shift - shift 1 window at a time":
|
||||
var mCache = MCache.init(1, 5)
|
||||
|
||||
for i in 0..<3:
|
||||
var msg = Message(fromPeer: randomPeerId(),
|
||||
seqno: "12345".toBytes(),
|
||||
topicIDs: @["foo"])
|
||||
var
|
||||
msg =
|
||||
Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topic: "foo")
|
||||
mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenSuccess), msg)
|
||||
|
||||
mCache.shift()
|
||||
check mCache.window("foo").len == 0
|
||||
|
||||
for i in 0..<3:
|
||||
var msg = Message(fromPeer: randomPeerId(),
|
||||
seqno: "12345".toBytes(),
|
||||
topicIDs: @["bar"])
|
||||
var
|
||||
msg =
|
||||
Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topic: "bar")
|
||||
mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenSuccess), msg)
|
||||
|
||||
mCache.shift()
|
||||
check mCache.window("bar").len == 0
|
||||
|
||||
for i in 0..<3:
|
||||
var msg = Message(fromPeer: randomPeerId(),
|
||||
seqno: "12345".toBytes(),
|
||||
topicIDs: @["baz"])
|
||||
var
|
||||
msg =
|
||||
Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topic: "baz")
|
||||
mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenSuccess), msg)
|
||||
|
||||
mCache.shift()
|
||||
@@ -78,21 +77,21 @@ suite "MCache":
|
||||
var mCache = MCache.init(1, 5)
|
||||
|
||||
for i in 0..<3:
|
||||
var msg = Message(fromPeer: randomPeerId(),
|
||||
seqno: "12345".toBytes(),
|
||||
topicIDs: @["foo"])
|
||||
var
|
||||
msg =
|
||||
Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topic: "foo")
|
||||
mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenSuccess), msg)
|
||||
|
||||
for i in 0..<3:
|
||||
var msg = Message(fromPeer: randomPeerId(),
|
||||
seqno: "12345".toBytes(),
|
||||
topicIDs: @["bar"])
|
||||
var
|
||||
msg =
|
||||
Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topic: "bar")
|
||||
mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenSuccess), msg)
|
||||
|
||||
for i in 0..<3:
|
||||
var msg = Message(fromPeer: randomPeerId(),
|
||||
seqno: "12345".toBytes(),
|
||||
topicIDs: @["baz"])
|
||||
var
|
||||
msg =
|
||||
Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topic: "baz")
|
||||
mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenSuccess), msg)
|
||||
|
||||
mCache.shift()
|
||||
|
||||
@@ -75,14 +75,17 @@ suite "Message":
|
||||
msgIdResult.error == ValidationResult.Reject
|
||||
|
||||
test "byteSize for RPCMsg":
|
||||
var msg = Message(
|
||||
fromPeer: PeerId(data: @['a'.byte, 'b'.byte]), # 2 bytes
|
||||
data: @[1'u8, 2, 3], # 3 bytes
|
||||
seqno: @[4'u8, 5], # 2 bytes
|
||||
signature: @['c'.byte, 'd'.byte], # 2 bytes
|
||||
key: @[6'u8, 7], # 2 bytes
|
||||
topicIds: @["abc", "defgh"] # 3 + 5 = 8 bytes
|
||||
)
|
||||
var
|
||||
msg =
|
||||
Message(
|
||||
fromPeer: PeerId(data: @['a'.byte, 'b'.byte]), # 2 bytes
|
||||
data: @[1'u8, 2, 3], # 3 bytes
|
||||
seqno: @[4'u8, 5], # 2 bytes
|
||||
signature: @['c'.byte, 'd'.byte], # 2 bytes
|
||||
key: @[6'u8, 7], # 2 bytes
|
||||
topic: "abcde" # 5 bytes
|
||||
,
|
||||
)
|
||||
|
||||
var peerInfo = PeerInfoMsg(
|
||||
peerId: PeerId(data: @['e'.byte]), # 1 byte
|
||||
@@ -90,20 +93,20 @@ suite "Message":
|
||||
)
|
||||
|
||||
var controlIHave = ControlIHave(
|
||||
topicId: "ijk", # 3 bytes
|
||||
messageIds: @[ @['l'.byte], @['m'.byte, 'n'.byte] ] # 1 + 2 = 3 bytes
|
||||
topicID: "ijk", # 3 bytes
|
||||
messageIDs: @[ @['l'.byte], @['m'.byte, 'n'.byte] ] # 1 + 2 = 3 bytes
|
||||
)
|
||||
|
||||
var controlIWant = ControlIWant(
|
||||
messageIds: @[ @['o'.byte, 'p'.byte], @['q'.byte] ] # 2 + 1 = 3 bytes
|
||||
messageIDs: @[ @['o'.byte, 'p'.byte], @['q'.byte] ] # 2 + 1 = 3 bytes
|
||||
)
|
||||
|
||||
var controlGraft = ControlGraft(
|
||||
topicId: "rst" # 3 bytes
|
||||
topicID: "rst" # 3 bytes
|
||||
)
|
||||
|
||||
var controlPrune = ControlPrune(
|
||||
topicId: "uvw", # 3 bytes
|
||||
topicID: "uvw", # 3 bytes
|
||||
peers: @[peerInfo, peerInfo], # (1 + 2) * 2 = 6 bytes
|
||||
backoff: 12345678 # 8 bytes for uint64
|
||||
)
|
||||
@@ -118,10 +121,10 @@ suite "Message":
|
||||
|
||||
var rpcMsg = RPCMsg(
|
||||
subscriptions: @[SubOpts(subscribe: true, topic: "a".repeat(12)), SubOpts(subscribe: false, topic: "b".repeat(14))], # 1 + 12 + 1 + 14 = 28 bytes
|
||||
messages: @[msg, msg], # 19 * 2 = 38 bytes
|
||||
messages: @[msg, msg], # 16 * 2 = 32 bytes
|
||||
ping: @[1'u8, 2], # 2 bytes
|
||||
pong: @[3'u8, 4], # 2 bytes
|
||||
control: some(control) # 12 + 3 + 3 + 17 + 3 = 38 bytes
|
||||
)
|
||||
|
||||
check byteSize(rpcMsg) == 28 + 38 + 2 + 2 + 38 # Total: 108 bytes
|
||||
check byteSize(rpcMsg) == 28 + 32 + 2 + 2 + 38 # Total: 102 bytes
|
||||
|
||||
@@ -24,6 +24,8 @@ suite "TimedCache":
|
||||
2 in cache
|
||||
3 in cache
|
||||
|
||||
cache.addedAt(2) == now + 3.seconds
|
||||
|
||||
check:
|
||||
cache.put(2, now + 7.seconds) # refreshes 2
|
||||
not cache.put(4, now + 12.seconds) # expires 3
|
||||
@@ -33,6 +35,23 @@ suite "TimedCache":
|
||||
3 notin cache
|
||||
4 in cache
|
||||
|
||||
check:
|
||||
cache.del(4).isSome()
|
||||
4 notin cache
|
||||
|
||||
check:
|
||||
not cache.put(100, now + 100.seconds) # expires everything
|
||||
100 in cache
|
||||
2 notin cache
|
||||
|
||||
test "enough items to force cache heap storage growth":
|
||||
var cache = TimedCache[int].init(5.seconds)
|
||||
|
||||
let now = Moment.now()
|
||||
for i in 101..100000:
|
||||
check:
|
||||
not cache.put(i, now)
|
||||
|
||||
for i in 101..100000:
|
||||
check:
|
||||
i in cache
|
||||
|
||||
@@ -43,14 +43,15 @@ proc randomPeerId*(): PeerId =
|
||||
raise newException(Defect, exc.msg)
|
||||
|
||||
func defaultMsgIdProvider*(m: Message): Result[MessageId, ValidationResult] =
|
||||
let mid =
|
||||
if m.seqno.len > 0 and m.fromPeer.data.len > 0:
|
||||
byteutils.toHex(m.seqno) & $m.fromPeer
|
||||
else:
|
||||
# This part is irrelevant because it's not standard,
|
||||
# We use it exclusively for testing basically and users should
|
||||
# implement their own logic in the case they use anonymization
|
||||
$m.data.hash & $m.topicIds.hash
|
||||
let
|
||||
mid =
|
||||
if m.seqno.len > 0 and m.fromPeer.data.len > 0:
|
||||
byteutils.toHex(m.seqno) & $m.fromPeer
|
||||
else:
|
||||
# This part is irrelevant because it's not standard,
|
||||
# We use it exclusively for testing basically and users should
|
||||
# implement their own logic in the case they use anonymization
|
||||
$m.data.hash & $m.topic.hash
|
||||
ok mid.toBytes()
|
||||
|
||||
proc generateNodes*(
|
||||
|
||||
@@ -26,7 +26,7 @@ import ../libp2p/[switch,
|
||||
muxers/muxer,
|
||||
muxers/mplex/mplex,
|
||||
protocols/secure/noise,
|
||||
protocols/secure/secio,
|
||||
protocols/secure/plaintext,
|
||||
protocols/secure/secure,
|
||||
upgrademngrs/muxedupgrade,
|
||||
connmanager]
|
||||
@@ -53,7 +53,7 @@ method init(p: TestProto) {.gcsafe.} =
|
||||
{.pop.}
|
||||
|
||||
|
||||
proc createSwitch(ma: MultiAddress; outgoing: bool, secio: bool = false): (Switch, PeerInfo) =
|
||||
proc createSwitch(ma: MultiAddress; outgoing: bool, plaintext: bool = false): (Switch, PeerInfo) =
|
||||
var
|
||||
privateKey = PrivateKey.random(ECDSA, rng[]).get()
|
||||
peerInfo = PeerInfo.new(privateKey, @[ma])
|
||||
@@ -66,8 +66,8 @@ proc createSwitch(ma: MultiAddress; outgoing: bool, secio: bool = false): (Switc
|
||||
peerStore = PeerStore.new(identify)
|
||||
mplexProvider = MuxerProvider.new(createMplex, MplexCodec)
|
||||
muxers = @[mplexProvider]
|
||||
secureManagers = if secio:
|
||||
[Secure(Secio.new(rng, privateKey))]
|
||||
secureManagers = if plaintext:
|
||||
[Secure(PlainText.new())]
|
||||
else:
|
||||
[Secure(Noise.new(rng, privateKey, outgoing = outgoing))]
|
||||
connManager = ConnManager.new()
|
||||
|
||||
@@ -315,7 +315,6 @@ suite "Circuit Relay V2":
|
||||
await sleepAsync(chronos.timer.seconds(ttl + 1))
|
||||
|
||||
expect(DialFailedError):
|
||||
check: conn.atEof()
|
||||
await conn.close()
|
||||
await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs)
|
||||
conn = await src.dial(dst.peerInfo.peerId, @[ addrs ], customProtoCodec)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Copyright (c) 2023-2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
@@ -88,7 +88,6 @@ suite "Tor transport":
|
||||
|
||||
# every incoming connections will be in handled in this closure
|
||||
proc handle(conn: Connection, proto: string) {.async.} =
|
||||
|
||||
var resp: array[6, byte]
|
||||
await conn.readExactly(addr resp, 6)
|
||||
check string.fromBytes(resp) == "client"
|
||||
@@ -97,7 +96,7 @@ suite "Tor transport":
|
||||
# We must close the connections ourselves when we're done with it
|
||||
await conn.close()
|
||||
|
||||
return T(codecs: @[TestCodec], handler: handle)
|
||||
return T.new(codecs = @[TestCodec], handler = handle)
|
||||
|
||||
let rng = newRng()
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import options
|
||||
import ./helpers
|
||||
import ../libp2p/utility
|
||||
|
||||
@@ -71,3 +72,88 @@ suite "Utility":
|
||||
test "unsuccessful safeConvert from uint to int":
|
||||
check not (compiles do:
|
||||
result: uint = safeConvert[int, uint](11.uint))
|
||||
|
||||
suite "withValue and valueOr templates":
|
||||
type
|
||||
TestObj = ref object
|
||||
x: int
|
||||
|
||||
proc objIncAndOpt(self: TestObj): Opt[TestObj] =
|
||||
self.x.inc()
|
||||
return Opt.some(self)
|
||||
|
||||
proc objIncAndOption(self: TestObj): Option[TestObj] =
|
||||
self.x.inc()
|
||||
return some(self)
|
||||
|
||||
test "withValue calls right branch when Opt/Option is none":
|
||||
var counter = 0
|
||||
# check Opt/Option withValue with else
|
||||
Opt.none(TestObj).withValue(v):
|
||||
fail()
|
||||
else:
|
||||
counter.inc()
|
||||
none(TestObj).withValue(v):
|
||||
fail()
|
||||
else:
|
||||
counter.inc()
|
||||
check counter == 2
|
||||
|
||||
# check Opt/Option withValue without else
|
||||
Opt.none(TestObj).withValue(v):
|
||||
fail()
|
||||
none(TestObj).withValue(v):
|
||||
fail()
|
||||
|
||||
test "withValue calls right branch when Opt/Option is some":
|
||||
var counter = 1
|
||||
# check Opt/Option withValue with else
|
||||
Opt.some(counter).withValue(v):
|
||||
counter.inc(v)
|
||||
else:
|
||||
fail()
|
||||
some(counter).withValue(v):
|
||||
counter.inc(v)
|
||||
else:
|
||||
fail()
|
||||
|
||||
# check Opt/Option withValue without else
|
||||
Opt.some(counter).withValue(v):
|
||||
counter.inc(v)
|
||||
some(counter).withValue(v):
|
||||
counter.inc(v)
|
||||
check counter == 16
|
||||
|
||||
test "withValue calls right branch when Opt/Option is some with proc call":
|
||||
var obj = TestObj(x: 0)
|
||||
# check Opt/Option withValue with else
|
||||
objIncAndOpt(obj).withValue(v):
|
||||
v.x.inc()
|
||||
else:
|
||||
fail()
|
||||
objIncAndOption(obj).withValue(v):
|
||||
v.x.inc()
|
||||
else:
|
||||
fail()
|
||||
|
||||
# check Opt/Option withValue without else
|
||||
objIncAndOpt(obj).withValue(v):
|
||||
v.x.inc()
|
||||
objIncAndOption(obj).withValue(v):
|
||||
v.x.inc()
|
||||
|
||||
check obj.x == 8
|
||||
|
||||
test "valueOr calls with and without proc call":
|
||||
var obj = none(TestObj).valueOr:
|
||||
TestObj(x: 0)
|
||||
check obj.x == 0
|
||||
obj = some(TestObj(x: 2)).valueOr:
|
||||
fail()
|
||||
return
|
||||
check obj.x == 2
|
||||
|
||||
obj = objIncAndOpt(obj).valueOr:
|
||||
fail()
|
||||
return
|
||||
check obj.x == 3
|
||||
|
||||
86
tests/testwildcardresolverservice.nim
Normal file
86
tests/testwildcardresolverservice.nim
Normal file
@@ -0,0 +1,86 @@
|
||||
{.used.}
|
||||
|
||||
# Nim-Libp2p
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import std/[options, sequtils]
|
||||
import stew/[byteutils]
|
||||
import chronos, metrics
|
||||
import unittest2
|
||||
import ../libp2p/[builders, switch]
|
||||
import ../libp2p/services/wildcardresolverservice
|
||||
import ../libp2p/[multiaddress, multicodec]
|
||||
import ./helpers
|
||||
import ../di/di
|
||||
|
||||
type NetworkInterfaceProviderMock* = ref object of NetworkInterfaceProvider
|
||||
|
||||
method getAddresses(
|
||||
networkInterfaceProvider: NetworkInterfaceProviderMock, addrFamily: AddressFamily
|
||||
): seq[InterfaceAddress] {.gcsafe, raises: [].} =
|
||||
echo "getAddressesMock"
|
||||
try:
|
||||
if addrFamily == AddressFamily.IPv4:
|
||||
return
|
||||
@[
|
||||
InterfaceAddress.init(initTAddress("127.0.0.1:0"), 8),
|
||||
InterfaceAddress.init(initTAddress("192.168.1.22:0"), 24),
|
||||
]
|
||||
else:
|
||||
return
|
||||
@[
|
||||
InterfaceAddress.init(initTAddress("::1:0"), 8),
|
||||
InterfaceAddress.init(initTAddress("fe80::1:0"), 64),
|
||||
]
|
||||
except TransportAddressError as e:
|
||||
echo "Error: " & $e.msg
|
||||
fail()
|
||||
|
||||
proc networkInterfaceProviderMock(): NetworkInterfaceProvider =
|
||||
NetworkInterfaceProviderMock.new()
|
||||
|
||||
proc createSwitch(): Switch =
|
||||
SwitchBuilder
|
||||
.new()
|
||||
.withRng(newRng())
|
||||
.withAddresses(
|
||||
@[
|
||||
MultiAddress.init("/ip4/0.0.0.0/tcp/0/").tryGet(),
|
||||
MultiAddress.init("/ip6/::/tcp/0/").tryGet(),
|
||||
]
|
||||
)
|
||||
.withTcpTransport()
|
||||
.withMplex()
|
||||
.withNoise()
|
||||
.withBinding(networkInterfaceProviderMock)
|
||||
.build()
|
||||
|
||||
suite "WildcardAddressResolverService":
|
||||
teardown:
|
||||
checkTrackers()
|
||||
|
||||
asyncTest "WildcardAddressResolverService must resolve wildcard addresses and stop doing so when stopped":
|
||||
let switch = createSwitch()
|
||||
await switch.start()
|
||||
let tcpIp4 = switch.peerInfo.addrs[0][multiCodec("tcp")].get # tcp port for ip4
|
||||
let tcpIp6 = switch.peerInfo.addrs[2][multiCodec("tcp")].get # tcp port for ip6
|
||||
|
||||
check switch.peerInfo.addrs ==
|
||||
@[
|
||||
MultiAddress.init("/ip4/127.0.0.1" & $tcpIp4).get,
|
||||
MultiAddress.init("/ip4/192.168.1.22" & $tcpIp4).get,
|
||||
MultiAddress.init("/ip6/::1" & $tcpIp6).get,
|
||||
MultiAddress.init("/ip6/fe80::1" & $tcpIp6).get,
|
||||
]
|
||||
await switch.stop()
|
||||
check switch.peerInfo.addrs ==
|
||||
@[
|
||||
MultiAddress.init("/ip4/0.0.0.0" & $tcpIp4).get,
|
||||
MultiAddress.init("/ip6/::" & $tcpIp6).get,
|
||||
]
|
||||
@@ -377,3 +377,24 @@ suite "Yamux":
|
||||
expect LPStreamClosedError: discard await streamA.readLp(100)
|
||||
blocker.complete()
|
||||
await streamA.close()
|
||||
|
||||
asyncTest "Peer must be able to read from stream after closing it for writing":
|
||||
mSetup()
|
||||
|
||||
yamuxb.streamHandler = proc(conn: Connection) {.async: (raises: []).} =
|
||||
try:
|
||||
check (await conn.readLp(100)) == fromHex("1234")
|
||||
except CancelledError, LPStreamError:
|
||||
return
|
||||
try:
|
||||
await conn.writeLp(fromHex("5678"))
|
||||
except CancelledError, LPStreamError:
|
||||
return
|
||||
await conn.close()
|
||||
|
||||
let streamA = await yamuxa.newStream()
|
||||
check streamA == yamuxa.getStreams()[0]
|
||||
|
||||
await streamA.writeLp(fromHex("1234"))
|
||||
await streamA.close()
|
||||
check (await streamA.readLp(100)) == fromHex("5678")
|
||||
|
||||
@@ -11,6 +11,6 @@ COPY . nim-libp2p/
|
||||
|
||||
RUN \
|
||||
cd nim-libp2p && \
|
||||
nim c --skipProjCfg --skipParentCfg --NimblePath:./nimbledeps/pkgs -p:nim-libp2p -d:chronicles_log_level=WARN --threads:off ./tests/transport-interop/main.nim
|
||||
nim c --skipProjCfg --skipParentCfg --NimblePath:./nimbledeps/pkgs -p:nim-libp2p -d:chronicles_log_level=WARN -d:chronicles_default_output_device=stderr --threads:off ./tests/transport-interop/main.nim
|
||||
|
||||
ENTRYPOINT ["/app/nim-libp2p/tests/transport-interop/main"]
|
||||
|
||||
Reference in New Issue
Block a user